File size: 23,961 Bytes
7dd85ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# coding=utf-8
# Copyright 2024 The BiBo Authors and The HuggingFace Inc. team. All rights reserved.

""" PyTorch BiBo model (Based on Qwen2 with MoE modifications).
we can use MoEwithoutput class; """
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from .configuration_bibo import BiBoConfig


try:
    import torch_xla.core.xla_model as xm
    _XLA_AVAILABLE = True
except ImportError:
    _XLA_AVAILABLE = False

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
    can_return_tuple,
)

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "BiBo-MoE-Model" 
_CONFIG_FOR_DOC = "BiBoConfig"


class BiBoMLP(nn.Module):
    """Standard SwiGLU MLP used for dense layers."""
    def __init__(self, config: BiBoConfig): 
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size 
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class MLPExpert(nn.Module):
    """SwiGLU based MLP Expert for MoE Layers"""
    def __init__(self, config: BiBoConfig): 
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.moe_intermediate_size 
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

class ModifiedConvolutionalExpert(nn.Module):
    """Causal Convolutional 'Expert' (Shared) for MoE Layers"""
    def __init__(self, config: BiBoConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.moe_intermediate_size 
        self.kernel_size_gate = config.kernel_size
        self.causal_padding_gate = self.kernel_size_gate - 1 
        self.gate_conv = nn.Conv1d(self.hidden_size, self.intermediate_size, self.kernel_size_gate, padding=0, bias=False) 
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bsz, seq_len, hidden_dim = x.shape
        x_perm = x.permute(0, 2, 1)
        # Apply causal padding
        x_padded = F.pad(x_perm, (self.causal_padding_gate, 0))
        gate_conv_out = self.gate_conv(x_padded)
        gate_activated = self.act_fn(gate_conv_out)
        gate_ready = gate_activated.permute(0, 2, 1)
        up_linear_out = self.up_proj(x)
        intermediate = gate_ready * up_linear_out; output = self.down_proj(intermediate)
        if output.shape[1] != seq_len: raise RuntimeError("ModifiedConvExpert length mismatch")
        return output

class IdentityExpert(nn.Module):
    def __init__(self, *args, **kwargs): super().__init__()
    def forward(self, x: torch.Tensor) -> torch.Tensor: return x



class BiBoMoERouter(nn.Module):
    def __init__(self, config: BiBoConfig):
        super().__init__()
        self.num_experts = config.num_routed_experts
        self.top_k = config.num_experts_per_tok
        self.temperature = config.router_temperature
        self.router_noise = config.router_noise
        self.bias = nn.Parameter(torch.zeros(self.num_experts))
        self.gate_proj = nn.Linear(config.hidden_size, self.num_experts, bias=False)

    
    def forward(self, hidden_states: torch.Tensor):
        """ Forward pass with noise, bias, clamping, temperature. """
        
        bsz, seq_len, _ = hidden_states.shape; num_tokens = bsz * seq_len
        noise_variance=self.router_noise
        flat_hidden = hidden_states.view(num_tokens, -1)
        router_logits = self.gate_proj(flat_hidden).float() 

        """ No Clamping for Now
        TODO: @aloobun make clamp range dynamic based on mean/median/mode/std of current logits"""
        # if self.logit_clamp_val > 0:
        #     router_logits = torch.clamp(router_logits, min=-self.logit_clamp_val, max=self.logit_clamp_val)

        if self.training and noise_variance > 0:
            noise_stddev = math.sqrt(noise_variance)
            noise = torch.randn_like(router_logits) * noise_stddev
            router_logits = router_logits + noise.detach()

        router_logits = router_logits + self.bias 
        if self.temperature != 1.0:
            router_logits = router_logits / self.temperature
        routing_weights = F.softmax(router_logits, dim=1)
        top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
        norm_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)

        return top_k_indices.long(), norm_weights.to(hidden_states.dtype)
        

