Create modelling_bibo.py
Browse files- modelling_bibo.py +405 -0
modelling_bibo.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The BiBo Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
""" PyTorch BiBo model (Based on Qwen2 with MoE modifications).
|
| 5 |
+
we can use MoEwithoutput class; """
|
| 6 |
+
import math
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from .configuration_bibo import BiBoConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import torch_xla.core.xla_model as xm
|
| 19 |
+
_XLA_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
_XLA_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
from transformers.activations import ACT2FN
|
| 24 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache
|
| 25 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 26 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 27 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 28 |
+
from transformers.generation import GenerationMixin
|
| 29 |
+
from transformers.utils import (
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
is_flash_attn_2_available,
|
| 33 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 34 |
+
logging,
|
| 35 |
+
replace_return_docstrings,
|
| 36 |
+
can_return_tuple,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
_CHECKPOINT_FOR_DOC = "BiBo-MoE-Model"
|
| 42 |
+
_CONFIG_FOR_DOC = "BiBoConfig"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BiBoMLP(nn.Module):
|
| 46 |
+
"""Standard SwiGLU MLP used for dense layers."""
|
| 47 |
+
def __init__(self, config: BiBoConfig):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.hidden_size = config.hidden_size
|
| 50 |
+
self.intermediate_size = config.intermediate_size
|
| 51 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 52 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 53 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 54 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MLPExpert(nn.Module):
|
| 60 |
+
"""SwiGLU based MLP Expert for MoE Layers"""
|
| 61 |
+
def __init__(self, config: BiBoConfig):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.hidden_size = config.hidden_size
|
| 64 |
+
self.intermediate_size = config.moe_intermediate_size
|
| 65 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 66 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 67 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 68 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 71 |
+
|
| 72 |
+
class ModifiedConvolutionalExpert(nn.Module):
|
| 73 |
+
"""Causal Convolutional 'Expert' (Shared) for MoE Layers"""
|
| 74 |
+
def __init__(self, config: BiBoConfig):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.hidden_size = config.hidden_size
|
| 77 |
+
self.intermediate_size = config.moe_intermediate_size
|
| 78 |
+
self.kernel_size_gate = config.kernel_size
|
| 79 |
+
self.causal_padding_gate = self.kernel_size_gate - 1
|
| 80 |
+
self.gate_conv = nn.Conv1d(self.hidden_size, self.intermediate_size, self.kernel_size_gate, padding=0, bias=False)
|
| 81 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 82 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 83 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 84 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
bsz, seq_len, hidden_dim = x.shape
|
| 86 |
+
x_perm = x.permute(0, 2, 1)
|
| 87 |
+
# Apply causal padding
|
| 88 |
+
x_padded = F.pad(x_perm, (self.causal_padding_gate, 0))
|
| 89 |
+
gate_conv_out = self.gate_conv(x_padded)
|
| 90 |
+
gate_activated = self.act_fn(gate_conv_out)
|
| 91 |
+
gate_ready = gate_activated.permute(0, 2, 1)
|
| 92 |
+
up_linear_out = self.up_proj(x)
|
| 93 |
+
intermediate = gate_ready * up_linear_out; output = self.down_proj(intermediate)
|
| 94 |
+
if output.shape[1] != seq_len: raise RuntimeError("ModifiedConvExpert length mismatch")
|
| 95 |
+
return output
|
| 96 |
+
|
| 97 |
+
class IdentityExpert(nn.Module):
|
| 98 |
+
def __init__(self, *args, **kwargs): super().__init__()
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class BiBoMoERouter(nn.Module):
|
| 104 |
+
def __init__(self, config: BiBoConfig):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.num_experts = config.num_routed_experts
|
| 107 |
+
self.top_k = config.num_experts_per_tok
|
| 108 |
+
self.temperature = config.router_temperature
|
| 109 |
+
self.router_noise = config.router_noise
|
| 110 |
+
self.bias = nn.Parameter(torch.zeros(self.num_experts))
|
| 111 |
+
self.gate_proj = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 115 |
+
""" Forward pass with noise, bias, clamping, temperature. """
|
| 116 |
+
|
| 117 |
+
bsz, seq_len, _ = hidden_states.shape; num_tokens = bsz * seq_len
|
| 118 |
+
noise_variance=self.router_noise
|
| 119 |
+
flat_hidden = hidden_states.view(num_tokens, -1)
|
| 120 |
+
router_logits = self.gate_proj(flat_hidden).float()
|
| 121 |
+
|
| 122 |
+
""" No Clamping for Now
|
| 123 |
+
TODO: @aloobun make clamp range dynamic based on mean/median/mode/std of current logits"""
|
| 124 |
+
# if self.logit_clamp_val > 0:
|
| 125 |
+
# router_logits = torch.clamp(router_logits, min=-self.logit_clamp_val, max=self.logit_clamp_val)
|
| 126 |
+
|
| 127 |
+
if self.training and noise_variance > 0:
|
| 128 |
+
noise_stddev = math.sqrt(noise_variance)
|
| 129 |
+
noise = torch.randn_like(router_logits) * noise_stddev
|
| 130 |
+
router_logits = router_logits + noise.detach()
|
| 131 |
+
|
| 132 |
+
router_logits = router_logits + self.bias
|
| 133 |
+
if self.temperature != 1.0:
|
| 134 |
+
router_logits = router_logits / self.temperature
|
| 135 |
+
routing_weights = F.softmax(router_logits, dim=1)
|
| 136 |
+
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 137 |
+
norm_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
| 138 |
+
|
| 139 |
+
return top_k_indices.long(), norm_weights.to(hidden_states.dtype)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class BiBoMoELayer(nn.Module):
|
| 143 |
+
def __init__(self, config: BiBoConfig):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.hidden_size = config.hidden_size; self.num_experts_per_tok = config.num_experts_per_tok
|
| 146 |
+
self.routed_experts = nn.ModuleList()
|
| 147 |
+
num_mlp_routed = config.num_routed_experts - 1
|
| 148 |
+
for _ in range(num_mlp_routed): self.routed_experts.append(MLPExpert(config))
|
| 149 |
+
self.routed_experts.append(IdentityExpert(config))
|
| 150 |
+
if len(self.routed_experts) != config.num_routed_experts: raise ValueError("Routed experts mismatch")
|
| 151 |
+
self.shared_experts_list = nn.ModuleList()
|
| 152 |
+
if config.num_shared_experts > 0:
|
| 153 |
+
if config.num_shared_experts != 1: warnings.warn("Expected 1 shared expert, using 1 Conv.")
|
| 154 |
+
self.shared_experts_list.append(ModifiedConvolutionalExpert(config))
|
| 155 |
+
self.gate = BiBoMoERouter(config)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@torch.no_grad() # Bias update should not track gradients
|
| 159 |
+
def update_bias(self, tpe):
|
| 160 |
+
"""
|
| 161 |
+
Updates the router's learnable bias based on token distribution.
|
| 162 |
+
Ref: https://gist.github.com/joey00072/f9e65f7fe05b763a19e4824bda29c975
|
| 163 |
+
|
| 164 |
+
"""
|
| 165 |
+
if not hasattr(self.gate, 'bias') or self.bias_update_factor <= 0: return
|
| 166 |
+
c = tpe.detach().float()
|
| 167 |
+
e = c.mean() - c
|
| 168 |
+
# Update bias: add_(factor * sign(deviation))
|
| 169 |
+
self.gate.bias.add_(self.bias_update_factor * e.sign())
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 173 |
+
""" Returns: final_output tensor """
|
| 174 |
+
bsz, seq_len, hidden_dim = hidden_states.shape; num_tokens = bsz * seq_len
|
| 175 |
+
flat_hidden = hidden_states.view(num_tokens, -1)
|
| 176 |
+
top_k_indices, top_k_weights = self.gate(hidden_states, noise_variance=self.router_noise)
|
| 177 |
+
|
| 178 |
+
tokens_per_expert = None
|
| 179 |
+
if self.training and hasattr(self.gate, 'bias') and self.bias_update_factor > 0:
|
| 180 |
+
tpe = torch.bincount(top_k_indices.view(-1), minlength=self.num_routed_experts)
|
| 181 |
+
tokens_per_expert = tpe
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
final_routed = torch.zeros_like(flat_hidden); flat_expert_indices = top_k_indices.view(-1)
|
| 185 |
+
flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.num_experts_per_tok)
|
| 186 |
+
for i, expert in enumerate(self.routed_experts):
|
| 187 |
+
mask = (flat_expert_indices == i)
|
| 188 |
+
if mask.any():
|
| 189 |
+
tokens_idx = flat_token_indices[mask]; unique_tokens, orig_indices = torch.unique(tokens_idx, return_inverse=True)
|
| 190 |
+
inputs = flat_hidden[unique_tokens]; outputs = expert(inputs)[orig_indices]
|
| 191 |
+
weights = top_k_weights.view(-1)[mask].unsqueeze(1)
|
| 192 |
+
final_routed.scatter_add_(0, tokens_idx.unsqueeze(1).expand(-1, hidden_dim), outputs * weights)
|
| 193 |
+
final_routed = final_routed.view(bsz, seq_len, hidden_dim)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
shared_combined = torch.zeros_like(hidden_states)
|
| 197 |
+
if self.shared_experts_list: shared_combined = self.shared_experts_list[0](hidden_states)
|
| 198 |
+
final_output = final_routed + shared_combined
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if tokens_per_expert is not None:
|
| 202 |
+
self.update_bias(tokens_per_expert)
|
| 203 |
+
|
| 204 |
+
return final_output
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def rotate_half(x): x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]; return torch.cat((-x2,x1),dim=-1)
|
| 209 |
+
def apply_rotary_pos_emb(q,k,cos,sin,position_ids=None,unsqueeze_dim=1): cos,sin=cos.unsqueeze(unsqueeze_dim),sin.unsqueeze(unsqueeze_dim); return (q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin)
|
| 210 |
+
def repeat_kv(x:torch.Tensor,n:int)->torch.Tensor: b,nk,s,h=x.shape; return x[:,:,None,:,:].expand(b,nk,n,s,h).reshape(b,nk*n,s,h) if n!=1 else x
|
| 211 |
+
def eager_attention_forward(m,q,k,v,mask,scale,dropout=0.0,**kw):
|
| 212 |
+
k,v=repeat_kv(k,m.num_key_value_groups),repeat_kv(v,m.num_key_value_groups); slk=k.shape[-2]
|
| 213 |
+
if mask is not None: mask=mask[:,:,:,:slk]
|
| 214 |
+
w=torch.matmul(q,k.transpose(2,3))*scale
|
| 215 |
+
if mask is not None:
|
| 216 |
+
if mask.size()!=(q.shape[0],1,q.shape[2],k.shape[2]): raise ValueError("Mask shape mismatch")
|
| 217 |
+
w=w+mask
|
| 218 |
+
w=F.softmax(w,dim=-1,dtype=torch.float32).to(q.dtype); w=F.dropout(w,p=dropout,training=m.training)
|
| 219 |
+
o=torch.matmul(w,v).transpose(1,2).contiguous(); return o,w
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class BiBoAttention(nn.Module):
|
| 224 |
+
def __init__(self, config: BiBoConfig, layer_idx: int):
|
| 225 |
+
super().__init__(); self.config=config; self.layer_idx=layer_idx
|
| 226 |
+
self.hidden_size=config.hidden_size; self.num_heads=config.num_attention_heads; self.head_dim=self.hidden_size//self.num_heads
|
| 227 |
+
self.num_key_value_heads=config.num_key_value_heads; self.num_key_value_groups=self.num_heads//self.num_key_value_heads
|
| 228 |
+
self.max_position_embeddings=config.max_position_embeddings; self.rope_theta=config.rope_theta; self.is_causal=True
|
| 229 |
+
self.attention_dropout=config.attention_dropout; self.scaling=self.head_dim**-0.5
|
| 230 |
+
self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=True); self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True)
|
| 231 |
+
self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True); self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def forward(self, hidden_states, pos_emb, mask=None, kv_cache=None, output_attentions=False, use_cache=False, cache_position=None, **kw):
|
| 235 |
+
b,q,_=hidden_states.size(); query=self.q_proj(hidden_states).view(b,q,self.num_heads,self.head_dim).transpose(1,2)
|
| 236 |
+
key=self.k_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2); value=self.v_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2)
|
| 237 |
+
cos,sin=pos_emb; query,key=apply_rotary_pos_emb(query,key,cos,sin)
|
| 238 |
+
if kv_cache is not None: key,value=kv_cache.update(key,value,self.layer_idx,{"sin":sin,"cos":cos,"cache_position":cache_position})
|
| 239 |
+
out,weights=eager_attention_forward(self,query,key,value,mask,self.scaling,self.attention_dropout)
|
| 240 |
+
out=out.reshape(b,q,self.hidden_size); out=self.o_proj(out); return out,weights if output_attentions else None
|
| 241 |
+
|
| 242 |
+
class BiBoRMSNorm(nn.Module):
|
| 243 |
+
def __init__(self, hidden_size, eps=1e-6): super().__init__(); self.weight=nn.Parameter(torch.ones(hidden_size)); self.variance_epsilon=eps
|
| 244 |
+
def forward(self, x): dt=x.dtype; x=x.to(torch.float32); v=x.pow(2).mean(-1,keepdim=True); x=x*torch.rsqrt(v+self.variance_epsilon); return self.weight*x.to(dt)
|
| 245 |
+
def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 246 |
+
|
| 247 |
+
class BiBoDecoderLayer(nn.Module):
|
| 248 |
+
def __init__(self, config: BiBoConfig, layer_idx: int):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.hidden_size = config.hidden_size
|
| 251 |
+
self.self_attn = BiBoAttention(config=config, layer_idx=layer_idx)
|
| 252 |
+
self.input_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 253 |
+
self.post_attention_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 254 |
+
self.layer_idx = layer_idx
|
| 255 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 256 |
+
is_first_layer = layer_idx == 0
|
| 257 |
+
is_last_layer = layer_idx == config.num_hidden_layers - 1
|
| 258 |
+
# Conditional MLP/MoE Instantiation
|
| 259 |
+
if is_first_layer or is_last_layer:
|
| 260 |
+
self.mlp = BiBoMLP(config)
|
| 261 |
+
self.is_moe_layer = False
|
| 262 |
+
else:
|
| 263 |
+
self.mlp = BiBoMoELayer(config)
|
| 264 |
+
self.is_moe_layer = True
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None):
|
| 268 |
+
""" Returns tuple: (hidden_states,) or (hidden_states, attn_weights) """
|
| 269 |
+
residual = hidden_states; hidden_states = self.input_layernorm(hidden_states)
|
| 270 |
+
attn_outputs, attn_weights = self.self_attn(hidden_states, position_embeddings, attention_mask, past_key_value, output_attentions, use_cache, cache_position)
|
| 271 |
+
hidden_states = residual + attn_outputs; residual = hidden_states
|
| 272 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 273 |
+
# --- Conditional Forward ---
|
| 274 |
+
if self.is_moe_layer: ffn_output = self.mlp(hidden_states)
|
| 275 |
+
else: ffn_output = self.mlp(hidden_states)
|
| 276 |
+
hidden_states = residual + ffn_output; outputs = (hidden_states,)
|
| 277 |
+
if output_attentions: outputs += (attn_weights,)
|
| 278 |
+
return outputs
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class BiBoRotaryEmbedding(nn.Module):
|
| 283 |
+
def __init__(self, config: BiBoConfig, device=None):
|
| 284 |
+
super().__init__(); rope_scaling=getattr(config,"rope_scaling",None); self.rope_type=rope_scaling.get("rope_type","default") if rope_scaling else "default"
|
| 285 |
+
self.max_seq_len_cached=config.max_position_embeddings; self.original_max_seq_len=config.max_position_embeddings; self.config=config
|
| 286 |
+
self.rope_init_fn=ROPE_INIT_FUNCTIONS[self.rope_type]; inv_freq,self.attention_scaling=self.rope_init_fn(self.config,device)
|
| 287 |
+
self.register_buffer("inv_freq",inv_freq,persistent=False); self.original_inv_freq=self.inv_freq
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@torch.no_grad()
|
| 291 |
+
@dynamic_rope_update
|
| 292 |
+
def forward(self, x, position_ids):
|
| 293 |
+
inv_freq=self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1).to(x.device); pos_ids=position_ids[:,None,:].float()
|
| 294 |
+
dev_type=x.device.type if isinstance(x.device.type,str) and x.device.type!="mps" else "cpu"
|
| 295 |
+
with torch.autocast(device_type=dev_type,enabled=False):
|
| 296 |
+
freqs=(inv_freq.float()@pos_ids.float()).transpose(1,2); emb=torch.cat((freqs,freqs),dim=-1)
|
| 297 |
+
cos=emb.cos()*self.attention_scaling; sin=emb.sin()*self.attention_scaling
|
| 298 |
+
return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
BIBO_START_DOCSTRING = r""" BiBo model... """
|
| 302 |
+
BIBO_INPUTS_DOCSTRING = r""" Standard arguments... """
|
| 303 |
+
|
| 304 |
+
@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
|
| 305 |
+
class BiBoPreTrainedModel(PreTrainedModel):
|
| 306 |
+
config_class = BiBoConfig
|
| 307 |
+
base_model_prefix = "model"; supports_gradient_checkpointing = True
|
| 308 |
+
_no_split_modules = ["BiBoDecoderLayer"]; _skip_keys_device_placement = ["past_key_values"]
|
| 309 |
+
_supports_flash_attn_2 = False; _supports_sdpa = True; _supports_cache_class = True
|
| 310 |
+
_supports_quantized_cache = True; _supports_static_cache = True
|
| 311 |
+
def _init_weights(self, module):
|
| 312 |
+
std = self.config.initializer_range
|
| 313 |
+
if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std); module.bias.data.zero_() if module.bias is not None else None
|
| 314 |
+
elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std); module.weight.data[module.padding_idx].zero_() if module.padding_idx is not None else None
|
| 315 |
+
elif isinstance(module, BiBoRMSNorm): module.weight.data.fill_(1.0)
|
| 316 |
+
elif isinstance(module, nn.Conv1d): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)); module.bias.data.zero_() if module.bias is not None else None
|
| 317 |
+
|
| 318 |
+
@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
|
| 319 |
+
class BiBoModel(BiBoPreTrainedModel):
|
| 320 |
+
def __init__(self, config: BiBoConfig):
|
| 321 |
+
super().__init__(config)
|
| 322 |
+
self.config = config
|
| 323 |
+
self.padding_idx = config.pad_token_id; self.vocab_size = config.vocab_size
|
| 324 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 325 |
+
self.layers = nn.ModuleList([BiBoDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
|
| 326 |
+
self.norm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 327 |
+
self.rotary_emb = BiBoRotaryEmbedding(config=config)
|
| 328 |
+
self.gradient_checkpointing = False; self.post_init()
|
| 329 |
+
|
| 330 |
+
def get_input_embeddings(self): return self.embed_tokens
|
| 331 |
+
def set_input_embeddings(self, value): self.embed_tokens = value
|
| 332 |
+
|
| 333 |
+
def _prepare_decoder_attention_mask(self, mask, shape, embeds, past_len):
|
| 334 |
+
combined_mask=None; L=shape[-1]
|
| 335 |
+
if L>1: combined_mask=nn.functional._make_causal_mask(shape,embeds.dtype,device=embeds.device,past_key_values_length=past_len).to(embeds.device)
|
| 336 |
+
if mask is not None:
|
| 337 |
+
expanded_mask=nn.functional._expand_mask(mask,embeds.dtype,tgt_len=L).to(embeds.device)
|
| 338 |
+
combined_mask=(expanded_mask if combined_mask is None else expanded_mask+combined_mask)
|
| 339 |
+
if combined_mask is not None: bool_mask=combined_mask<0; combined_mask=combined_mask.masked_fill(bool_mask,torch.finfo(embeds.dtype).min)
|
| 340 |
+
return combined_mask
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@can_return_tuple
|
| 344 |
+
@add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
|
| 345 |
+
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, return_dict=None):
|
| 346 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 347 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache; return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 348 |
+
if (input_ids is None)^(inputs_embeds is not None): raise ValueError("Specify ids or embeds")
|
| 349 |
+
if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("Disabling use_cache"); use_cache=False
|
| 350 |
+
if not isinstance(past_key_values,(type(None),Cache)): raise ValueError("past_key_values type error")
|
| 351 |
+
if inputs_embeds is None: inputs_embeds=self.embed_tokens(input_ids)
|
| 352 |
+
if use_cache and past_key_values is None: past_key_values=DynamicCache()
|
| 353 |
+
past_len=past_key_values.get_seq_length() if past_key_values is not None else 0; seq_len=inputs_embeds.shape[1]
|
| 354 |
+
if cache_position is None: cache_position=torch.arange(past_len,past_len+seq_len,device=inputs_embeds.device)
|
| 355 |
+
if position_ids is None: position_ids=cache_position.unsqueeze(0)
|
| 356 |
+
causal_mask=self._prepare_decoder_attention_mask(attention_mask,(inputs_embeds.shape[0],seq_len),inputs_embeds,past_len)
|
| 357 |
+
hidden_states=inputs_embeds; pos_emb=self.rotary_emb(hidden_states,position_ids)
|
| 358 |
+
all_hidden,all_attn=(()if output_hidden_states else None,()if output_attentions else None)
|
| 359 |
+
for layer in self.layers:
|
| 360 |
+
if output_hidden_states: all_hidden+=(hidden_states,)
|
| 361 |
+
layer_outputs=layer(hidden_states,pos_emb,causal_mask,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position)
|
| 362 |
+
hidden_states=layer_outputs[0]
|
| 363 |
+
if output_attentions: all_attn+=(layer_outputs[1],)
|
| 364 |
+
hidden_states=self.norm(hidden_states)
|
| 365 |
+
if output_hidden_states: all_hidden+=(hidden_states,)
|
| 366 |
+
next_cache=past_key_values if use_cache else None
|
| 367 |
+
if not return_dict: return tuple(v for v in [hidden_states,next_cache,all_hidden,all_attn] if v is not None)
|
| 368 |
+
return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=next_cache,hidden_states=all_hidden,attentions=all_attn)
|
| 369 |
+
|
| 370 |
+
@add_start_docstrings(""" BiBo Model with CausalLM head. """, BIBO_START_DOCSTRING)
|
| 371 |
+
class BiBoForCausalLM(BiBoPreTrainedModel, GenerationMixin):
|
| 372 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 373 |
+
def __init__(self, config: BiBoConfig):
|
| 374 |
+
super().__init__(config)
|
| 375 |
+
self.model = BiBoModel(config)
|
| 376 |
+
self.vocab_size = config.vocab_size
|
| 377 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 378 |
+
self.post_init()
|
| 379 |
+
# Methods remain the same
|
| 380 |
+
def get_input_embeddings(self): return self.model.embed_tokens
|
| 381 |
+
def set_input_embeddings(self, value): self.model.embed_tokens = value
|
| 382 |
+
def get_output_embeddings(self): return self.lm_head
|
| 383 |
+
def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
|
| 384 |
+
def set_decoder(self, decoder): self.model = decoder
|
| 385 |
+
def get_decoder(self): return self.model
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
@can_return_tuple
|
| 390 |
+
@add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
|
| 391 |
+
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, logits_to_keep=0, return_dict=None,): # Add noise arg w/ default
|
| 392 |
+
r""" Loss calculation (CrossEntropy) must happen outside this function. """
|
| 393 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 394 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 395 |
+
model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict,)
|
| 396 |
+
hidden_states = model_outputs[0] if not return_dict else model_outputs.last_hidden_state
|
| 397 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep != 0 else slice(None)
|
| 398 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 399 |
+
# --- Loss is None ---
|
| 400 |
+
loss = None
|
| 401 |
+
if labels is not None: warnings.warn("Labels provided but loss calculation must be done externally.")
|
| 402 |
+
if not return_dict:
|
| 403 |
+
other_outputs = model_outputs[1:]
|
| 404 |
+
return (loss,) + (logits,) + other_outputs
|
| 405 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions)
|