|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
from .configuration_veronica import VeronicaConfig |
|
|
from .modeling_components import PolymorphicMLP, router_aux_loss, Fp32LayerNorm, apply_rotary_pos_emb |
|
|
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0, max_position_embeddings: int = 4096, rope_theta: float = 10000.0): |
|
|
super().__init__() |
|
|
assert hidden_size % num_heads == 0, "hidden_size must be divisible by n_head" |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = hidden_size // num_heads |
|
|
self.scale = 1.0 / math.sqrt(self.head_dim) |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rope_theta = rope_theta |
|
|
|
|
|
self.qkv = nn.Linear(hidden_size, 3 * hidden_size) |
|
|
self.out_proj = nn.Linear(hidden_size, hidden_size) |
|
|
self.attn_drop = nn.Dropout(dropout) |
|
|
self.resid_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self._rope_cached_seq_len = 0 |
|
|
self._rope_cos_cached = None |
|
|
self._rope_sin_cached = None |
|
|
|
|
|
def _split_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, T, C = x.shape |
|
|
x = x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
return x |
|
|
|
|
|
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, nh, T, hd = x.shape |
|
|
return x.transpose(1, 2).contiguous().view(B, T, nh * hd) |
|
|
|
|
|
def _get_rope_cos_sin(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Genera o recupera dalla cache cos/sin per RoPE.""" |
|
|
if seq_len <= self._rope_cached_seq_len and self._rope_cos_cached is not None: |
|
|
return self._rope_cos_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), \ |
|
|
self._rope_sin_cached[:, :, :seq_len, :].to(device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
self._rope_cached_seq_len = max(seq_len, self.max_position_embeddings) |
|
|
|
|
|
|
|
|
dim = self.head_dim |
|
|
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) |
|
|
|
|
|
|
|
|
t = torch.arange(self._rope_cached_seq_len, dtype=torch.float32, device=device) |
|
|
|
|
|
|
|
|
freqs = torch.outer(t, inv_freq) |
|
|
|
|
|
|
|
|
emb = torch.cat([freqs, freqs], dim=-1) |
|
|
|
|
|
|
|
|
cos = emb.cos().unsqueeze(0).unsqueeze(0) |
|
|
sin = emb.sin().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
self._rope_cos_cached = cos |
|
|
self._rope_sin_cached = sin |
|
|
|
|
|
return cos[:, :, :seq_len, :].to(dtype=dtype), sin[:, :, :seq_len, :].to(dtype=dtype) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
position_offset: int = 0, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
B, T, C = x.shape |
|
|
qkv = self.qkv(x) |
|
|
q, k, v = qkv.split(C, dim=-1) |
|
|
q = self._split_heads(q) |
|
|
k = self._split_heads(k) |
|
|
v = self._split_heads(v) |
|
|
|
|
|
|
|
|
cos, sin = self._get_rope_cos_sin(position_offset + T, q.device, q.dtype) |
|
|
|
|
|
cos = cos[:, :, position_offset:position_offset+T, :] |
|
|
sin = sin[:, :, position_offset:position_offset+T, :] |
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
|
|
present = None |
|
|
if past_key_value is not None: |
|
|
pk, pv = past_key_value |
|
|
k = torch.cat([pk, k], dim=-2) |
|
|
v = torch.cat([pv, v], dim=-2) |
|
|
if use_cache: |
|
|
present = (k, v) |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * self.scale |
|
|
att = att.float() |
|
|
if attn_mask is not None: |
|
|
att = att + attn_mask |
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_drop(att) |
|
|
att = att.to(v.dtype) |
|
|
y = att @ v |
|
|
y = self._merge_heads(y) |
|
|
y = self.out_proj(y) |
|
|
y = self.resid_drop(y) |
|
|
return y, present |
|
|
|
|
|
|
|
|
class VeronicaBlock(nn.Module): |
|
|
def __init__(self, config: VeronicaConfig): |
|
|
super().__init__() |
|
|
self.ln_1 = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
self.attn = MultiHeadSelfAttention( |
|
|
config.n_embd, |
|
|
config.n_head, |
|
|
dropout=config.dropout, |
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
rope_theta=getattr(config, 'rope_theta', 10000.0) |
|
|
) |
|
|
self.ln_2 = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
self.mlp = PolymorphicMLP( |
|
|
hidden_size=config.n_embd, |
|
|
mlp_mult=config.mlp_mult, |
|
|
num_funcs=config.num_funcs, |
|
|
router_dim=config.router_dim, |
|
|
dropout=config.dropout, |
|
|
use_channel_attention=config.use_channel_attention, |
|
|
router_tau=config.router_tau, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
position_offset: int = 0, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
h = self.ln_1(x) |
|
|
attn_out, present = self.attn(h, attn_mask, past_key_value=past_key_value, use_cache=use_cache, position_offset=position_offset) |
|
|
x = x + attn_out |
|
|
h = self.ln_2(x) |
|
|
y, alpha = self.mlp(h) |
|
|
x = x + y |
|
|
return x, alpha, present |
|
|
|
|
|
|
|
|
class VeronicaModel(PreTrainedModel): |
|
|
config_class = VeronicaConfig |
|
|
|
|
|
def __init__(self, config: VeronicaConfig): |
|
|
super().__init__(config) |
|
|
self.embed_dim = config.n_embd |
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
|
|
|
self.drop = nn.Dropout(config.dropout) |
|
|
self.blocks = nn.ModuleList([VeronicaBlock(config) for _ in range(config.n_layer)]) |
|
|
self.ln_f = Fp32LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
|
|
|
self.register_buffer( |
|
|
"causal_mask", |
|
|
torch.tril( |
|
|
torch.ones( |
|
|
config.max_position_embeddings, |
|
|
config.max_position_embeddings, |
|
|
dtype=torch.uint8, |
|
|
) |
|
|
).view(1, 1, config.max_position_embeddings, config.max_position_embeddings), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
|
|
|
self.router_alpha_entropy: Optional[torch.Tensor] = None |
|
|
self.router_alpha_mean: Optional[torch.Tensor] = None |
|
|
|
|
|
self._use_gradient_checkpointing: bool = getattr(config, "gradient_checkpointing", False) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.wte |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.wte = value |
|
|
|
|
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
|
|
self._use_gradient_checkpointing = True |
|
|
|
|
|
def gradient_checkpointing_disable(self): |
|
|
self._use_gradient_checkpointing = False |
|
|
|
|
|
def _build_attn_mask( |
|
|
self, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
seq_len: int, |
|
|
past_kv_len: int, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
T, P = seq_len, past_kv_len |
|
|
causal = torch.full((T, T + P), float("-inf"), device=device, dtype=dtype) |
|
|
causal = torch.triu(causal, diagonal=1 + P) |
|
|
|
|
|
if attention_mask is None: |
|
|
return causal.unsqueeze(0).unsqueeze(1) |
|
|
|
|
|
|
|
|
attn_full = attention_mask.to(dtype) |
|
|
pad_add = (1.0 - attn_full) * torch.finfo(dtype).min |
|
|
pad_add = pad_add.unsqueeze(1).unsqueeze(2) |
|
|
causal = causal.unsqueeze(0).unsqueeze(1) |
|
|
return causal + pad_add |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
output_router_stats: bool = True, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
|
|
device = input_ids.device |
|
|
B, T = input_ids.shape |
|
|
|
|
|
if use_cache is None: |
|
|
use_cache = False if self.training else True |
|
|
|
|
|
pkv_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
|
|
|
|
|
P = 0 |
|
|
if ( |
|
|
past_key_values is not None |
|
|
and len(past_key_values) > 0 |
|
|
and past_key_values[0] is not None |
|
|
and isinstance(past_key_values[0], (tuple, list)) |
|
|
and past_key_values[0][0] is not None |
|
|
): |
|
|
P = past_key_values[0][0].size(-2) |
|
|
|
|
|
|
|
|
x = self.wte(input_ids) |
|
|
x = self.drop(x) |
|
|
|
|
|
|
|
|
attn_full = None |
|
|
if attention_mask is not None: |
|
|
if attention_mask.size(-1) == T + P: |
|
|
attn_full = attention_mask |
|
|
elif attention_mask.size(-1) == T: |
|
|
if P > 0: |
|
|
ones = torch.ones((B, P), dtype=attention_mask.dtype, device=attention_mask.device) |
|
|
attn_full = torch.cat([ones, attention_mask], dim=-1) |
|
|
else: |
|
|
attn_full = attention_mask |
|
|
else: |
|
|
attn_full = None |
|
|
|
|
|
attn_bias = self._build_attn_mask(attn_full, T, P, device, torch.float32) |
|
|
|
|
|
alpha_list: List[torch.Tensor] = [] |
|
|
if self.training: |
|
|
self._acc_aux_sum = 0.0 |
|
|
self._acc_aux_count = 0 |
|
|
|
|
|
if getattr(self, "_use_gradient_checkpointing", False) and self.training: |
|
|
def create_custom_forward(module, pkv): |
|
|
def custom_forward(x): |
|
|
out_x, out_alpha, _ = module(x, attn_bias, past_key_value=pkv, use_cache=False, position_offset=P) |
|
|
return out_x, out_alpha |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
if past_key_values is not None: |
|
|
curr_past = [ |
|
|
pkv |
|
|
if (pkv is not None and isinstance(pkv, (tuple, list)) and pkv[0] is not None and pkv[1] is not None) |
|
|
else None |
|
|
for pkv in past_key_values |
|
|
] |
|
|
else: |
|
|
curr_past = [None] * len(self.blocks) |
|
|
for layer_idx, block in enumerate(self.blocks): |
|
|
x, alpha = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block, curr_past[layer_idx]), x, use_reentrant=False |
|
|
) |
|
|
alpha_list.append(alpha) |
|
|
if self.training and getattr(block.mlp, "last_aux", None) is not None: |
|
|
self._acc_aux_sum = self._acc_aux_sum + block.mlp.last_aux |
|
|
self._acc_aux_count += 1 |
|
|
else: |
|
|
if past_key_values is not None: |
|
|
curr_past = [ |
|
|
pkv |
|
|
if (pkv is not None and isinstance(pkv, (tuple, list)) and pkv[0] is not None and pkv[1] is not None) |
|
|
else None |
|
|
for pkv in past_key_values |
|
|
] |
|
|
else: |
|
|
curr_past = [None] * len(self.blocks) |
|
|
for layer_idx, block in enumerate(self.blocks): |
|
|
x, alpha, present = block(x, attn_bias, past_key_value=curr_past[layer_idx], use_cache=use_cache, position_offset=P) |
|
|
alpha_list.append(alpha) |
|
|
if self.training and getattr(block.mlp, "last_aux", None) is not None: |
|
|
self._acc_aux_sum = self._acc_aux_sum + block.mlp.last_aux |
|
|
self._acc_aux_count += 1 |
|
|
if use_cache and pkv_list is not None: |
|
|
pkv_list.append(present) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
|
|
|
|
|
|
if output_router_stats and len(alpha_list) > 0: |
|
|
alpha_stack = torch.stack(alpha_list, dim=0) |
|
|
alpha_mean = alpha_stack.mean(dim=(0, 1, 2)) |
|
|
self.router_alpha_mean = alpha_mean.detach() |
|
|
self.router_alpha_entropy = router_aux_loss(alpha_stack.mean(dim=0)) |
|
|
|
|
|
|
|
|
if hasattr(self, "_acc_aux_sum"): |
|
|
if self._acc_aux_count > 0: |
|
|
self._last_router_aux = self._acc_aux_sum / self._acc_aux_count |
|
|
else: |
|
|
self._last_router_aux = None |
|
|
delattr(self, "_acc_aux_sum") |
|
|
delattr(self, "_acc_aux_count") |
|
|
|
|
|
return x, pkv_list |
|
|
|
|
|
|
|
|
class VeronicaForCausalLM(VeronicaModel, GenerationMixin): |
|
|
def __init__(self, config: VeronicaConfig): |
|
|
super().__init__(config) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def tie_weights(self): |
|
|
self._tie_or_clone_weights(self.lm_head, self.get_input_embeddings()) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
): |
|
|
if past_key_values is not None and len(past_key_values) > 0: |
|
|
input_ids = input_ids[:, -1:] |
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"attention_mask": attention_mask, |
|
|
"use_cache": True, |
|
|
} |
|
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx: torch.LongTensor): |
|
|
if past_key_values is None: |
|
|
return past_key_values |
|
|
reordered = [] |
|
|
for (k, v) in past_key_values: |
|
|
reordered.append((k.index_select(0, beam_idx), v.index_select(0, beam_idx))) |
|
|
return reordered |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
**kwargs, |
|
|
) -> CausalLMOutputWithPast: |
|
|
hidden_states, present = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=None, |
|
|
use_cache=use_cache, |
|
|
past_key_values=past_key_values, |
|
|
**kwargs, |
|
|
) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = labels[:, 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
aux = getattr(self, "_last_router_aux", None) |
|
|
if aux is not None and getattr(self.config, "router_aux_weight", 0.0) > 0: |
|
|
if not torch.is_tensor(aux): |
|
|
aux = torch.as_tensor(aux, device=logits.device, dtype=logits.dtype) |
|
|
else: |
|
|
aux = aux.to(device=logits.device, dtype=logits.dtype) |
|
|
aux = aux.clamp_min(0.0) |
|
|
loss = loss + float(self.config.router_aux_weight) * aux |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=present if use_cache else None, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|