KitsuVp commited on
Commit
189f7d6
·
verified ·
1 Parent(s): 91342c4

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +22 -8
modeling_neollm.py CHANGED
@@ -563,20 +563,34 @@ class NeoLLMRotaryEmbedding(nn.Module):
563
  return inv_freq, attention_scaling
564
 
565
  @torch.no_grad()
566
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
567
  def forward(self, x, position_ids):
568
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
569
- position_ids_expanded = position_ids[:, None, :].float()
570
-
 
 
 
 
 
 
571
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
572
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
573
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
574
- emb = torch.cat((freqs, freqs), dim=-1)
 
 
 
 
 
 
 
575
  cos = emb.cos() * self.attention_scaling
576
  sin = emb.sin() * self.attention_scaling
577
-
578
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
579
 
 
580
  def rotate_half(x):
581
  """Rotates half the hidden dims of the input."""
582
  x1 = x[..., : x.shape[-1] // 2]
 
563
  return inv_freq, attention_scaling
564
 
565
  @torch.no_grad()
566
+ @dynamic_rope_update
567
  def forward(self, x, position_ids):
568
+ # Asegura forma [B, S]
569
+ if position_ids.dim() == 1:
570
+ position_ids = position_ids.unsqueeze(0) # [1, S]
571
+
572
+ B = x.shape[0]
573
+ if position_ids.shape[0] != B:
574
+ # Replica posiciones idénticas por batch (semántica correcta)
575
+ position_ids = position_ids.expand(B, -1) # [B, S]
576
+
577
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
578
+
579
+ # inv_freq en float32 en el device correcto (sin expand con stride 0)
580
+ inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) # [d/2]
581
+
582
+ with torch.autocast(device_type=device_type, enabled=False): # fuerza float32
583
+ # Θ[b,s,i] = position_ids[b,s] * inv_freq[i]
584
+ freqs = position_ids.to(dtype=torch.float32).unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)
585
+ # freqs: [B, S, d/2]
586
+
587
+ emb = torch.cat((freqs, freqs), dim=-1) # [B, S, d]
588
  cos = emb.cos() * self.attention_scaling
589
  sin = emb.sin() * self.attention_scaling
590
+
591
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
592
 
593
+
594
  def rotate_half(x):
595
  """Rotates half the hidden dims of the input."""
596
  x1 = x[..., : x.shape[-1] // 2]