class BiBoMoELayer(nn.Module):
    def __init__(self, config: BiBoConfig):
        super().__init__()
        self.hidden_size = config.hidden_size; self.num_experts_per_tok = config.num_experts_per_tok
        self.routed_experts = nn.ModuleList()
        num_mlp_routed = config.num_routed_experts - 1
        for _ in range(num_mlp_routed): self.routed_experts.append(MLPExpert(config))
        self.routed_experts.append(IdentityExpert(config))
        if len(self.routed_experts) != config.num_routed_experts: raise ValueError("Routed experts mismatch")
        self.shared_experts_list = nn.ModuleList()
        if config.num_shared_experts > 0:
             if config.num_shared_experts != 1: warnings.warn("Expected 1 shared expert, using 1 Conv.")
             self.shared_experts_list.append(ModifiedConvolutionalExpert(config))
        self.gate = BiBoMoERouter(config)


    @torch.no_grad() # Bias update should not track gradients
    def update_bias(self, tpe):
        """ 
        Updates the router's learnable bias based on token distribution. 
        Ref: https://gist.github.com/joey00072/f9e65f7fe05b763a19e4824bda29c975
        
        """
        if not hasattr(self.gate, 'bias') or self.bias_update_factor <= 0: return
        c = tpe.detach().float() 
        e = c.mean() - c 
        # Update bias: add_(factor * sign(deviation))
        self.gate.bias.add_(self.bias_update_factor * e.sign())

    
    def forward(self, hidden_states: torch.Tensor): 
        """ Returns: final_output tensor """
        bsz, seq_len, hidden_dim = hidden_states.shape; num_tokens = bsz * seq_len
        flat_hidden = hidden_states.view(num_tokens, -1)
        top_k_indices, top_k_weights = self.gate(hidden_states, noise_variance=self.router_noise)

        tokens_per_expert = None
        if self.training and hasattr(self.gate, 'bias') and self.bias_update_factor > 0:
            tpe = torch.bincount(top_k_indices.view(-1), minlength=self.num_routed_experts)
            tokens_per_expert = tpe 

        
        final_routed = torch.zeros_like(flat_hidden); flat_expert_indices = top_k_indices.view(-1)
        flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.num_experts_per_tok)
        for i, expert in enumerate(self.routed_experts):
            mask = (flat_expert_indices == i)
            if mask.any():
                tokens_idx = flat_token_indices[mask]; unique_tokens, orig_indices = torch.unique(tokens_idx, return_inverse=True)
                inputs = flat_hidden[unique_tokens]; outputs = expert(inputs)[orig_indices]
                weights = top_k_weights.view(-1)[mask].unsqueeze(1)
                final_routed.scatter_add_(0, tokens_idx.unsqueeze(1).expand(-1, hidden_dim), outputs * weights)
        final_routed = final_routed.view(bsz, seq_len, hidden_dim)

        
        shared_combined = torch.zeros_like(hidden_states)
        if self.shared_experts_list: shared_combined = self.shared_experts_list[0](hidden_states)
        final_output = final_routed + shared_combined

        
        if tokens_per_expert is not None:
            self.update_bias(tokens_per_expert)

        return final_output 



def rotate_half(x): x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]; return torch.cat((-x2,x1),dim=-1)
def apply_rotary_pos_emb(q,k,cos,sin,position_ids=None,unsqueeze_dim=1): cos,sin=cos.unsqueeze(unsqueeze_dim),sin.unsqueeze(unsqueeze_dim); return (q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin)
def repeat_kv(x:torch.Tensor,n:int)->torch.Tensor: b,nk,s,h=x.shape; return x[:,:,None,:,:].expand(b,nk,n,s,h).reshape(b,nk*n,s,h) if n!=1 else x
def eager_attention_forward(m,q,k,v,mask,scale,dropout=0.0,**kw):
    k,v=repeat_kv(k,m.num_key_value_groups),repeat_kv(v,m.num_key_value_groups); slk=k.shape[-2]
    if mask is not None: mask=mask[:,:,:,:slk]
    w=torch.matmul(q,k.transpose(2,3))*scale
    if mask is not None:
        if mask.size()!=(q.shape[0],1,q.shape[2],k.shape[2]): raise ValueError("Mask shape mismatch")
        w=w+mask
    w=F.softmax(w,dim=-1,dtype=torch.float32).to(q.dtype); w=F.dropout(w,p=dropout,training=m.training)
    o=torch.matmul(w,v).transpose(1,2).contiguous(); return o,w



