from typing import Optional, Tuple, List import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_veronica import VeronicaConfig from .modeling_components import PolymorphicMLP, router_aux_loss, Fp32LayerNorm, apply_rotary_pos_emb class MultiHeadSelfAttention(nn.Module): def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0, max_position_embeddings: int = 4096, rope_theta: float = 10000.0): super().__init__() assert hidden_size % num_heads == 0, "hidden_size must be divisible by n_head" self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.qkv = nn.Linear(hidden_size, 3 * hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) self.attn_drop = nn.Dropout(dropout) self.resid_drop = nn.Dropout(dropout) # Precomputa RoPE frequencies self._rope_cached_seq_len = 0 self._rope_cos_cached = None self._rope_sin_cached = None def _split_heads(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape x = x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hd) return x def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: B, nh, T, hd = x.shape return x.transpose(1, 2).contiguous().view(B, T, nh * hd) def _get_rope_cos_sin(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]: """Genera o recupera dalla cache cos/sin per RoPE.""" if seq_len <= self._rope_cached_seq_len and self._rope_cos_cached is not None: return self._rope_cos_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), \ self._rope_sin_cached[:, :, :seq_len, :].to(device=device, dtype=dtype) # Genera nuove frequenze self._rope_cached_seq_len = max(seq_len, self.max_position_embeddings) # inv_freq: (hd/2,) dim = self.head_dim inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) # t: (seq_len,) t = torch.arange(self._rope_cached_seq_len, dtype=torch.float32, device=device) # freqs: (seq_len, hd/2) freqs = torch.outer(t, inv_freq) # Duplica per avere shape (seq_len, hd) emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, hd) # cos, sin: (1, 1, seq_len, hd) cos = emb.cos().unsqueeze(0).unsqueeze(0) sin = emb.sin().unsqueeze(0).unsqueeze(0) self._rope_cos_cached = cos self._rope_sin_cached = sin return cos[:, :, :seq_len, :].to(dtype=dtype), sin[:, :, :seq_len, :].to(dtype=dtype) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, # additive mask [B,1,T,S] in float32 past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0, # offset per posizione (per KV cache) ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, C = x.shape qkv = self.qkv(x) q, k, v = qkv.split(C, dim=-1) q = self._split_heads(q) # (B, nh, T, hd) k = self._split_heads(k) v = self._split_heads(v) # Applica RoPE a q e k cos, sin = self._get_rope_cos_sin(position_offset + T, q.device, q.dtype) # Prendi solo le posizioni rilevanti [position_offset : position_offset+T] cos = cos[:, :, position_offset:position_offset+T, :] sin = sin[:, :, position_offset:position_offset+T, :] q, k = apply_rotary_pos_emb(q, k, cos, sin) present = None if past_key_value is not None: pk, pv = past_key_value # (B, nh, Tp, hd) k = torch.cat([pk, k], dim=-2) v = torch.cat([pv, v], dim=-2) if use_cache: present = (k, v) att = (q @ k.transpose(-2, -1)) * self.scale # (B, nh, T, S) att = att.float() if attn_mask is not None: att = att + attn_mask # additive bias: -inf on masked pos att = F.softmax(att, dim=-1) att = self.attn_drop(att) att = att.to(v.dtype) # Cast back to match v's dtype (BF16/FP16/FP32) y = att @ v # (B, nh, T, hd) y = self._merge_heads(y) y = self.out_proj(y) y = self.resid_drop(y) return y, present class VeronicaBlock(nn.Module): def __init__(self, config: VeronicaConfig): super().__init__() self.ln_1 = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = MultiHeadSelfAttention( config.n_embd, config.n_head, dropout=config.dropout, max_position_embeddings=config.max_position_embeddings, rope_theta=getattr(config, 'rope_theta', 10000.0) ) self.ln_2 = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.mlp = PolymorphicMLP( hidden_size=config.n_embd, mlp_mult=config.mlp_mult, num_funcs=config.num_funcs, router_dim=config.router_dim, dropout=config.dropout, use_channel_attention=config.use_channel_attention, router_tau=config.router_tau, ) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, position_offset: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: h = self.ln_1(x) attn_out, present = self.attn(h, attn_mask, past_key_value=past_key_value, use_cache=use_cache, position_offset=position_offset) x = x + attn_out h = self.ln_2(x) y, alpha = self.mlp(h) x = x + y return x, alpha, present class VeronicaModel(PreTrainedModel): config_class = VeronicaConfig def __init__(self, config: VeronicaConfig): super().__init__(config) self.embed_dim = config.n_embd self.wte = nn.Embedding(config.vocab_size, config.n_embd) # RoPE sostituisce positional embeddings assoluti (wpe rimosso) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([VeronicaBlock(config) for _ in range(config.n_layer)]) self.ln_f = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.register_buffer( "causal_mask", torch.tril( torch.ones( config.max_position_embeddings, config.max_position_embeddings, dtype=torch.uint8, ) ).view(1, 1, config.max_position_embeddings, config.max_position_embeddings), persistent=False, ) # Monitoring self.router_alpha_entropy: Optional[torch.Tensor] = None self.router_alpha_mean: Optional[torch.Tensor] = None self._use_gradient_checkpointing: bool = getattr(config, "gradient_checkpointing", False) def get_input_embeddings(self): return self.wte def set_input_embeddings(self, value): self.wte = value def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self._use_gradient_checkpointing = True def gradient_checkpointing_disable(self): self._use_gradient_checkpointing = False def _build_attn_mask( self, attention_mask: Optional[torch.Tensor], seq_len: int, past_kv_len: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: # Causal mask additiva in float32 T, P = seq_len, past_kv_len causal = torch.full((T, T + P), float("-inf"), device=device, dtype=dtype) causal = torch.triu(causal, diagonal=1 + P) # -inf per future, 0 altrove if attention_mask is None: return causal.unsqueeze(0).unsqueeze(1) # [1,1,T,T+P] # attention_mask shape: [B, T+P] (0 pad, 1 valid) attn_full = attention_mask.to(dtype) pad_add = (1.0 - attn_full) * torch.finfo(dtype).min # [B, T+P] pad_add = pad_add.unsqueeze(1).unsqueeze(2) # [B,1,1,T+P] causal = causal.unsqueeze(0).unsqueeze(1) # [1,1,T,T+P] return causal + pad_add def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, output_router_stats: bool = True, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: device = input_ids.device B, T = input_ids.shape if use_cache is None: use_cache = False if self.training else True pkv_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None P = 0 if ( past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None and isinstance(past_key_values[0], (tuple, list)) and past_key_values[0][0] is not None ): P = past_key_values[0][0].size(-2) # Solo token embeddings (RoPE gestisce le posizioni) x = self.wte(input_ids) x = self.drop(x) # attention_mask full [B, T+P] attn_full = None if attention_mask is not None: if attention_mask.size(-1) == T + P: attn_full = attention_mask elif attention_mask.size(-1) == T: if P > 0: ones = torch.ones((B, P), dtype=attention_mask.dtype, device=attention_mask.device) attn_full = torch.cat([ones, attention_mask], dim=-1) else: attn_full = attention_mask else: attn_full = None attn_bias = self._build_attn_mask(attn_full, T, P, device, torch.float32) alpha_list: List[torch.Tensor] = [] if self.training: self._acc_aux_sum = 0.0 self._acc_aux_count = 0 if getattr(self, "_use_gradient_checkpointing", False) and self.training: def create_custom_forward(module, pkv): def custom_forward(x): out_x, out_alpha, _ = module(x, attn_bias, past_key_value=pkv, use_cache=False, position_offset=P) return out_x, out_alpha return custom_forward if past_key_values is not None: curr_past = [ pkv if (pkv is not None and isinstance(pkv, (tuple, list)) and pkv[0] is not None and pkv[1] is not None) else None for pkv in past_key_values ] else: curr_past = [None] * len(self.blocks) for layer_idx, block in enumerate(self.blocks): x, alpha = torch.utils.checkpoint.checkpoint( create_custom_forward(block, curr_past[layer_idx]), x, use_reentrant=False ) alpha_list.append(alpha) if self.training and getattr(block.mlp, "last_aux", None) is not None: self._acc_aux_sum = self._acc_aux_sum + block.mlp.last_aux self._acc_aux_count += 1 else: if past_key_values is not None: curr_past = [ pkv if (pkv is not None and isinstance(pkv, (tuple, list)) and pkv[0] is not None and pkv[1] is not None) else None for pkv in past_key_values ] else: curr_past = [None] * len(self.blocks) for layer_idx, block in enumerate(self.blocks): x, alpha, present = block(x, attn_bias, past_key_value=curr_past[layer_idx], use_cache=use_cache, position_offset=P) alpha_list.append(alpha) if self.training and getattr(block.mlp, "last_aux", None) is not None: self._acc_aux_sum = self._acc_aux_sum + block.mlp.last_aux self._acc_aux_count += 1 if use_cache and pkv_list is not None: pkv_list.append(present) x = self.ln_f(x) # Router stats if output_router_stats and len(alpha_list) > 0: alpha_stack = torch.stack(alpha_list, dim=0) # (L, B, T, K) alpha_mean = alpha_stack.mean(dim=(0, 1, 2)) # (K,) self.router_alpha_mean = alpha_mean.detach() self.router_alpha_entropy = router_aux_loss(alpha_stack.mean(dim=0)) # Aux-loss medio su profondità if hasattr(self, "_acc_aux_sum"): if self._acc_aux_count > 0: self._last_router_aux = self._acc_aux_sum / self._acc_aux_count else: self._last_router_aux = None delattr(self, "_acc_aux_sum") delattr(self, "_acc_aux_count") return x, pkv_list class VeronicaForCausalLM(VeronicaModel, GenerationMixin): def __init__(self, config: VeronicaConfig): super().__init__(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.post_init() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self): self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings()) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ): if past_key_values is not None and len(past_key_values) > 0: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "use_cache": True, } def _reorder_cache(self, past_key_values, beam_idx: torch.LongTensor): if past_key_values is None: return past_key_values reordered = [] for (k, v) in past_key_values: reordered.append((k.index_select(0, beam_idx), v.index_select(0, beam_idx))) return reordered def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, **kwargs, ) -> CausalLMOutputWithPast: hidden_states, present = super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=None, use_cache=use_cache, past_key_values=past_key_values, **kwargs, ) # (B, T, H) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) aux = getattr(self, "_last_router_aux", None) if aux is not None and getattr(self.config, "router_aux_weight", 0.0) > 0: if not torch.is_tensor(aux): aux = torch.as_tensor(aux, device=logits.device, dtype=logits.dtype) else: aux = aux.to(device=logits.device, dtype=logits.dtype) aux = aux.clamp_min(0.0) loss = loss + float(self.config.router_aux_weight) * aux return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=present if use_cache else None, hidden_states=None, attentions=None, )