# coding=utf-8 # Copyright 2024 The BiBo Authors and The HuggingFace Inc. team. All rights reserved. """ PyTorch BiBo model (Based on Qwen2 with MoE modifications). we can use MoEwithoutput class; """ import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from .configuration_bibo import BiBoConfig try: import torch_xla.core.xla_model as xm _XLA_AVAILABLE = True except ImportError: _XLA_AVAILABLE = False from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.generation import GenerationMixin from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, can_return_tuple, ) logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "BiBo-MoE-Model" _CONFIG_FOR_DOC = "BiBoConfig" class BiBoMLP(nn.Module): """Standard SwiGLU MLP used for dense layers.""" def __init__(self, config: BiBoConfig): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class MLPExpert(nn.Module): """SwiGLU based MLP Expert for MoE Layers""" def __init__(self, config: BiBoConfig): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.moe_intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class ModifiedConvolutionalExpert(nn.Module): """Causal Convolutional 'Expert' (Shared) for MoE Layers""" def __init__(self, config: BiBoConfig): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.moe_intermediate_size self.kernel_size_gate = config.kernel_size self.causal_padding_gate = self.kernel_size_gate - 1 self.gate_conv = nn.Conv1d(self.hidden_size, self.intermediate_size, self.kernel_size_gate, padding=0, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: bsz, seq_len, hidden_dim = x.shape x_perm = x.permute(0, 2, 1) # Apply causal padding x_padded = F.pad(x_perm, (self.causal_padding_gate, 0)) gate_conv_out = self.gate_conv(x_padded) gate_activated = self.act_fn(gate_conv_out) gate_ready = gate_activated.permute(0, 2, 1) up_linear_out = self.up_proj(x) intermediate = gate_ready * up_linear_out; output = self.down_proj(intermediate) if output.shape[1] != seq_len: raise RuntimeError("ModifiedConvExpert length mismatch") return output class IdentityExpert(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x class BiBoMoERouter(nn.Module): def __init__(self, config: BiBoConfig): super().__init__() self.num_experts = config.num_routed_experts self.top_k = config.num_experts_per_tok self.temperature = config.router_temperature self.router_noise = config.router_noise self.bias = nn.Parameter(torch.zeros(self.num_experts)) self.gate_proj = nn.Linear(config.hidden_size, self.num_experts, bias=False) def forward(self, hidden_states: torch.Tensor): """ Forward pass with noise, bias, clamping, temperature. """ bsz, seq_len, _ = hidden_states.shape; num_tokens = bsz * seq_len noise_variance=self.router_noise flat_hidden = hidden_states.view(num_tokens, -1) router_logits = self.gate_proj(flat_hidden).float() """ No Clamping for Now TODO: @aloobun make clamp range dynamic based on mean/median/mode/std of current logits""" # if self.logit_clamp_val > 0: # router_logits = torch.clamp(router_logits, min=-self.logit_clamp_val, max=self.logit_clamp_val) if self.training and noise_variance > 0: noise_stddev = math.sqrt(noise_variance) noise = torch.randn_like(router_logits) * noise_stddev router_logits = router_logits + noise.detach() router_logits = router_logits + self.bias if self.temperature != 1.0: router_logits = router_logits / self.temperature routing_weights = F.softmax(router_logits, dim=1) top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1) norm_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6) return top_k_indices.long(), norm_weights.to(hidden_states.dtype) class BiBoMoELayer(nn.Module): def __init__(self, config: BiBoConfig): super().__init__() self.hidden_size = config.hidden_size; self.num_experts_per_tok = config.num_experts_per_tok self.routed_experts = nn.ModuleList() num_mlp_routed = config.num_routed_experts - 1 for _ in range(num_mlp_routed): self.routed_experts.append(MLPExpert(config)) self.routed_experts.append(IdentityExpert(config)) if len(self.routed_experts) != config.num_routed_experts: raise ValueError("Routed experts mismatch") self.shared_experts_list = nn.ModuleList() if config.num_shared_experts > 0: if config.num_shared_experts != 1: warnings.warn("Expected 1 shared expert, using 1 Conv.") self.shared_experts_list.append(ModifiedConvolutionalExpert(config)) self.gate = BiBoMoERouter(config) @torch.no_grad() # Bias update should not track gradients def update_bias(self, tpe): """ Updates the router's learnable bias based on token distribution. Ref: https://gist.github.com/joey00072/f9e65f7fe05b763a19e4824bda29c975 """ if not hasattr(self.gate, 'bias') or self.bias_update_factor <= 0: return c = tpe.detach().float() e = c.mean() - c # Update bias: add_(factor * sign(deviation)) self.gate.bias.add_(self.bias_update_factor * e.sign()) def forward(self, hidden_states: torch.Tensor): """ Returns: final_output tensor """ bsz, seq_len, hidden_dim = hidden_states.shape; num_tokens = bsz * seq_len flat_hidden = hidden_states.view(num_tokens, -1) top_k_indices, top_k_weights = self.gate(hidden_states, noise_variance=self.router_noise) tokens_per_expert = None if self.training and hasattr(self.gate, 'bias') and self.bias_update_factor > 0: tpe = torch.bincount(top_k_indices.view(-1), minlength=self.num_routed_experts) tokens_per_expert = tpe final_routed = torch.zeros_like(flat_hidden); flat_expert_indices = top_k_indices.view(-1) flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.num_experts_per_tok) for i, expert in enumerate(self.routed_experts): mask = (flat_expert_indices == i) if mask.any(): tokens_idx = flat_token_indices[mask]; unique_tokens, orig_indices = torch.unique(tokens_idx, return_inverse=True) inputs = flat_hidden[unique_tokens]; outputs = expert(inputs)[orig_indices] weights = top_k_weights.view(-1)[mask].unsqueeze(1) final_routed.scatter_add_(0, tokens_idx.unsqueeze(1).expand(-1, hidden_dim), outputs * weights) final_routed = final_routed.view(bsz, seq_len, hidden_dim) shared_combined = torch.zeros_like(hidden_states) if self.shared_experts_list: shared_combined = self.shared_experts_list[0](hidden_states) final_output = final_routed + shared_combined if tokens_per_expert is not None: self.update_bias(tokens_per_expert) return final_output def rotate_half(x): x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]; return torch.cat((-x2,x1),dim=-1) def apply_rotary_pos_emb(q,k,cos,sin,position_ids=None,unsqueeze_dim=1): cos,sin=cos.unsqueeze(unsqueeze_dim),sin.unsqueeze(unsqueeze_dim); return (q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin) def repeat_kv(x:torch.Tensor,n:int)->torch.Tensor: b,nk,s,h=x.shape; return x[:,:,None,:,:].expand(b,nk,n,s,h).reshape(b,nk*n,s,h) if n!=1 else x def eager_attention_forward(m,q,k,v,mask,scale,dropout=0.0,**kw): k,v=repeat_kv(k,m.num_key_value_groups),repeat_kv(v,m.num_key_value_groups); slk=k.shape[-2] if mask is not None: mask=mask[:,:,:,:slk] w=torch.matmul(q,k.transpose(2,3))*scale if mask is not None: if mask.size()!=(q.shape[0],1,q.shape[2],k.shape[2]): raise ValueError("Mask shape mismatch") w=w+mask w=F.softmax(w,dim=-1,dtype=torch.float32).to(q.dtype); w=F.dropout(w,p=dropout,training=m.training) o=torch.matmul(w,v).transpose(1,2).contiguous(); return o,w class BiBoAttention(nn.Module): def __init__(self, config: BiBoConfig, layer_idx: int): super().__init__(); self.config=config; self.layer_idx=layer_idx self.hidden_size=config.hidden_size; self.num_heads=config.num_attention_heads; self.head_dim=self.hidden_size//self.num_heads self.num_key_value_heads=config.num_key_value_heads; self.num_key_value_groups=self.num_heads//self.num_key_value_heads self.max_position_embeddings=config.max_position_embeddings; self.rope_theta=config.rope_theta; self.is_causal=True self.attention_dropout=config.attention_dropout; self.scaling=self.head_dim**-0.5 self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=True); self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True) self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True); self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False) def forward(self, hidden_states, pos_emb, mask=None, kv_cache=None, output_attentions=False, use_cache=False, cache_position=None, **kw): b,q,_=hidden_states.size(); query=self.q_proj(hidden_states).view(b,q,self.num_heads,self.head_dim).transpose(1,2) key=self.k_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2); value=self.v_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2) cos,sin=pos_emb; query,key=apply_rotary_pos_emb(query,key,cos,sin) if kv_cache is not None: key,value=kv_cache.update(key,value,self.layer_idx,{"sin":sin,"cos":cos,"cache_position":cache_position}) out,weights=eager_attention_forward(self,query,key,value,mask,self.scaling,self.attention_dropout) out=out.reshape(b,q,self.hidden_size); out=self.o_proj(out); return out,weights if output_attentions else None class BiBoRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__(); self.weight=nn.Parameter(torch.ones(hidden_size)); self.variance_epsilon=eps def forward(self, x): dt=x.dtype; x=x.to(torch.float32); v=x.pow(2).mean(-1,keepdim=True); x=x*torch.rsqrt(v+self.variance_epsilon); return self.weight*x.to(dt) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class BiBoDecoderLayer(nn.Module): def __init__(self, config: BiBoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BiBoAttention(config=config, layer_idx=layer_idx) self.input_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx self.num_hidden_layers = config.num_hidden_layers is_first_layer = layer_idx == 0 is_last_layer = layer_idx == config.num_hidden_layers - 1 # Conditional MLP/MoE Instantiation if is_first_layer or is_last_layer: self.mlp = BiBoMLP(config) self.is_moe_layer = False else: self.mlp = BiBoMoELayer(config) self.is_moe_layer = True def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None): """ Returns tuple: (hidden_states,) or (hidden_states, attn_weights) """ residual = hidden_states; hidden_states = self.input_layernorm(hidden_states) attn_outputs, attn_weights = self.self_attn(hidden_states, position_embeddings, attention_mask, past_key_value, output_attentions, use_cache, cache_position) hidden_states = residual + attn_outputs; residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) # --- Conditional Forward --- if self.is_moe_layer: ffn_output = self.mlp(hidden_states) else: ffn_output = self.mlp(hidden_states) hidden_states = residual + ffn_output; outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class BiBoRotaryEmbedding(nn.Module): def __init__(self, config: BiBoConfig, device=None): super().__init__(); rope_scaling=getattr(config,"rope_scaling",None); self.rope_type=rope_scaling.get("rope_type","default") if rope_scaling else "default" self.max_seq_len_cached=config.max_position_embeddings; self.original_max_seq_len=config.max_position_embeddings; self.config=config self.rope_init_fn=ROPE_INIT_FUNCTIONS[self.rope_type]; inv_freq,self.attention_scaling=self.rope_init_fn(self.config,device) self.register_buffer("inv_freq",inv_freq,persistent=False); self.original_inv_freq=self.inv_freq @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): inv_freq=self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1).to(x.device); pos_ids=position_ids[:,None,:].float() dev_type=x.device.type if isinstance(x.device.type,str) and x.device.type!="mps" else "cpu" with torch.autocast(device_type=dev_type,enabled=False): freqs=(inv_freq.float()@pos_ids.float()).transpose(1,2); emb=torch.cat((freqs,freqs),dim=-1) cos=emb.cos()*self.attention_scaling; sin=emb.sin()*self.attention_scaling return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype) BIBO_START_DOCSTRING = r""" BiBo model... """ BIBO_INPUTS_DOCSTRING = r""" Standard arguments... """ @add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING) class BiBoPreTrainedModel(PreTrainedModel): config_class = BiBoConfig base_model_prefix = "model"; supports_gradient_checkpointing = True _no_split_modules = ["BiBoDecoderLayer"]; _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = False; _supports_sdpa = True; _supports_cache_class = True _supports_quantized_cache = True; _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std); module.bias.data.zero_() if module.bias is not None else None elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std); module.weight.data[module.padding_idx].zero_() if module.padding_idx is not None else None elif isinstance(module, BiBoRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)); module.bias.data.zero_() if module.bias is not None else None @add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING) class BiBoModel(BiBoPreTrainedModel): def __init__(self, config: BiBoConfig): super().__init__(config) self.config = config self.padding_idx = config.pad_token_id; self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([BiBoDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.norm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = BiBoRotaryEmbedding(config=config) self.gradient_checkpointing = False; self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def _prepare_decoder_attention_mask(self, mask, shape, embeds, past_len): combined_mask=None; L=shape[-1] if L>1: combined_mask=nn.functional._make_causal_mask(shape,embeds.dtype,device=embeds.device,past_key_values_length=past_len).to(embeds.device) if mask is not None: expanded_mask=nn.functional._expand_mask(mask,embeds.dtype,tgt_len=L).to(embeds.device) combined_mask=(expanded_mask if combined_mask is None else expanded_mask+combined_mask) if combined_mask is not None: bool_mask=combined_mask<0; combined_mask=combined_mask.masked_fill(bool_mask,torch.finfo(embeds.dtype).min) return combined_mask @can_return_tuple @add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING) def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, return_dict=None): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache; return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None)^(inputs_embeds is not None): raise ValueError("Specify ids or embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("Disabling use_cache"); use_cache=False if not isinstance(past_key_values,(type(None),Cache)): raise ValueError("past_key_values type error") if inputs_embeds is None: inputs_embeds=self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values=DynamicCache() past_len=past_key_values.get_seq_length() if past_key_values is not None else 0; seq_len=inputs_embeds.shape[1] if cache_position is None: cache_position=torch.arange(past_len,past_len+seq_len,device=inputs_embeds.device) if position_ids is None: position_ids=cache_position.unsqueeze(0) causal_mask=self._prepare_decoder_attention_mask(attention_mask,(inputs_embeds.shape[0],seq_len),inputs_embeds,past_len) hidden_states=inputs_embeds; pos_emb=self.rotary_emb(hidden_states,position_ids) all_hidden,all_attn=(()if output_hidden_states else None,()if output_attentions else None) for layer in self.layers: if output_hidden_states: all_hidden+=(hidden_states,) layer_outputs=layer(hidden_states,pos_emb,causal_mask,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position) hidden_states=layer_outputs[0] if output_attentions: all_attn+=(layer_outputs[1],) hidden_states=self.norm(hidden_states) if output_hidden_states: all_hidden+=(hidden_states,) next_cache=past_key_values if use_cache else None if not return_dict: return tuple(v for v in [hidden_states,next_cache,all_hidden,all_attn] if v is not None) return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=next_cache,hidden_states=all_hidden,attentions=all_attn) @add_start_docstrings(""" BiBo Model with CausalLM head. """, BIBO_START_DOCSTRING) class BiBoForCausalLM(BiBoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BiBoConfig): super().__init__(config) self.model = BiBoModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() # Methods remain the same def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING) def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, logits_to_keep=0, return_dict=None,): # Add noise arg w/ default r""" Loss calculation (CrossEntropy) must happen outside this function. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict,) hidden_states = model_outputs[0] if not return_dict else model_outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep != 0 else slice(None) logits = self.lm_head(hidden_states[:, slice_indices, :]) # --- Loss is None --- loss = None if labels is not None: warnings.warn("Labels provided but loss calculation must be done externally.") if not return_dict: other_outputs = model_outputs[1:] return (loss,) + (logits,) + other_outputs return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions)