|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
from typing import List, Optional, Tuple, Union, Dict |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions |
|
|
from transformers.generation import GenerationMixin |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from .configuration_stldec import STLDecoderConfig |
|
|
|
|
|
class STLPreTrainedModel(PreTrainedModel): |
|
|
config_class = STLDecoderConfig |
|
|
base_model_prefix = "model" |
|
|
def _init_weights(self, module): |
|
|
"""Migliorata con Xavier Uniform per evitare gradienti esplosivi nelle fasi iniziali.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
|
|
class STLAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
self.dropout = nn.Dropout(config.attention_dropout) |
|
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None): |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
key_value_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
|
|
kv_len = key_value_states.size(1) |
|
|
|
|
|
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) * self.scaling |
|
|
k = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
if past_key_value is not None and encoder_hidden_states is None: |
|
|
k = torch.cat([past_key_value[0], k], dim=2) |
|
|
v = torch.cat([past_key_value[1], v], dim=2) |
|
|
|
|
|
present_kv = (k, v) if encoder_hidden_states is None else None |
|
|
attn_weights = torch.matmul(q, k.transpose(-1, -2)) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) |
|
|
attn_output = torch.matmul(self.dropout(attn_probs), v) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) |
|
|
|
|
|
return self.out_proj(attn_output), present_kv |
|
|
|
|
|
class STLDecoderBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(config.hidden_size) |
|
|
self.self_attn = STLAttention(config) |
|
|
self.ln_cross = nn.LayerNorm(config.hidden_size) |
|
|
self.cross_attn = STLAttention(config) |
|
|
self.ln2 = nn.LayerNorm(config.hidden_size) |
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
return checkpoint( |
|
|
self.internal_forward, |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
past_key_value, |
|
|
attention_mask, |
|
|
use_reentrant=False |
|
|
) |
|
|
return self.internal_forward(hidden_states, encoder_hidden_states, past_key_value, attention_mask) |
|
|
|
|
|
def internal_forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None): |
|
|
"""Modificata in Pre-Norm per garantire la stabilità del gradiente.""" |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.ln1(hidden_states) |
|
|
hidden_states, pkv = self.self_attn(hidden_states, past_key_value=past_key_value, attention_mask=attention_mask) |
|
|
hidden_states = residual + self.dropout(hidden_states) |
|
|
|
|
|
|
|
|
if encoder_hidden_states is not None: |
|
|
residual = hidden_states |
|
|
hidden_states = self.ln_cross(hidden_states) |
|
|
hidden_states, _ = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) |
|
|
hidden_states = residual + self.dropout(hidden_states) |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.ln2(hidden_states) |
|
|
hidden_states = self.fc2(F.gelu(self.fc1(hidden_states))) |
|
|
hidden_states = residual + self.dropout(hidden_states) |
|
|
return hidden_states, pkv |
|
|
|
|
|
class STLDecoderModel(STLPreTrainedModel, GenerationMixin): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) |
|
|
|
|
|
self.layers = nn.ModuleList([STLDecoderBlock(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.norm = nn.LayerNorm(config.hidden_size) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_sinusoidal_embeddings(self, seq_len, d_model, device): |
|
|
"""Genera posizioni matematiche stabili, evitando errori di indice della tabella fixed.""" |
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)).to(device) |
|
|
pos = torch.arange(seq_len, device=device).type_as(inv_freq) |
|
|
sin_inp = torch.einsum("i,j->ij", pos, inv_freq) |
|
|
emb = torch.cat((sin_inp.sin(), sin_inp.cos()), dim=-1) |
|
|
return emb[None, :, :] |
|
|
|
|
|
@property |
|
|
def supports_gradient_checkpointing(self): return True |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, STLDecoderBlock): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def get_input_embeddings(self): return self.embed_tokens |
|
|
def set_input_embeddings(self, v): self.embed_tokens = v |
|
|
def get_output_embeddings(self): return self.lm_head |
|
|
|
|
|
def forward(self, input_ids=None, encoder_hidden_states=None, past_key_values=None, labels=None, use_cache=None, return_dict=None, **kwargs): |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
bsz, seq_len = input_ids.size() |
|
|
past_len = 0 |
|
|
if past_key_values is not None: |
|
|
past_len = past_key_values[0][0].shape[2] |
|
|
|
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
pos_emb = self.get_sinusoidal_embeddings(seq_len, self.config.hidden_size, input_ids.device) |
|
|
hidden_states = hidden_states + pos_emb[:, :seq_len, :] |
|
|
|
|
|
|
|
|
causal_mask = torch.full((seq_len, seq_len + past_len), float("-inf"), device=input_ids.device, dtype=hidden_states.dtype) |
|
|
causal_mask.triu_(diagonal=past_len + 1) |
|
|
causal_mask = causal_mask[None, None, :, :] |
|
|
|
|
|
next_cache = () if use_cache else None |
|
|
for i, layer in enumerate(self.layers): |
|
|
pk = past_key_values[i] if past_key_values is not None else None |
|
|
hidden_states, pkv = layer(hidden_states, encoder_hidden_states=encoder_hidden_states, past_key_value=pk, attention_mask=causal_mask) |
|
|
if use_cache: next_cache += (pkv,) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_vocab_size = logits.size(-1) |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, current_vocab_size), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
return (loss, logits, next_cache) if loss is not None else (logits, next_cache) |
|
|
return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits, past_key_values=next_cache) |
|
|
|
|
|
@staticmethod |
|
|
def _reorder_cache(past_key_values, beam_idx): |
|
|
return tuple(tuple(s.index_select(0, beam_idx.to(s.device)) for s in layer) for layer in past_key_values) |