BiBo-MoE-Tiny / modelling_bibo.py
fhai50032's picture
Create modelling_bibo.py
7dd85ba verified
# coding=utf-8
# Copyright 2024 The BiBo Authors and The HuggingFace Inc. team. All rights reserved.
""" PyTorch BiBo model (Based on Qwen2 with MoE modifications).
we can use MoEwithoutput class; """
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from .configuration_bibo import BiBoConfig
try:
import torch_xla.core.xla_model as xm
_XLA_AVAILABLE = True
except ImportError:
_XLA_AVAILABLE = False
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
can_return_tuple,
)
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "BiBo-MoE-Model"
_CONFIG_FOR_DOC = "BiBoConfig"
class BiBoMLP(nn.Module):
"""Standard SwiGLU MLP used for dense layers."""
def __init__(self, config: BiBoConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MLPExpert(nn.Module):
"""SwiGLU based MLP Expert for MoE Layers"""
def __init__(self, config: BiBoConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.moe_intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class ModifiedConvolutionalExpert(nn.Module):
"""Causal Convolutional 'Expert' (Shared) for MoE Layers"""
def __init__(self, config: BiBoConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.moe_intermediate_size
self.kernel_size_gate = config.kernel_size
self.causal_padding_gate = self.kernel_size_gate - 1
self.gate_conv = nn.Conv1d(self.hidden_size, self.intermediate_size, self.kernel_size_gate, padding=0, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq_len, hidden_dim = x.shape
x_perm = x.permute(0, 2, 1)
# Apply causal padding
x_padded = F.pad(x_perm, (self.causal_padding_gate, 0))
gate_conv_out = self.gate_conv(x_padded)
gate_activated = self.act_fn(gate_conv_out)
gate_ready = gate_activated.permute(0, 2, 1)
up_linear_out = self.up_proj(x)
intermediate = gate_ready * up_linear_out; output = self.down_proj(intermediate)
if output.shape[1] != seq_len: raise RuntimeError("ModifiedConvExpert length mismatch")
return output
class IdentityExpert(nn.Module):
def __init__(self, *args, **kwargs): super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: return x
class BiBoMoERouter(nn.Module):
def __init__(self, config: BiBoConfig):
super().__init__()
self.num_experts = config.num_routed_experts
self.top_k = config.num_experts_per_tok
self.temperature = config.router_temperature
self.router_noise = config.router_noise
self.bias = nn.Parameter(torch.zeros(self.num_experts))
self.gate_proj = nn.Linear(config.hidden_size, self.num_experts, bias=False)
def forward(self, hidden_states: torch.Tensor):
""" Forward pass with noise, bias, clamping, temperature. """
bsz, seq_len, _ = hidden_states.shape; num_tokens = bsz * seq_len
noise_variance=self.router_noise
flat_hidden = hidden_states.view(num_tokens, -1)
router_logits = self.gate_proj(flat_hidden).float()
""" No Clamping for Now
TODO: @aloobun make clamp range dynamic based on mean/median/mode/std of current logits"""
# if self.logit_clamp_val > 0:
# router_logits = torch.clamp(router_logits, min=-self.logit_clamp_val, max=self.logit_clamp_val)
if self.training and noise_variance > 0:
noise_stddev = math.sqrt(noise_variance)
noise = torch.randn_like(router_logits) * noise_stddev
router_logits = router_logits + noise.detach()
router_logits = router_logits + self.bias
if self.temperature != 1.0:
router_logits = router_logits / self.temperature
routing_weights = F.softmax(router_logits, dim=1)
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
norm_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
return top_k_indices.long(), norm_weights.to(hidden_states.dtype)
class BiBoMoELayer(nn.Module):
def __init__(self, config: BiBoConfig):
super().__init__()
self.hidden_size = config.hidden_size; self.num_experts_per_tok = config.num_experts_per_tok
self.routed_experts = nn.ModuleList()
num_mlp_routed = config.num_routed_experts - 1
for _ in range(num_mlp_routed): self.routed_experts.append(MLPExpert(config))
self.routed_experts.append(IdentityExpert(config))
if len(self.routed_experts) != config.num_routed_experts: raise ValueError("Routed experts mismatch")
self.shared_experts_list = nn.ModuleList()
if config.num_shared_experts > 0:
if config.num_shared_experts != 1: warnings.warn("Expected 1 shared expert, using 1 Conv.")
self.shared_experts_list.append(ModifiedConvolutionalExpert(config))
self.gate = BiBoMoERouter(config)
@torch.no_grad() # Bias update should not track gradients
def update_bias(self, tpe):
"""
Updates the router's learnable bias based on token distribution.
Ref: https://gist.github.com/joey00072/f9e65f7fe05b763a19e4824bda29c975
"""
if not hasattr(self.gate, 'bias') or self.bias_update_factor <= 0: return
c = tpe.detach().float()
e = c.mean() - c
# Update bias: add_(factor * sign(deviation))
self.gate.bias.add_(self.bias_update_factor * e.sign())
def forward(self, hidden_states: torch.Tensor):
""" Returns: final_output tensor """
bsz, seq_len, hidden_dim = hidden_states.shape; num_tokens = bsz * seq_len
flat_hidden = hidden_states.view(num_tokens, -1)
top_k_indices, top_k_weights = self.gate(hidden_states, noise_variance=self.router_noise)
tokens_per_expert = None
if self.training and hasattr(self.gate, 'bias') and self.bias_update_factor > 0:
tpe = torch.bincount(top_k_indices.view(-1), minlength=self.num_routed_experts)
tokens_per_expert = tpe
final_routed = torch.zeros_like(flat_hidden); flat_expert_indices = top_k_indices.view(-1)
flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.num_experts_per_tok)
for i, expert in enumerate(self.routed_experts):
mask = (flat_expert_indices == i)
if mask.any():
tokens_idx = flat_token_indices[mask]; unique_tokens, orig_indices = torch.unique(tokens_idx, return_inverse=True)
inputs = flat_hidden[unique_tokens]; outputs = expert(inputs)[orig_indices]
weights = top_k_weights.view(-1)[mask].unsqueeze(1)
final_routed.scatter_add_(0, tokens_idx.unsqueeze(1).expand(-1, hidden_dim), outputs * weights)
final_routed = final_routed.view(bsz, seq_len, hidden_dim)
shared_combined = torch.zeros_like(hidden_states)
if self.shared_experts_list: shared_combined = self.shared_experts_list[0](hidden_states)
final_output = final_routed + shared_combined
if tokens_per_expert is not None:
self.update_bias(tokens_per_expert)
return final_output
def rotate_half(x): x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]; return torch.cat((-x2,x1),dim=-1)
def apply_rotary_pos_emb(q,k,cos,sin,position_ids=None,unsqueeze_dim=1): cos,sin=cos.unsqueeze(unsqueeze_dim),sin.unsqueeze(unsqueeze_dim); return (q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin)
def repeat_kv(x:torch.Tensor,n:int)->torch.Tensor: b,nk,s,h=x.shape; return x[:,:,None,:,:].expand(b,nk,n,s,h).reshape(b,nk*n,s,h) if n!=1 else x
def eager_attention_forward(m,q,k,v,mask,scale,dropout=0.0,**kw):
k,v=repeat_kv(k,m.num_key_value_groups),repeat_kv(v,m.num_key_value_groups); slk=k.shape[-2]
if mask is not None: mask=mask[:,:,:,:slk]
w=torch.matmul(q,k.transpose(2,3))*scale
if mask is not None:
if mask.size()!=(q.shape[0],1,q.shape[2],k.shape[2]): raise ValueError("Mask shape mismatch")
w=w+mask
w=F.softmax(w,dim=-1,dtype=torch.float32).to(q.dtype); w=F.dropout(w,p=dropout,training=m.training)
o=torch.matmul(w,v).transpose(1,2).contiguous(); return o,w
class BiBoAttention(nn.Module):
def __init__(self, config: BiBoConfig, layer_idx: int):
super().__init__(); self.config=config; self.layer_idx=layer_idx
self.hidden_size=config.hidden_size; self.num_heads=config.num_attention_heads; self.head_dim=self.hidden_size//self.num_heads
self.num_key_value_heads=config.num_key_value_heads; self.num_key_value_groups=self.num_heads//self.num_key_value_heads
self.max_position_embeddings=config.max_position_embeddings; self.rope_theta=config.rope_theta; self.is_causal=True
self.attention_dropout=config.attention_dropout; self.scaling=self.head_dim**-0.5
self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=True); self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True)
self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True); self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False)
def forward(self, hidden_states, pos_emb, mask=None, kv_cache=None, output_attentions=False, use_cache=False, cache_position=None, **kw):
b,q,_=hidden_states.size(); query=self.q_proj(hidden_states).view(b,q,self.num_heads,self.head_dim).transpose(1,2)
key=self.k_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2); value=self.v_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2)
cos,sin=pos_emb; query,key=apply_rotary_pos_emb(query,key,cos,sin)
if kv_cache is not None: key,value=kv_cache.update(key,value,self.layer_idx,{"sin":sin,"cos":cos,"cache_position":cache_position})
out,weights=eager_attention_forward(self,query,key,value,mask,self.scaling,self.attention_dropout)
out=out.reshape(b,q,self.hidden_size); out=self.o_proj(out); return out,weights if output_attentions else None
class BiBoRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): super().__init__(); self.weight=nn.Parameter(torch.ones(hidden_size)); self.variance_epsilon=eps
def forward(self, x): dt=x.dtype; x=x.to(torch.float32); v=x.pow(2).mean(-1,keepdim=True); x=x*torch.rsqrt(v+self.variance_epsilon); return self.weight*x.to(dt)
def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class BiBoDecoderLayer(nn.Module):
def __init__(self, config: BiBoConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = BiBoAttention(config=config, layer_idx=layer_idx)
self.input_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx
self.num_hidden_layers = config.num_hidden_layers
is_first_layer = layer_idx == 0
is_last_layer = layer_idx == config.num_hidden_layers - 1
# Conditional MLP/MoE Instantiation
if is_first_layer or is_last_layer:
self.mlp = BiBoMLP(config)
self.is_moe_layer = False
else:
self.mlp = BiBoMoELayer(config)
self.is_moe_layer = True
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None):
""" Returns tuple: (hidden_states,) or (hidden_states, attn_weights) """
residual = hidden_states; hidden_states = self.input_layernorm(hidden_states)
attn_outputs, attn_weights = self.self_attn(hidden_states, position_embeddings, attention_mask, past_key_value, output_attentions, use_cache, cache_position)
hidden_states = residual + attn_outputs; residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# --- Conditional Forward ---
if self.is_moe_layer: ffn_output = self.mlp(hidden_states)
else: ffn_output = self.mlp(hidden_states)
hidden_states = residual + ffn_output; outputs = (hidden_states,)
if output_attentions: outputs += (attn_weights,)
return outputs
class BiBoRotaryEmbedding(nn.Module):
def __init__(self, config: BiBoConfig, device=None):
super().__init__(); rope_scaling=getattr(config,"rope_scaling",None); self.rope_type=rope_scaling.get("rope_type","default") if rope_scaling else "default"
self.max_seq_len_cached=config.max_position_embeddings; self.original_max_seq_len=config.max_position_embeddings; self.config=config
self.rope_init_fn=ROPE_INIT_FUNCTIONS[self.rope_type]; inv_freq,self.attention_scaling=self.rope_init_fn(self.config,device)
self.register_buffer("inv_freq",inv_freq,persistent=False); self.original_inv_freq=self.inv_freq
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
inv_freq=self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1).to(x.device); pos_ids=position_ids[:,None,:].float()
dev_type=x.device.type if isinstance(x.device.type,str) and x.device.type!="mps" else "cpu"
with torch.autocast(device_type=dev_type,enabled=False):
freqs=(inv_freq.float()@pos_ids.float()).transpose(1,2); emb=torch.cat((freqs,freqs),dim=-1)
cos=emb.cos()*self.attention_scaling; sin=emb.sin()*self.attention_scaling
return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype)
BIBO_START_DOCSTRING = r""" BiBo model... """
BIBO_INPUTS_DOCSTRING = r""" Standard arguments... """
@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
class BiBoPreTrainedModel(PreTrainedModel):
config_class = BiBoConfig
base_model_prefix = "model"; supports_gradient_checkpointing = True
_no_split_modules = ["BiBoDecoderLayer"]; _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = False; _supports_sdpa = True; _supports_cache_class = True
_supports_quantized_cache = True; _supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std); module.bias.data.zero_() if module.bias is not None else None
elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std); module.weight.data[module.padding_idx].zero_() if module.padding_idx is not None else None
elif isinstance(module, BiBoRMSNorm): module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)); module.bias.data.zero_() if module.bias is not None else None
@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
class BiBoModel(BiBoPreTrainedModel):
def __init__(self, config: BiBoConfig):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id; self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([BiBoDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = BiBoRotaryEmbedding(config=config)
self.gradient_checkpointing = False; self.post_init()
def get_input_embeddings(self): return self.embed_tokens
def set_input_embeddings(self, value): self.embed_tokens = value
def _prepare_decoder_attention_mask(self, mask, shape, embeds, past_len):
combined_mask=None; L=shape[-1]
if L>1: combined_mask=nn.functional._make_causal_mask(shape,embeds.dtype,device=embeds.device,past_key_values_length=past_len).to(embeds.device)
if mask is not None:
expanded_mask=nn.functional._expand_mask(mask,embeds.dtype,tgt_len=L).to(embeds.device)
combined_mask=(expanded_mask if combined_mask is None else expanded_mask+combined_mask)
if combined_mask is not None: bool_mask=combined_mask<0; combined_mask=combined_mask.masked_fill(bool_mask,torch.finfo(embeds.dtype).min)
return combined_mask
@can_return_tuple
@add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, return_dict=None):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache; return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None)^(inputs_embeds is not None): raise ValueError("Specify ids or embeds")
if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("Disabling use_cache"); use_cache=False
if not isinstance(past_key_values,(type(None),Cache)): raise ValueError("past_key_values type error")
if inputs_embeds is None: inputs_embeds=self.embed_tokens(input_ids)
if use_cache and past_key_values is None: past_key_values=DynamicCache()
past_len=past_key_values.get_seq_length() if past_key_values is not None else 0; seq_len=inputs_embeds.shape[1]
if cache_position is None: cache_position=torch.arange(past_len,past_len+seq_len,device=inputs_embeds.device)
if position_ids is None: position_ids=cache_position.unsqueeze(0)
causal_mask=self._prepare_decoder_attention_mask(attention_mask,(inputs_embeds.shape[0],seq_len),inputs_embeds,past_len)
hidden_states=inputs_embeds; pos_emb=self.rotary_emb(hidden_states,position_ids)
all_hidden,all_attn=(()if output_hidden_states else None,()if output_attentions else None)
for layer in self.layers:
if output_hidden_states: all_hidden+=(hidden_states,)
layer_outputs=layer(hidden_states,pos_emb,causal_mask,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position)
hidden_states=layer_outputs[0]
if output_attentions: all_attn+=(layer_outputs[1],)
hidden_states=self.norm(hidden_states)
if output_hidden_states: all_hidden+=(hidden_states,)
next_cache=past_key_values if use_cache else None
if not return_dict: return tuple(v for v in [hidden_states,next_cache,all_hidden,all_attn] if v is not None)
return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=next_cache,hidden_states=all_hidden,attentions=all_attn)
@add_start_docstrings(""" BiBo Model with CausalLM head. """, BIBO_START_DOCSTRING)
class BiBoForCausalLM(BiBoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: BiBoConfig):
super().__init__(config)
self.model = BiBoModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
# Methods remain the same
def get_input_embeddings(self): return self.model.embed_tokens
def set_input_embeddings(self, value): self.model.embed_tokens = value
def get_output_embeddings(self): return self.lm_head
def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
def set_decoder(self, decoder): self.model = decoder
def get_decoder(self): return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, logits_to_keep=0, return_dict=None,): # Add noise arg w/ default
r""" Loss calculation (CrossEntropy) must happen outside this function. """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict,)
hidden_states = model_outputs[0] if not return_dict else model_outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep != 0 else slice(None)
logits = self.lm_head(hidden_states[:, slice_indices, :])
# --- Loss is None ---
loss = None
if labels is not None: warnings.warn("Labels provided but loss calculation must be done externally.")
if not return_dict:
other_outputs = model_outputs[1:]
return (loss,) + (logits,) + other_outputs
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions)