improved3?
Browse files- modeling_stldec.py +13 -4
modeling_stldec.py
CHANGED
|
@@ -12,6 +12,7 @@ class STLPreTrainedModel(PreTrainedModel):
|
|
| 12 |
config_class = STLDecoderConfig
|
| 13 |
base_model_prefix = "model"
|
| 14 |
def _init_weights(self, module):
|
|
|
|
| 15 |
if isinstance(module, nn.Linear):
|
| 16 |
torch.nn.init.xavier_uniform_(module.weight)
|
| 17 |
if module.bias is not None:
|
|
@@ -83,9 +84,10 @@ class STLDecoderBlock(nn.Module):
|
|
| 83 |
return self.internal_forward(hidden_states, encoder_hidden_states, past_key_value, attention_mask)
|
| 84 |
|
| 85 |
def internal_forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None):
|
|
|
|
| 86 |
# 1. Self-Attention
|
| 87 |
residual = hidden_states
|
| 88 |
-
hidden_states = self.ln1(hidden_states) # LN PRIMA
|
| 89 |
hidden_states, pkv = self.self_attn(hidden_states, past_key_value=past_key_value, attention_mask=attention_mask)
|
| 90 |
hidden_states = residual + self.dropout(hidden_states)
|
| 91 |
|
|
@@ -107,12 +109,14 @@ class STLDecoderModel(STLPreTrainedModel, GenerationMixin):
|
|
| 107 |
def __init__(self, config):
|
| 108 |
super().__init__(config)
|
| 109 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
|
|
|
| 110 |
self.layers = nn.ModuleList([STLDecoderBlock(config) for _ in range(config.num_hidden_layers)])
|
| 111 |
self.norm = nn.LayerNorm(config.hidden_size)
|
| 112 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 113 |
self.post_init()
|
| 114 |
|
| 115 |
def get_sinusoidal_embeddings(self, seq_len, d_model, device):
|
|
|
|
| 116 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)).to(device)
|
| 117 |
pos = torch.arange(seq_len, device=device).type_as(inv_freq)
|
| 118 |
sin_inp = torch.einsum("i,j->ij", pos, inv_freq)
|
|
@@ -140,9 +144,11 @@ class STLDecoderModel(STLPreTrainedModel, GenerationMixin):
|
|
| 140 |
|
| 141 |
hidden_states = self.embed_tokens(input_ids)
|
| 142 |
|
|
|
|
| 143 |
pos_emb = self.get_sinusoidal_embeddings(seq_len, self.config.hidden_size, input_ids.device)
|
| 144 |
hidden_states = hidden_states + pos_emb[:, :seq_len, :]
|
| 145 |
|
|
|
|
| 146 |
causal_mask = torch.full((seq_len, seq_len + past_len), float("-inf"), device=input_ids.device, dtype=hidden_states.dtype)
|
| 147 |
causal_mask.triu_(diagonal=past_len + 1)
|
| 148 |
causal_mask = causal_mask[None, None, :, :]
|
|
@@ -161,11 +167,14 @@ class STLDecoderModel(STLPreTrainedModel, GenerationMixin):
|
|
| 161 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 162 |
shift_labels = labels[..., 1:].contiguous()
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
| 165 |
loss = F.cross_entropy(
|
| 166 |
-
shift_logits.view(-1,
|
| 167 |
shift_labels.view(-1),
|
| 168 |
-
ignore_index
|
| 169 |
)
|
| 170 |
|
| 171 |
if not return_dict:
|
|
|
|
| 12 |
config_class = STLDecoderConfig
|
| 13 |
base_model_prefix = "model"
|
| 14 |
def _init_weights(self, module):
|
| 15 |
+
"""Migliorata con Xavier Uniform per evitare gradienti esplosivi nelle fasi iniziali."""
|
| 16 |
if isinstance(module, nn.Linear):
|
| 17 |
torch.nn.init.xavier_uniform_(module.weight)
|
| 18 |
if module.bias is not None:
|
|
|
|
| 84 |
return self.internal_forward(hidden_states, encoder_hidden_states, past_key_value, attention_mask)
|
| 85 |
|
| 86 |
def internal_forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None):
|
| 87 |
+
"""Modificata in Pre-Norm per garantire la stabilità del gradiente."""
|
| 88 |
# 1. Self-Attention
|
| 89 |
residual = hidden_states
|
| 90 |
+
hidden_states = self.ln1(hidden_states) # LN PRIMA
|
| 91 |
hidden_states, pkv = self.self_attn(hidden_states, past_key_value=past_key_value, attention_mask=attention_mask)
|
| 92 |
hidden_states = residual + self.dropout(hidden_states)
|
| 93 |
|
|
|
|
| 109 |
def __init__(self, config):
|
| 110 |
super().__init__(config)
|
| 111 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
| 112 |
+
# Posizionali Sinusoidali rimossi dal __init__ perché calcolati dinamicamente nel forward
|
| 113 |
self.layers = nn.ModuleList([STLDecoderBlock(config) for _ in range(config.num_hidden_layers)])
|
| 114 |
self.norm = nn.LayerNorm(config.hidden_size)
|
| 115 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 116 |
self.post_init()
|
| 117 |
|
| 118 |
def get_sinusoidal_embeddings(self, seq_len, d_model, device):
|
| 119 |
+
"""Genera posizioni matematiche stabili, evitando errori di indice della tabella fixed."""
|
| 120 |
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)).to(device)
|
| 121 |
pos = torch.arange(seq_len, device=device).type_as(inv_freq)
|
| 122 |
sin_inp = torch.einsum("i,j->ij", pos, inv_freq)
|
|
|
|
| 144 |
|
| 145 |
hidden_states = self.embed_tokens(input_ids)
|
| 146 |
|
| 147 |
+
# Sostituzione con sinusoidali (più robusti dello Script 2)
|
| 148 |
pos_emb = self.get_sinusoidal_embeddings(seq_len, self.config.hidden_size, input_ids.device)
|
| 149 |
hidden_states = hidden_states + pos_emb[:, :seq_len, :]
|
| 150 |
|
| 151 |
+
# Maschera causale ottimizzata
|
| 152 |
causal_mask = torch.full((seq_len, seq_len + past_len), float("-inf"), device=input_ids.device, dtype=hidden_states.dtype)
|
| 153 |
causal_mask.triu_(diagonal=past_len + 1)
|
| 154 |
causal_mask = causal_mask[None, None, :, :]
|
|
|
|
| 167 |
shift_logits = logits[..., :-1, :].contiguous()
|
| 168 |
shift_labels = labels[..., 1:].contiguous()
|
| 169 |
|
| 170 |
+
# --- MODIFICA DI SICUREZZA ---
|
| 171 |
+
# Prendiamo il vocab_size dai logits correnti, non dal config fisso,
|
| 172 |
+
# per evitare l'Assertion t < n_classes se hai fatto un resize_token_embeddings.
|
| 173 |
+
current_vocab_size = logits.size(-1)
|
| 174 |
loss = F.cross_entropy(
|
| 175 |
+
shift_logits.view(-1, current_vocab_size),
|
| 176 |
shift_labels.view(-1),
|
| 177 |
+
ignore_index=-100 # Standard HF per ignorare padding trasformato
|
| 178 |
)
|
| 179 |
|
| 180 |
if not return_dict:
|