class BiBoAttention(nn.Module):
    def __init__(self, config: BiBoConfig, layer_idx: int):
        super().__init__(); self.config=config; self.layer_idx=layer_idx
        self.hidden_size=config.hidden_size; self.num_heads=config.num_attention_heads; self.head_dim=self.hidden_size//self.num_heads
        self.num_key_value_heads=config.num_key_value_heads; self.num_key_value_groups=self.num_heads//self.num_key_value_heads
        self.max_position_embeddings=config.max_position_embeddings; self.rope_theta=config.rope_theta; self.is_causal=True
        self.attention_dropout=config.attention_dropout; self.scaling=self.head_dim**-0.5
        self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=True); self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True)
        self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True); self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False)

    
    def forward(self, hidden_states, pos_emb, mask=None, kv_cache=None, output_attentions=False, use_cache=False, cache_position=None, **kw):
        b,q,_=hidden_states.size(); query=self.q_proj(hidden_states).view(b,q,self.num_heads,self.head_dim).transpose(1,2)
        key=self.k_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2); value=self.v_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2)
        cos,sin=pos_emb; query,key=apply_rotary_pos_emb(query,key,cos,sin)
        if kv_cache is not None: key,value=kv_cache.update(key,value,self.layer_idx,{"sin":sin,"cos":cos,"cache_position":cache_position})
        out,weights=eager_attention_forward(self,query,key,value,mask,self.scaling,self.attention_dropout)
        out=out.reshape(b,q,self.hidden_size); out=self.o_proj(out); return out,weights if output_attentions else None

class BiBoRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6): super().__init__(); self.weight=nn.Parameter(torch.ones(hidden_size)); self.variance_epsilon=eps
    def forward(self, x): dt=x.dtype; x=x.to(torch.float32); v=x.pow(2).mean(-1,keepdim=True); x=x*torch.rsqrt(v+self.variance_epsilon); return self.weight*x.to(dt)
    def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

class BiBoDecoderLayer(nn.Module):
    def __init__(self, config: BiBoConfig, layer_idx: int): 
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = BiBoAttention(config=config, layer_idx=layer_idx)
        self.input_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.layer_idx = layer_idx
        self.num_hidden_layers = config.num_hidden_layers
        is_first_layer = layer_idx == 0
        is_last_layer = layer_idx == config.num_hidden_layers - 1
        # Conditional MLP/MoE Instantiation
        if is_first_layer or is_last_layer:
            self.mlp = BiBoMLP(config) 
            self.is_moe_layer = False
        else:
            self.mlp = BiBoMoELayer(config) 
            self.is_moe_layer = True

    
    def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None): 
        """ Returns tuple: (hidden_states,) or (hidden_states, attn_weights) """
        residual = hidden_states; hidden_states = self.input_layernorm(hidden_states)
        attn_outputs, attn_weights = self.self_attn(hidden_states, position_embeddings, attention_mask, past_key_value, output_attentions, use_cache, cache_position)
        hidden_states = residual + attn_outputs; residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        # --- Conditional Forward  ---
        if self.is_moe_layer: ffn_output = self.mlp(hidden_states)
        else: ffn_output = self.mlp(hidden_states)
        hidden_states = residual + ffn_output; outputs = (hidden_states,)
        if output_attentions: outputs += (attn_weights,)
        return outputs



class BiBoRotaryEmbedding(nn.Module):
    def __init__(self, config: BiBoConfig, device=None): 
        super().__init__(); rope_scaling=getattr(config,"rope_scaling",None); self.rope_type=rope_scaling.get("rope_type","default") if rope_scaling else "default"
        self.max_seq_len_cached=config.max_position_embeddings; self.original_max_seq_len=config.max_position_embeddings; self.config=config
        self.rope_init_fn=ROPE_INIT_FUNCTIONS[self.rope_type]; inv_freq,self.attention_scaling=self.rope_init_fn(self.config,device)
        self.register_buffer("inv_freq",inv_freq,persistent=False); self.original_inv_freq=self.inv_freq

    
    @torch.no_grad()
    @dynamic_rope_update
    def forward(self, x, position_ids):
        inv_freq=self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1).to(x.device); pos_ids=position_ids[:,None,:].float()
        dev_type=x.device.type if isinstance(x.device.type,str) and x.device.type!="mps" else "cpu"
        with torch.autocast(device_type=dev_type,enabled=False):
            freqs=(inv_freq.float()@pos_ids.float()).transpose(1,2); emb=torch.cat((freqs,freqs),dim=-1)
            cos=emb.cos()*self.attention_scaling; sin=emb.sin()*self.attention_scaling
        return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype)


BIBO_START_DOCSTRING = r""" BiBo model... """
BIBO_INPUTS_DOCSTRING = r""" Standard arguments... """

@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
class BiBoPreTrainedModel(PreTrainedModel):
    config_class = BiBoConfig 
    base_model_prefix = "model"; supports_gradient_checkpointing = True
    _no_split_modules = ["BiBoDecoderLayer"]; _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = False; _supports_sdpa = True; _supports_cache_class = True
    _supports_quantized_cache = True; _supports_static_cache = True
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std); module.bias.data.zero_() if module.bias is not None else None
        elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std); module.weight.data[module.padding_idx].zero_() if module.padding_idx is not None else None
        elif isinstance(module, BiBoRMSNorm): module.weight.data.fill_(1.0)
        elif isinstance(module, nn.Conv1d): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)); module.bias.data.zero_() if module.bias is not None else None

@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING)
class BiBoModel(BiBoPreTrainedModel):
    def __init__(self, config: BiBoConfig):
        super().__init__(config)
        self.config = config
        self.padding_idx = config.pad_token_id; self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([BiBoDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
        self.norm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = BiBoRotaryEmbedding(config=config)
        self.gradient_checkpointing = False; self.post_init()
        
    def get_input_embeddings(self): return self.embed_tokens
    def set_input_embeddings(self, value): self.embed_tokens = value
        
    def _prepare_decoder_attention_mask(self, mask, shape, embeds, past_len):
        combined_mask=None; L=shape[-1]
        if L>1: combined_mask=nn.functional._make_causal_mask(shape,embeds.dtype,device=embeds.device,past_key_values_length=past_len).to(embeds.device)
        if mask is not None:
            expanded_mask=nn.functional._expand_mask(mask,embeds.dtype,tgt_len=L).to(embeds.device)
            combined_mask=(expanded_mask if combined_mask is None else expanded_mask+combined_mask)
        if combined_mask is not None: bool_mask=combined_mask<0; combined_mask=combined_mask.masked_fill(bool_mask,torch.finfo(embeds.dtype).min)
        return combined_mask

    
    @can_return_tuple
    @add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
    def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, return_dict=None):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache; return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if (input_ids is None)^(inputs_embeds is not None): raise ValueError("Specify ids or embeds")
        if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("Disabling use_cache"); use_cache=False
        if not isinstance(past_key_values,(type(None),Cache)): raise ValueError("past_key_values type error")
        if inputs_embeds is None: inputs_embeds=self.embed_tokens(input_ids)
        if use_cache and past_key_values is None: past_key_values=DynamicCache()
        past_len=past_key_values.get_seq_length() if past_key_values is not None else 0; seq_len=inputs_embeds.shape[1]
        if cache_position is None: cache_position=torch.arange(past_len,past_len+seq_len,device=inputs_embeds.device)
        if position_ids is None: position_ids=cache_position.unsqueeze(0)
        causal_mask=self._prepare_decoder_attention_mask(attention_mask,(inputs_embeds.shape[0],seq_len),inputs_embeds,past_len)
        hidden_states=inputs_embeds; pos_emb=self.rotary_emb(hidden_states,position_ids)
        all_hidden,all_attn=(()if output_hidden_states else None,()if output_attentions else None)
        for layer in self.layers:
            if output_hidden_states: all_hidden+=(hidden_states,)
            layer_outputs=layer(hidden_states,pos_emb,causal_mask,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position)
            hidden_states=layer_outputs[0]
            if output_attentions: all_attn+=(layer_outputs[1],) 
        hidden_states=self.norm(hidden_states)
        if output_hidden_states: all_hidden+=(hidden_states,)
        next_cache=past_key_values if use_cache else None
        if not return_dict: return tuple(v for v in [hidden_states,next_cache,all_hidden,all_attn] if v is not None) 
        return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=next_cache,hidden_states=all_hidden,attentions=all_attn)

@add_start_docstrings(""" BiBo Model with CausalLM head. """, BIBO_START_DOCSTRING)
class BiBoForCausalLM(BiBoPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    def __init__(self, config: BiBoConfig): 
        super().__init__(config)
        self.model = BiBoModel(config) 
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()
    # Methods remain the same
    def get_input_embeddings(self): return self.model.embed_tokens
    def set_input_embeddings(self, value): self.model.embed_tokens = value
    def get_output_embeddings(self): return self.lm_head
    def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
    def set_decoder(self, decoder): self.model = decoder
    def get_decoder(self): return self.model


    
    @can_return_tuple
    @add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING)
    def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, logits_to_keep=0, return_dict=None,): # Add noise arg w/ default
        r""" Loss calculation (CrossEntropy) must happen outside this function. """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict,)
        hidden_states = model_outputs[0] if not return_dict else model_outputs.last_hidden_state
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep != 0 else slice(None)
        logits = self.lm_head(hidden_states[:, slice_indices, :])
        # --- Loss is None ---
        loss = None
        if labels is not None: warnings.warn("Labels provided but loss calculation must be done externally.")
        if not return_dict:
            other_outputs = model_outputs[1:] 
            return (loss,) + (logits,) + other_outputs
        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions)