saracandu commited on
Commit
c04796c
·
verified ·
1 Parent(s): 5230ede

improved3?

Browse files
Files changed (1) hide show
  1. modeling_stldec.py +13 -4
modeling_stldec.py CHANGED
@@ -12,6 +12,7 @@ class STLPreTrainedModel(PreTrainedModel):
12
  config_class = STLDecoderConfig
13
  base_model_prefix = "model"
14
  def _init_weights(self, module):
 
15
  if isinstance(module, nn.Linear):
16
  torch.nn.init.xavier_uniform_(module.weight)
17
  if module.bias is not None:
@@ -83,9 +84,10 @@ class STLDecoderBlock(nn.Module):
83
  return self.internal_forward(hidden_states, encoder_hidden_states, past_key_value, attention_mask)
84
 
85
  def internal_forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None):
 
86
  # 1. Self-Attention
87
  residual = hidden_states
88
- hidden_states = self.ln1(hidden_states) # LN PRIMA dell'operazione
89
  hidden_states, pkv = self.self_attn(hidden_states, past_key_value=past_key_value, attention_mask=attention_mask)
90
  hidden_states = residual + self.dropout(hidden_states)
91
 
@@ -107,12 +109,14 @@ class STLDecoderModel(STLPreTrainedModel, GenerationMixin):
107
  def __init__(self, config):
108
  super().__init__(config)
109
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
 
110
  self.layers = nn.ModuleList([STLDecoderBlock(config) for _ in range(config.num_hidden_layers)])
111
  self.norm = nn.LayerNorm(config.hidden_size)
112
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
113
  self.post_init()
114
 
115
  def get_sinusoidal_embeddings(self, seq_len, d_model, device):
 
116
  inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)).to(device)
117
  pos = torch.arange(seq_len, device=device).type_as(inv_freq)
118
  sin_inp = torch.einsum("i,j->ij", pos, inv_freq)
@@ -140,9 +144,11 @@ class STLDecoderModel(STLPreTrainedModel, GenerationMixin):
140
 
141
  hidden_states = self.embed_tokens(input_ids)
142
 
 
143
  pos_emb = self.get_sinusoidal_embeddings(seq_len, self.config.hidden_size, input_ids.device)
144
  hidden_states = hidden_states + pos_emb[:, :seq_len, :]
145
 
 
146
  causal_mask = torch.full((seq_len, seq_len + past_len), float("-inf"), device=input_ids.device, dtype=hidden_states.dtype)
147
  causal_mask.triu_(diagonal=past_len + 1)
148
  causal_mask = causal_mask[None, None, :, :]
@@ -161,11 +167,14 @@ class STLDecoderModel(STLPreTrainedModel, GenerationMixin):
161
  shift_logits = logits[..., :-1, :].contiguous()
162
  shift_labels = labels[..., 1:].contiguous()
163
 
164
- vocab_size = logits.size(-1)
 
 
 
165
  loss = F.cross_entropy(
166
- shift_logits.view(-1, vocab_size),
167
  shift_labels.view(-1),
168
- ignore_index=self.config.pad_token_id
169
  )
170
 
171
  if not return_dict:
 
12
  config_class = STLDecoderConfig
13
  base_model_prefix = "model"
14
  def _init_weights(self, module):
15
+ """Migliorata con Xavier Uniform per evitare gradienti esplosivi nelle fasi iniziali."""
16
  if isinstance(module, nn.Linear):
17
  torch.nn.init.xavier_uniform_(module.weight)
18
  if module.bias is not None:
 
84
  return self.internal_forward(hidden_states, encoder_hidden_states, past_key_value, attention_mask)
85
 
86
  def internal_forward(self, hidden_states, encoder_hidden_states=None, past_key_value=None, attention_mask=None):
87
+ """Modificata in Pre-Norm per garantire la stabilità del gradiente."""
88
  # 1. Self-Attention
89
  residual = hidden_states
90
+ hidden_states = self.ln1(hidden_states) # LN PRIMA
91
  hidden_states, pkv = self.self_attn(hidden_states, past_key_value=past_key_value, attention_mask=attention_mask)
92
  hidden_states = residual + self.dropout(hidden_states)
93
 
 
109
  def __init__(self, config):
110
  super().__init__(config)
111
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
112
+ # Posizionali Sinusoidali rimossi dal __init__ perché calcolati dinamicamente nel forward
113
  self.layers = nn.ModuleList([STLDecoderBlock(config) for _ in range(config.num_hidden_layers)])
114
  self.norm = nn.LayerNorm(config.hidden_size)
115
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
116
  self.post_init()
117
 
118
  def get_sinusoidal_embeddings(self, seq_len, d_model, device):
119
+ """Genera posizioni matematiche stabili, evitando errori di indice della tabella fixed."""
120
  inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)).to(device)
121
  pos = torch.arange(seq_len, device=device).type_as(inv_freq)
122
  sin_inp = torch.einsum("i,j->ij", pos, inv_freq)
 
144
 
145
  hidden_states = self.embed_tokens(input_ids)
146
 
147
+ # Sostituzione con sinusoidali (più robusti dello Script 2)
148
  pos_emb = self.get_sinusoidal_embeddings(seq_len, self.config.hidden_size, input_ids.device)
149
  hidden_states = hidden_states + pos_emb[:, :seq_len, :]
150
 
151
+ # Maschera causale ottimizzata
152
  causal_mask = torch.full((seq_len, seq_len + past_len), float("-inf"), device=input_ids.device, dtype=hidden_states.dtype)
153
  causal_mask.triu_(diagonal=past_len + 1)
154
  causal_mask = causal_mask[None, None, :, :]
 
167
  shift_logits = logits[..., :-1, :].contiguous()
168
  shift_labels = labels[..., 1:].contiguous()
169
 
170
+ # --- MODIFICA DI SICUREZZA ---
171
+ # Prendiamo il vocab_size dai logits correnti, non dal config fisso,
172
+ # per evitare l'Assertion t < n_classes se hai fatto un resize_token_embeddings.
173
+ current_vocab_size = logits.size(-1)
174
  loss = F.cross_entropy(
175
+ shift_logits.view(-1, current_vocab_size),
176
  shift_labels.view(-1),
177
+ ignore_index=-100 # Standard HF per ignorare padding trasformato
178
  )
179
 
180
  if not return_dict: