fhai50032 commited on
Commit
7dd85ba
·
verified ·
1 Parent(s): f2f9ba0

Create modelling_bibo.py

Browse files
Files changed (1) hide show
  1. modelling_bibo.py +405 -0
modelling_bibo.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The BiBo Authors and The HuggingFace Inc. team. All rights reserved.
3
+
4
+ """ PyTorch BiBo model (Based on Qwen2 with MoE modifications).
5
+ we can use MoEwithoutput class; """
6
+ import math
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from .configuration_bibo import BiBoConfig
15
+
16
+
17
+ try:
18
+ import torch_xla.core.xla_model as xm
19
+ _XLA_AVAILABLE = True
20
+ except ImportError:
21
+ _XLA_AVAILABLE = False
22
+
23
+ from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
26
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.utils import (
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ is_flash_attn_2_available,
33
+ is_flash_attn_greater_or_equal_2_10,
34
+ logging,
35
+ replace_return_docstrings,
36
+ can_return_tuple,
37
+ )
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CHECKPOINT_FOR_DOC = "BiBo-MoE-Model"
42
+ _CONFIG_FOR_DOC = "BiBoConfig"
43
+
44
+
45
+ class BiBoMLP(nn.Module):
46
+ """Standard SwiGLU MLP used for dense layers."""
47
+ def __init__(self, config: BiBoConfig):
48
+ super().__init__()
49
+ self.hidden_size = config.hidden_size
50
+ self.intermediate_size = config.intermediate_size
51
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
52
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
53
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
54
+ self.act_fn = ACT2FN[config.hidden_act]
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
57
+
58
+
59
+ class MLPExpert(nn.Module):
60
+ """SwiGLU based MLP Expert for MoE Layers"""
61
+ def __init__(self, config: BiBoConfig):
62
+ super().__init__()
63
+ self.hidden_size = config.hidden_size
64
+ self.intermediate_size = config.moe_intermediate_size
65
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
66
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
67
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
68
+ self.act_fn = ACT2FN[config.hidden_act]
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
71
+
72
+ class ModifiedConvolutionalExpert(nn.Module):
73
+ """Causal Convolutional 'Expert' (Shared) for MoE Layers"""
74
+ def __init__(self, config: BiBoConfig):
75
+ super().__init__()
76
+ self.hidden_size = config.hidden_size
77
+ self.intermediate_size = config.moe_intermediate_size
78
+ self.kernel_size_gate = config.kernel_size
79
+ self.causal_padding_gate = self.kernel_size_gate - 1
80
+ self.gate_conv = nn.Conv1d(self.hidden_size, self.intermediate_size, self.kernel_size_gate, padding=0, bias=False)
81
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
82
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
83
+ self.act_fn = ACT2FN[config.hidden_act]
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ bsz, seq_len, hidden_dim = x.shape
86
+ x_perm = x.permute(0, 2, 1)
87
+ # Apply causal padding
88
+ x_padded = F.pad(x_perm, (self.causal_padding_gate, 0))
89
+ gate_conv_out = self.gate_conv(x_padded)
90
+ gate_activated = self.act_fn(gate_conv_out)
91
+ gate_ready = gate_activated.permute(0, 2, 1)
92
+ up_linear_out = self.up_proj(x)
93
+ intermediate = gate_ready * up_linear_out; output = self.down_proj(intermediate)
94
+ if output.shape[1] != seq_len: raise RuntimeError("ModifiedConvExpert length mismatch")
95
+ return output
96
+
97
+ class IdentityExpert(nn.Module):
98
+ def __init__(self, *args, **kwargs): super().__init__()
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor: return x
100
+
101
+
102
+
103
+ class BiBoMoERouter(nn.Module):
104
+ def __init__(self, config: BiBoConfig):
105
+ super().__init__()
106
+ self.num_experts = config.num_routed_experts
107
+ self.top_k = config.num_experts_per_tok
108
+ self.temperature = config.router_temperature
109
+ self.router_noise = config.router_noise
110
+ self.bias = nn.Parameter(torch.zeros(self.num_experts))
111
+ self.gate_proj = nn.Linear(config.hidden_size, self.num_experts, bias=False)
112
+
113
+
114
+ def forward(self, hidden_states: torch.Tensor):
115
+ """ Forward pass with noise, bias, clamping, temperature. """
116
+
117
+ bsz, seq_len, _ = hidden_states.shape; num_tokens = bsz * seq_len
118
+ noise_variance=self.router_noise
119
+ flat_hidden = hidden_states.view(num_tokens, -1)
120
+ router_logits = self.gate_proj(flat_hidden).float()
121
+
122
+ """ No Clamping for Now
123
+ TODO: @aloobun make clamp range dynamic based on mean/median/mode/std of current logits"""
124
+ # if self.logit_clamp_val > 0:
125
+ # router_logits = torch.clamp(router_logits, min=-self.logit_clamp_val, max=self.logit_clamp_val)
126
+
127
+ if self.training and noise_variance > 0:
128
+ noise_stddev = math.sqrt(noise_variance)
129
+ noise = torch.randn_like(router_logits) * noise_stddev
130
+ router_logits = router_logits + noise.detach()
131
+
132
+ router_logits = router_logits + self.bias
133
+ if self.temperature != 1.0:
134
+ router_logits = router_logits / self.temperature
135
+ routing_weights = F.softmax(router_logits, dim=1)
136
+ top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
137
+ norm_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
138
+
139
+ return top_k_indices.long(), norm_weights.to(hidden_states.dtype)
140
+
141
+
142
+ class BiBoMoELayer(nn.Module):
143
+ def __init__(self, config: BiBoConfig):
144
+ super().__init__()
145
+ self.hidden_size = config.hidden_size; self.num_experts_per_tok = config.num_experts_per_tok
146
+ self.routed_experts = nn.ModuleList()
147
+ num_mlp_routed = config.num_routed_experts - 1
148
+ for _ in range(num_mlp_routed): self.routed_experts.append(MLPExpert(config))
149
+ self.routed_experts.append(IdentityExpert(config))
150
+ if len(self.routed_experts) != config.num_routed_experts: raise ValueError("Routed experts mismatch")
151
+ self.shared_experts_list = nn.ModuleList()
152
+ if config.num_shared_experts > 0:
153
+ if config.num_shared_experts != 1: warnings.warn("Expected 1 shared expert, using 1 Conv.")
154
+ self.shared_experts_list.append(ModifiedConvolutionalExpert(config))
155
+ self.gate = BiBoMoERouter(config)
156
+
157
+
158
+ @torch.no_grad() # Bias update should not track gradients
159
+ def update_bias(self, tpe):
160
+ """
161
+ Updates the router's learnable bias based on token distribution.
162
+ Ref: https://gist.github.com/joey00072/f9e65f7fe05b763a19e4824bda29c975
163
+
164
+ """
165
+ if not hasattr(self.gate, 'bias') or self.bias_update_factor <= 0: return
166
+ c = tpe.detach().float()
167
+ e = c.mean() - c
168
+ # Update bias: add_(factor * sign(deviation))
169
+ self.gate.bias.add_(self.bias_update_factor * e.sign())
170
+
171
+
172
+ def forward(self, hidden_states: torch.Tensor):
173
+ """ Returns: final_output tensor """
174
+ bsz, seq_len, hidden_dim = hidden_states.shape; num_tokens = bsz * seq_len
175
+ flat_hidden = hidden_states.view(num_tokens, -1)
176
+ top_k_indices, top_k_weights = self.gate(hidden_states, noise_variance=self.router_noise)
177
+
178
+ tokens_per_expert = None
179
+ if self.training and hasattr(self.gate, 'bias') and self.bias_update_factor > 0:
180
+ tpe = torch.bincount(top_k_indices.view(-1), minlength=self.num_routed_experts)
181
+ tokens_per_expert = tpe
182
+
183
+
184
+ final_routed = torch.zeros_like(flat_hidden); flat_expert_indices = top_k_indices.view(-1)
185
+ flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.num_experts_per_tok)
186
+ for i, expert in enumerate(self.routed_experts):
187
+ mask = (flat_expert_indices == i)
188
+ if mask.any():
189
+ tokens_idx = flat_token_indices[mask]; unique_tokens, orig_indices = torch.unique(tokens_idx, return_inverse=True)
190
+ inputs = flat_hidden[unique_tokens]; outputs = expert(inputs)[orig_indices]
191
+ weights = top_k_weights.view(-1)[mask].unsqueeze(1)
192
+ final_routed.scatter_add_(0, tokens_idx.unsqueeze(1).expand(-1, hidden_dim), outputs * weights)
193
+ final_routed = final_routed.view(bsz, seq_len, hidden_dim)
194
+
195
+
196
+ shared_combined = torch.zeros_like(hidden_states)
197
+ if self.shared_experts_list: shared_combined = self.shared_experts_list[0](hidden_states)
198
+ final_output = final_routed + shared_combined
199
+
200
+
201
+ if tokens_per_expert is not None:
202
+ self.update_bias(tokens_per_expert)
203
+
204
+ return final_output
205
+
206
+
207
+
208
+ def rotate_half(x): x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]; return torch.cat((-x2,x1),dim=-1)
209
+ def apply_rotary_pos_emb(q,k,cos,sin,position_ids=None,unsqueeze_dim=1): cos,sin=cos.unsqueeze(unsqueeze_dim),sin.unsqueeze(unsqueeze_dim); return (q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin)
210
+ def repeat_kv(x:torch.Tensor,n:int)->torch.Tensor: b,nk,s,h=x.shape; return x[:,:,None,:,:].expand(b,nk,n,s,h).reshape(b,nk*n,s,h) if n!=1 else x
211
+ def eager_attention_forward(m,q,k,v,mask,scale,dropout=0.0,**kw):
212
+ k,v=repeat_kv(k,m.num_key_value_groups),repeat_kv(v,m.num_key_value_groups); slk=k.shape[-2]
213
+ if mask is not None: mask=mask[:,:,:,:slk]
214
+ w=torch.matmul(q,k.transpose(2,3))*scale
215
+ if mask is not None:
216
+ if mask.size()!=(q.shape[0],1,q.shape[2],k.shape[2]): raise ValueError("Mask shape mismatch")
217
+ w=w+mask
218
+ w=F.softmax(w,dim=-1,dtype=torch.float32).to(q.dtype); w=F.dropout(w,p=dropout,training=m.training)
219
+ o=torch.matmul(w,v).transpose(1,2).contiguous(); return o,w
220
+
221
+
222
+
223
+ class BiBoAttention(nn.Module):
224
+ def __init__(self, config: BiBoConfig, layer_idx: int):
225
+ super().__init__(); self.config=config; self.layer_idx=layer_idx
226
+ self.hidden_size=config.hidden_size; self.num_heads=config.num_attention_heads; self.head_dim=self.hidden_size//self.num_heads
227
+ self.num_key_value_heads=config.num_key_value_heads; self.num_key_value_groups=self.num_heads//self.num_key_value_heads
228
+ self.max_position_embeddings=config.max_position_embeddings; self.rope_theta=config.rope_theta; self.is_causal=True
229
+ self.attention_dropout=config.attention_dropout; self.scaling=self.head_dim**-0.5
230
+ self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=True); self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True)
231
+ self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True); self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False)
232
+
233
+
234
+ def forward(self, hidden_states, pos_emb, mask=None, kv_cache=None, output_attentions=False, use_cache=False, cache_position=None, **kw):
235
+ b,q,_=hidden_states.size(); query=self.q_proj(hidden_states).view(b,q,self.num_heads,self.head_dim).transpose(1,2)
236
+ key=self.k_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2); value=self.v_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2)
237
+ cos,sin=pos_emb; query,key=apply_rotary_pos_emb(query,key,cos,sin)
238
+ if kv_cache is not None: key,value=kv_cache.update(key,value,self.layer_idx,{"sin":sin,"cos":cos,"cache_position":cache_position})
239
+ out,weights=eager_attention_forward(self,query,key,value,mask,self.scaling,self.attention_dropout)
240
+ out=out.reshape(b,q,self.hidden_size); out=self.o_proj(out); return out,weights if output_attentions else None
241
+
242
+ class BiBoRMSNorm(nn.Module):
243
+ def __init__(self, hidden_size, eps=1e-6): super().__init__(); self.weight=nn.Parameter(torch.ones(hidden_size)); self.variance_epsilon=eps
244
+ def forward(self, x): dt=x.dtype; x=x.to(torch.float32); v=x.pow(2).mean(-1,keepdim=True); x=x*torch.rsqrt(v+self.variance_epsilon); return self.weight*x.to(dt)
245
+ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
246
+
247
+ class BiBoDecoderLayer(nn.Module):
248
+ def __init__(self, config: BiBoConfig, layer_idx: int):
249
+ super().__init__()
250
+ self.hidden_size = config.hidden_size
251
+ self.self_attn = BiBoAttention(config=config, layer_idx=layer_idx)
252
+ self.input_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
253
+ self.post_attention_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
254
+ self.layer_idx = layer_idx
255
+ self.num_hidden_layers = config.num_hidden_layers
256
+ is_first_layer = layer_idx == 0
257
+ is_last_layer = layer_idx == config.num_hidden_layers - 1
258
+ # Conditional MLP/MoE Instantiation
259
+ if is_first_layer or is_last_layer:
260
+ self.mlp = BiBoMLP(config)
261
+ self.is_moe_layer = False
262
+ else:
263
+ self.mlp = BiBoMoELayer(config)
264
+ self.is_moe_layer = True
265
+
266
+
267
+ def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None):
268
+ """ Returns tuple: (hidden_states,) or (hidden_states, attn_weights) """
269
+ residual = hidden_states; hidden_states = self.input_layernorm(hidden_states)
270
+ attn_outputs, attn_weights = self.self_attn(hidden_states, position_embeddings, attention_mask, past_key_value, output_attentions, use_cache, cache_position)
271
+ hidden_states = residual + attn_outputs; residual = hidden_states
272
+ hidden_states = self.post_attention_layernorm(hidden_states)
273
+ # --- Conditional Forward ---
274
+ if self.is_moe_layer: ffn_output = self.mlp(hidden_states)
275
+ else: ffn_output = self.mlp(hidden_states)
276
+ hidden_states = residual + ffn_output; outputs = (hidden_states,)
277
+ if output_attentions: outputs += (attn_weights,)
278
+ return outputs
279
+
280
+
281
+
282
+ class BiBoRotaryEmbedding(nn.Module):
283
+ def __init__(self, config: BiBoConfig, device=None):
284
+ super().__init__(); rope_scaling=getattr(config,"rope_scaling",None); self.rope_type=rope_scaling.get("rope_type","default") if rope_scaling else "default"
285
+ self.max_seq_len_cached=config.max_position_embeddings; self.original_max_seq_len=config.max_position_embeddings; self.config=config
286
+ self.rope_init_fn=ROPE_INIT_FUNCTIONS[self.rope_type]; inv_freq,self.attention_scaling=self.rope_init_fn(self.config,device)
287
+ self.register_buffer("inv_freq",inv_freq,persistent=False); self.original_inv_freq=self.inv_freq
288
+
289
+
290
+ @torch.no_grad()
291
+ @dynamic_rope_update
292
+ def forward(self, x, position_ids):
293
+ inv_freq=self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1).to(x.device); pos_ids=position_ids[:,None,:].float()
294
+ dev_type=x.device.type if isinstance(x.device.type,str) and x.device.type!="mps" else "cpu"
295
+ with torch.autocast(device_type=dev_type,enabled=False):
296
+ freqs=(inv_freq.float()@pos_ids.float()).transpose(1,2); emb=torch.cat((freqs,freqs),dim=-1)
297
+ cos=emb.cos()*self.attention_scaling; sin=emb.sin()*self.attention_scaling
298
+ return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype)
299
+
300
+
301
+ BIBO_START_DOCSTRING = r""" BiBo model... """
302
+ BIBO_INPUTS_DOCSTRING = r""" Standard arguments... """
303
+
304
+ @add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
305
+ class BiBoPreTrainedModel(PreTrainedModel):
306
+ config_class = BiBoConfig
307
+ base_model_prefix = "model"; supports_gradient_checkpointing = True
308
+ _no_split_modules = ["BiBoDecoderLayer"]; _skip_keys_device_placement = ["past_key_values"]
309
+ _supports_flash_attn_2 = False; _supports_sdpa = True; _supports_cache_class = True
310
+ _supports_quantized_cache = True; _supports_static_cache = True
311
+ def _init_weights(self, module):
312
+ std = self.config.initializer_range
313
+ if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std); module.bias.data.zero_() if module.bias is not None else None
314
+ elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std); module.weight.data[module.padding_idx].zero_() if module.padding_idx is not None else None
315
+ elif isinstance(module, BiBoRMSNorm): module.weight.data.fill_(1.0)
316
+ elif isinstance(module, nn.Conv1d): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)); module.bias.data.zero_() if module.bias is not None else None
317
+
318
+ @add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
319
+ class BiBoModel(BiBoPreTrainedModel):
320
+ def __init__(self, config: BiBoConfig):
321
+ super().__init__(config)
322
+ self.config = config
323
+ self.padding_idx = config.pad_token_id; self.vocab_size = config.vocab_size
324
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
325
+ self.layers = nn.ModuleList([BiBoDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
326
+ self.norm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
327
+ self.rotary_emb = BiBoRotaryEmbedding(config=config)
328
+ self.gradient_checkpointing = False; self.post_init()
329
+
330
+ def get_input_embeddings(self): return self.embed_tokens
331
+ def set_input_embeddings(self, value): self.embed_tokens = value
332
+
333
+ def _prepare_decoder_attention_mask(self, mask, shape, embeds, past_len):
334
+ combined_mask=None; L=shape[-1]
335
+ if L>1: combined_mask=nn.functional._make_causal_mask(shape,embeds.dtype,device=embeds.device,past_key_values_length=past_len).to(embeds.device)
336
+ if mask is not None:
337
+ expanded_mask=nn.functional._expand_mask(mask,embeds.dtype,tgt_len=L).to(embeds.device)
338
+ combined_mask=(expanded_mask if combined_mask is None else expanded_mask+combined_mask)
339
+ if combined_mask is not None: bool_mask=combined_mask<0; combined_mask=combined_mask.masked_fill(bool_mask,torch.finfo(embeds.dtype).min)
340
+ return combined_mask
341
+
342
+
343
+ @can_return_tuple
344
+ @add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
345
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, return_dict=None):
346
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
347
+ use_cache = use_cache if use_cache is not None else self.config.use_cache; return_dict = return_dict if return_dict is not None else self.config.use_return_dict
348
+ if (input_ids is None)^(inputs_embeds is not None): raise ValueError("Specify ids or embeds")
349
+ if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("Disabling use_cache"); use_cache=False
350
+ if not isinstance(past_key_values,(type(None),Cache)): raise ValueError("past_key_values type error")
351
+ if inputs_embeds is None: inputs_embeds=self.embed_tokens(input_ids)
352
+ if use_cache and past_key_values is None: past_key_values=DynamicCache()
353
+ past_len=past_key_values.get_seq_length() if past_key_values is not None else 0; seq_len=inputs_embeds.shape[1]
354
+ if cache_position is None: cache_position=torch.arange(past_len,past_len+seq_len,device=inputs_embeds.device)
355
+ if position_ids is None: position_ids=cache_position.unsqueeze(0)
356
+ causal_mask=self._prepare_decoder_attention_mask(attention_mask,(inputs_embeds.shape[0],seq_len),inputs_embeds,past_len)
357
+ hidden_states=inputs_embeds; pos_emb=self.rotary_emb(hidden_states,position_ids)
358
+ all_hidden,all_attn=(()if output_hidden_states else None,()if output_attentions else None)
359
+ for layer in self.layers:
360
+ if output_hidden_states: all_hidden+=(hidden_states,)
361
+ layer_outputs=layer(hidden_states,pos_emb,causal_mask,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position)
362
+ hidden_states=layer_outputs[0]
363
+ if output_attentions: all_attn+=(layer_outputs[1],)
364
+ hidden_states=self.norm(hidden_states)
365
+ if output_hidden_states: all_hidden+=(hidden_states,)
366
+ next_cache=past_key_values if use_cache else None
367
+ if not return_dict: return tuple(v for v in [hidden_states,next_cache,all_hidden,all_attn] if v is not None)
368
+ return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=next_cache,hidden_states=all_hidden,attentions=all_attn)
369
+
370
+ @add_start_docstrings(""" BiBo Model with CausalLM head. """, BIBO_START_DOCSTRING)
371
+ class BiBoForCausalLM(BiBoPreTrainedModel, GenerationMixin):
372
+ _tied_weights_keys = ["lm_head.weight"]
373
+ def __init__(self, config: BiBoConfig):
374
+ super().__init__(config)
375
+ self.model = BiBoModel(config)
376
+ self.vocab_size = config.vocab_size
377
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
378
+ self.post_init()
379
+ # Methods remain the same
380
+ def get_input_embeddings(self): return self.model.embed_tokens
381
+ def set_input_embeddings(self, value): self.model.embed_tokens = value
382
+ def get_output_embeddings(self): return self.lm_head
383
+ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
384
+ def set_decoder(self, decoder): self.model = decoder
385
+ def get_decoder(self): return self.model
386
+
387
+
388
+
389
+ @can_return_tuple
390
+ @add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
391
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, logits_to_keep=0, return_dict=None,): # Add noise arg w/ default
392
+ r""" Loss calculation (CrossEntropy) must happen outside this function. """
393
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+ model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict,)
396
+ hidden_states = model_outputs[0] if not return_dict else model_outputs.last_hidden_state
397
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep != 0 else slice(None)
398
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
399
+ # --- Loss is None ---
400
+ loss = None
401
+ if labels is not None: warnings.warn("Labels provided but loss calculation must be done externally.")
402
+ if not return_dict:
403
+ other_outputs = model_outputs[1:]
404
+ return (loss,) + (logits,) + other_outputs
405
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions)