|
|
|
|
|
|
|
|
|
|
|
""" PyTorch BiBo model (Based on Qwen2 with MoE modifications). |
|
|
we can use MoEwithoutput class; """ |
|
|
import math |
|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
from .configuration_bibo import BiBoConfig |
|
|
|
|
|
|
|
|
try: |
|
|
import torch_xla.core.xla_model as xm |
|
|
_XLA_AVAILABLE = True |
|
|
except ImportError: |
|
|
_XLA_AVAILABLE = False |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache, SlidingWindowCache |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.utils import ( |
|
|
add_start_docstrings, |
|
|
add_start_docstrings_to_model_forward, |
|
|
is_flash_attn_2_available, |
|
|
is_flash_attn_greater_or_equal_2_10, |
|
|
logging, |
|
|
replace_return_docstrings, |
|
|
can_return_tuple, |
|
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_CHECKPOINT_FOR_DOC = "BiBo-MoE-Model" |
|
|
_CONFIG_FOR_DOC = "BiBoConfig" |
|
|
|
|
|
|
|
|
class BiBoMLP(nn.Module): |
|
|
"""Standard SwiGLU MLP used for dense layers.""" |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class MLPExpert(nn.Module): |
|
|
"""SwiGLU based MLP Expert for MoE Layers""" |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.moe_intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
class ModifiedConvolutionalExpert(nn.Module): |
|
|
"""Causal Convolutional 'Expert' (Shared) for MoE Layers""" |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.moe_intermediate_size |
|
|
self.kernel_size_gate = config.kernel_size |
|
|
self.causal_padding_gate = self.kernel_size_gate - 1 |
|
|
self.gate_conv = nn.Conv1d(self.hidden_size, self.intermediate_size, self.kernel_size_gate, padding=0, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
bsz, seq_len, hidden_dim = x.shape |
|
|
x_perm = x.permute(0, 2, 1) |
|
|
|
|
|
x_padded = F.pad(x_perm, (self.causal_padding_gate, 0)) |
|
|
gate_conv_out = self.gate_conv(x_padded) |
|
|
gate_activated = self.act_fn(gate_conv_out) |
|
|
gate_ready = gate_activated.permute(0, 2, 1) |
|
|
up_linear_out = self.up_proj(x) |
|
|
intermediate = gate_ready * up_linear_out; output = self.down_proj(intermediate) |
|
|
if output.shape[1] != seq_len: raise RuntimeError("ModifiedConvExpert length mismatch") |
|
|
return output |
|
|
|
|
|
class IdentityExpert(nn.Module): |
|
|
def __init__(self, *args, **kwargs): super().__init__() |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: return x |
|
|
|
|
|
|
|
|
|
|
|
class BiBoMoERouter(nn.Module): |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__() |
|
|
self.num_experts = config.num_routed_experts |
|
|
self.top_k = config.num_experts_per_tok |
|
|
self.temperature = config.router_temperature |
|
|
self.router_noise = config.router_noise |
|
|
self.bias = nn.Parameter(torch.zeros(self.num_experts)) |
|
|
self.gate_proj = nn.Linear(config.hidden_size, self.num_experts, bias=False) |
|
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
|
""" Forward pass with noise, bias, clamping, temperature. """ |
|
|
|
|
|
bsz, seq_len, _ = hidden_states.shape; num_tokens = bsz * seq_len |
|
|
noise_variance=self.router_noise |
|
|
flat_hidden = hidden_states.view(num_tokens, -1) |
|
|
router_logits = self.gate_proj(flat_hidden).float() |
|
|
|
|
|
""" No Clamping for Now |
|
|
TODO: @aloobun make clamp range dynamic based on mean/median/mode/std of current logits""" |
|
|
|
|
|
|
|
|
|
|
|
if self.training and noise_variance > 0: |
|
|
noise_stddev = math.sqrt(noise_variance) |
|
|
noise = torch.randn_like(router_logits) * noise_stddev |
|
|
router_logits = router_logits + noise.detach() |
|
|
|
|
|
router_logits = router_logits + self.bias |
|
|
if self.temperature != 1.0: |
|
|
router_logits = router_logits / self.temperature |
|
|
routing_weights = F.softmax(router_logits, dim=1) |
|
|
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1) |
|
|
norm_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6) |
|
|
|
|
|
return top_k_indices.long(), norm_weights.to(hidden_states.dtype) |
|
|
|
|
|
|
|
|
class BiBoMoELayer(nn.Module): |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size; self.num_experts_per_tok = config.num_experts_per_tok |
|
|
self.routed_experts = nn.ModuleList() |
|
|
num_mlp_routed = config.num_routed_experts - 1 |
|
|
for _ in range(num_mlp_routed): self.routed_experts.append(MLPExpert(config)) |
|
|
self.routed_experts.append(IdentityExpert(config)) |
|
|
if len(self.routed_experts) != config.num_routed_experts: raise ValueError("Routed experts mismatch") |
|
|
self.shared_experts_list = nn.ModuleList() |
|
|
if config.num_shared_experts > 0: |
|
|
if config.num_shared_experts != 1: warnings.warn("Expected 1 shared expert, using 1 Conv.") |
|
|
self.shared_experts_list.append(ModifiedConvolutionalExpert(config)) |
|
|
self.gate = BiBoMoERouter(config) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def update_bias(self, tpe): |
|
|
""" |
|
|
Updates the router's learnable bias based on token distribution. |
|
|
Ref: https://gist.github.com/joey00072/f9e65f7fe05b763a19e4824bda29c975 |
|
|
|
|
|
""" |
|
|
if not hasattr(self.gate, 'bias') or self.bias_update_factor <= 0: return |
|
|
c = tpe.detach().float() |
|
|
e = c.mean() - c |
|
|
|
|
|
self.gate.bias.add_(self.bias_update_factor * e.sign()) |
|
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
|
""" Returns: final_output tensor """ |
|
|
bsz, seq_len, hidden_dim = hidden_states.shape; num_tokens = bsz * seq_len |
|
|
flat_hidden = hidden_states.view(num_tokens, -1) |
|
|
top_k_indices, top_k_weights = self.gate(hidden_states, noise_variance=self.router_noise) |
|
|
|
|
|
tokens_per_expert = None |
|
|
if self.training and hasattr(self.gate, 'bias') and self.bias_update_factor > 0: |
|
|
tpe = torch.bincount(top_k_indices.view(-1), minlength=self.num_routed_experts) |
|
|
tokens_per_expert = tpe |
|
|
|
|
|
|
|
|
final_routed = torch.zeros_like(flat_hidden); flat_expert_indices = top_k_indices.view(-1) |
|
|
flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.num_experts_per_tok) |
|
|
for i, expert in enumerate(self.routed_experts): |
|
|
mask = (flat_expert_indices == i) |
|
|
if mask.any(): |
|
|
tokens_idx = flat_token_indices[mask]; unique_tokens, orig_indices = torch.unique(tokens_idx, return_inverse=True) |
|
|
inputs = flat_hidden[unique_tokens]; outputs = expert(inputs)[orig_indices] |
|
|
weights = top_k_weights.view(-1)[mask].unsqueeze(1) |
|
|
final_routed.scatter_add_(0, tokens_idx.unsqueeze(1).expand(-1, hidden_dim), outputs * weights) |
|
|
final_routed = final_routed.view(bsz, seq_len, hidden_dim) |
|
|
|
|
|
|
|
|
shared_combined = torch.zeros_like(hidden_states) |
|
|
if self.shared_experts_list: shared_combined = self.shared_experts_list[0](hidden_states) |
|
|
final_output = final_routed + shared_combined |
|
|
|
|
|
|
|
|
if tokens_per_expert is not None: |
|
|
self.update_bias(tokens_per_expert) |
|
|
|
|
|
return final_output |
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x): x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]; return torch.cat((-x2,x1),dim=-1) |
|
|
def apply_rotary_pos_emb(q,k,cos,sin,position_ids=None,unsqueeze_dim=1): cos,sin=cos.unsqueeze(unsqueeze_dim),sin.unsqueeze(unsqueeze_dim); return (q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin) |
|
|
def repeat_kv(x:torch.Tensor,n:int)->torch.Tensor: b,nk,s,h=x.shape; return x[:,:,None,:,:].expand(b,nk,n,s,h).reshape(b,nk*n,s,h) if n!=1 else x |
|
|
def eager_attention_forward(m,q,k,v,mask,scale,dropout=0.0,**kw): |
|
|
k,v=repeat_kv(k,m.num_key_value_groups),repeat_kv(v,m.num_key_value_groups); slk=k.shape[-2] |
|
|
if mask is not None: mask=mask[:,:,:,:slk] |
|
|
w=torch.matmul(q,k.transpose(2,3))*scale |
|
|
if mask is not None: |
|
|
if mask.size()!=(q.shape[0],1,q.shape[2],k.shape[2]): raise ValueError("Mask shape mismatch") |
|
|
w=w+mask |
|
|
w=F.softmax(w,dim=-1,dtype=torch.float32).to(q.dtype); w=F.dropout(w,p=dropout,training=m.training) |
|
|
o=torch.matmul(w,v).transpose(1,2).contiguous(); return o,w |
|
|
|
|
|
|
|
|
|
|
|
class BiBoAttention(nn.Module): |
|
|
def __init__(self, config: BiBoConfig, layer_idx: int): |
|
|
super().__init__(); self.config=config; self.layer_idx=layer_idx |
|
|
self.hidden_size=config.hidden_size; self.num_heads=config.num_attention_heads; self.head_dim=self.hidden_size//self.num_heads |
|
|
self.num_key_value_heads=config.num_key_value_heads; self.num_key_value_groups=self.num_heads//self.num_key_value_heads |
|
|
self.max_position_embeddings=config.max_position_embeddings; self.rope_theta=config.rope_theta; self.is_causal=True |
|
|
self.attention_dropout=config.attention_dropout; self.scaling=self.head_dim**-0.5 |
|
|
self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=True); self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True) |
|
|
self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=True); self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False) |
|
|
|
|
|
|
|
|
def forward(self, hidden_states, pos_emb, mask=None, kv_cache=None, output_attentions=False, use_cache=False, cache_position=None, **kw): |
|
|
b,q,_=hidden_states.size(); query=self.q_proj(hidden_states).view(b,q,self.num_heads,self.head_dim).transpose(1,2) |
|
|
key=self.k_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2); value=self.v_proj(hidden_states).view(b,q,self.num_key_value_heads,self.head_dim).transpose(1,2) |
|
|
cos,sin=pos_emb; query,key=apply_rotary_pos_emb(query,key,cos,sin) |
|
|
if kv_cache is not None: key,value=kv_cache.update(key,value,self.layer_idx,{"sin":sin,"cos":cos,"cache_position":cache_position}) |
|
|
out,weights=eager_attention_forward(self,query,key,value,mask,self.scaling,self.attention_dropout) |
|
|
out=out.reshape(b,q,self.hidden_size); out=self.o_proj(out); return out,weights if output_attentions else None |
|
|
|
|
|
class BiBoRMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): super().__init__(); self.weight=nn.Parameter(torch.ones(hidden_size)); self.variance_epsilon=eps |
|
|
def forward(self, x): dt=x.dtype; x=x.to(torch.float32); v=x.pow(2).mean(-1,keepdim=True); x=x*torch.rsqrt(v+self.variance_epsilon); return self.weight*x.to(dt) |
|
|
def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
|
class BiBoDecoderLayer(nn.Module): |
|
|
def __init__(self, config: BiBoConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.self_attn = BiBoAttention(config=config, layer_idx=layer_idx) |
|
|
self.input_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.layer_idx = layer_idx |
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
is_first_layer = layer_idx == 0 |
|
|
is_last_layer = layer_idx == config.num_hidden_layers - 1 |
|
|
|
|
|
if is_first_layer or is_last_layer: |
|
|
self.mlp = BiBoMLP(config) |
|
|
self.is_moe_layer = False |
|
|
else: |
|
|
self.mlp = BiBoMoELayer(config) |
|
|
self.is_moe_layer = True |
|
|
|
|
|
|
|
|
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None): |
|
|
""" Returns tuple: (hidden_states,) or (hidden_states, attn_weights) """ |
|
|
residual = hidden_states; hidden_states = self.input_layernorm(hidden_states) |
|
|
attn_outputs, attn_weights = self.self_attn(hidden_states, position_embeddings, attention_mask, past_key_value, output_attentions, use_cache, cache_position) |
|
|
hidden_states = residual + attn_outputs; residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
|
|
if self.is_moe_layer: ffn_output = self.mlp(hidden_states) |
|
|
else: ffn_output = self.mlp(hidden_states) |
|
|
hidden_states = residual + ffn_output; outputs = (hidden_states,) |
|
|
if output_attentions: outputs += (attn_weights,) |
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
class BiBoRotaryEmbedding(nn.Module): |
|
|
def __init__(self, config: BiBoConfig, device=None): |
|
|
super().__init__(); rope_scaling=getattr(config,"rope_scaling",None); self.rope_type=rope_scaling.get("rope_type","default") if rope_scaling else "default" |
|
|
self.max_seq_len_cached=config.max_position_embeddings; self.original_max_seq_len=config.max_position_embeddings; self.config=config |
|
|
self.rope_init_fn=ROPE_INIT_FUNCTIONS[self.rope_type]; inv_freq,self.attention_scaling=self.rope_init_fn(self.config,device) |
|
|
self.register_buffer("inv_freq",inv_freq,persistent=False); self.original_inv_freq=self.inv_freq |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@dynamic_rope_update |
|
|
def forward(self, x, position_ids): |
|
|
inv_freq=self.inv_freq[None,:,None].float().expand(position_ids.shape[0],-1,1).to(x.device); pos_ids=position_ids[:,None,:].float() |
|
|
dev_type=x.device.type if isinstance(x.device.type,str) and x.device.type!="mps" else "cpu" |
|
|
with torch.autocast(device_type=dev_type,enabled=False): |
|
|
freqs=(inv_freq.float()@pos_ids.float()).transpose(1,2); emb=torch.cat((freqs,freqs),dim=-1) |
|
|
cos=emb.cos()*self.attention_scaling; sin=emb.sin()*self.attention_scaling |
|
|
return cos.to(dtype=x.dtype),sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
BIBO_START_DOCSTRING = r""" BiBo model... """ |
|
|
BIBO_INPUTS_DOCSTRING = r""" Standard arguments... """ |
|
|
|
|
|
@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING) |
|
|
class BiBoPreTrainedModel(PreTrainedModel): |
|
|
config_class = BiBoConfig |
|
|
base_model_prefix = "model"; supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["BiBoDecoderLayer"]; _skip_keys_device_placement = ["past_key_values"] |
|
|
_supports_flash_attn_2 = False; _supports_sdpa = True; _supports_cache_class = True |
|
|
_supports_quantized_cache = True; _supports_static_cache = True |
|
|
def _init_weights(self, module): |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std); module.bias.data.zero_() if module.bias is not None else None |
|
|
elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std); module.weight.data[module.padding_idx].zero_() if module.padding_idx is not None else None |
|
|
elif isinstance(module, BiBoRMSNorm): module.weight.data.fill_(1.0) |
|
|
elif isinstance(module, nn.Conv1d): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)); module.bias.data.zero_() if module.bias is not None else None |
|
|
|
|
|
@add_start_docstrings("The bare BiBo Model", BIBO_START_DOCSTRING) |
|
|
class BiBoModel(BiBoPreTrainedModel): |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.padding_idx = config.pad_token_id; self.vocab_size = config.vocab_size |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
self.layers = nn.ModuleList([BiBoDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) |
|
|
self.norm = BiBoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = BiBoRotaryEmbedding(config=config) |
|
|
self.gradient_checkpointing = False; self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): return self.embed_tokens |
|
|
def set_input_embeddings(self, value): self.embed_tokens = value |
|
|
|
|
|
def _prepare_decoder_attention_mask(self, mask, shape, embeds, past_len): |
|
|
combined_mask=None; L=shape[-1] |
|
|
if L>1: combined_mask=nn.functional._make_causal_mask(shape,embeds.dtype,device=embeds.device,past_key_values_length=past_len).to(embeds.device) |
|
|
if mask is not None: |
|
|
expanded_mask=nn.functional._expand_mask(mask,embeds.dtype,tgt_len=L).to(embeds.device) |
|
|
combined_mask=(expanded_mask if combined_mask is None else expanded_mask+combined_mask) |
|
|
if combined_mask is not None: bool_mask=combined_mask<0; combined_mask=combined_mask.masked_fill(bool_mask,torch.finfo(embeds.dtype).min) |
|
|
return combined_mask |
|
|
|
|
|
|
|
|
@can_return_tuple |
|
|
@add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING) |
|
|
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, return_dict=None): |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache; return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
if (input_ids is None)^(inputs_embeds is not None): raise ValueError("Specify ids or embeds") |
|
|
if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("Disabling use_cache"); use_cache=False |
|
|
if not isinstance(past_key_values,(type(None),Cache)): raise ValueError("past_key_values type error") |
|
|
if inputs_embeds is None: inputs_embeds=self.embed_tokens(input_ids) |
|
|
if use_cache and past_key_values is None: past_key_values=DynamicCache() |
|
|
past_len=past_key_values.get_seq_length() if past_key_values is not None else 0; seq_len=inputs_embeds.shape[1] |
|
|
if cache_position is None: cache_position=torch.arange(past_len,past_len+seq_len,device=inputs_embeds.device) |
|
|
if position_ids is None: position_ids=cache_position.unsqueeze(0) |
|
|
causal_mask=self._prepare_decoder_attention_mask(attention_mask,(inputs_embeds.shape[0],seq_len),inputs_embeds,past_len) |
|
|
hidden_states=inputs_embeds; pos_emb=self.rotary_emb(hidden_states,position_ids) |
|
|
all_hidden,all_attn=(()if output_hidden_states else None,()if output_attentions else None) |
|
|
for layer in self.layers: |
|
|
if output_hidden_states: all_hidden+=(hidden_states,) |
|
|
layer_outputs=layer(hidden_states,pos_emb,causal_mask,past_key_value=past_key_values,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position) |
|
|
hidden_states=layer_outputs[0] |
|
|
if output_attentions: all_attn+=(layer_outputs[1],) |
|
|
hidden_states=self.norm(hidden_states) |
|
|
if output_hidden_states: all_hidden+=(hidden_states,) |
|
|
next_cache=past_key_values if use_cache else None |
|
|
if not return_dict: return tuple(v for v in [hidden_states,next_cache,all_hidden,all_attn] if v is not None) |
|
|
return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=next_cache,hidden_states=all_hidden,attentions=all_attn) |
|
|
|
|
|
@add_start_docstrings(""" BiBo Model with CausalLM head. """, BIBO_START_DOCSTRING) |
|
|
class BiBoForCausalLM(BiBoPreTrainedModel, GenerationMixin): |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
def __init__(self, config: BiBoConfig): |
|
|
super().__init__(config) |
|
|
self.model = BiBoModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): return self.model.embed_tokens |
|
|
def set_input_embeddings(self, value): self.model.embed_tokens = value |
|
|
def get_output_embeddings(self): return self.lm_head |
|
|
def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings |
|
|
def set_decoder(self, decoder): self.model = decoder |
|
|
def get_decoder(self): return self.model |
|
|
|
|
|
|
|
|
|
|
|
@can_return_tuple |
|
|
@add_start_docstrings_to_model_forward(BIBO_INPUTS_DOCSTRING) |
|
|
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, cache_position=None, logits_to_keep=0, return_dict=None,): |
|
|
r""" Loss calculation (CrossEntropy) must happen outside this function. """ |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions; output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict,) |
|
|
hidden_states = model_outputs[0] if not return_dict else model_outputs.last_hidden_state |
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep != 0 else slice(None) |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: warnings.warn("Labels provided but loss calculation must be done externally.") |
|
|
if not return_dict: |
|
|
other_outputs = model_outputs[1:] |
|
|
return (loss,) + (logits,) + other_outputs |
|
|
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, attentions=model_outputs.attentions) |