Update modeling_neollm.py
Browse files- modeling_neollm.py +22 -8
modeling_neollm.py
CHANGED
|
@@ -563,20 +563,34 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 563 |
return inv_freq, attention_scaling
|
| 564 |
|
| 565 |
@torch.no_grad()
|
| 566 |
-
@dynamic_rope_update
|
| 567 |
def forward(self, x, position_ids):
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
cos = emb.cos() * self.attention_scaling
|
| 576 |
sin = emb.sin() * self.attention_scaling
|
| 577 |
-
|
| 578 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 579 |
|
|
|
|
| 580 |
def rotate_half(x):
|
| 581 |
"""Rotates half the hidden dims of the input."""
|
| 582 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
|
| 563 |
return inv_freq, attention_scaling
|
| 564 |
|
| 565 |
@torch.no_grad()
|
| 566 |
+
@dynamic_rope_update
|
| 567 |
def forward(self, x, position_ids):
|
| 568 |
+
# Asegura forma [B, S]
|
| 569 |
+
if position_ids.dim() == 1:
|
| 570 |
+
position_ids = position_ids.unsqueeze(0) # [1, S]
|
| 571 |
+
|
| 572 |
+
B = x.shape[0]
|
| 573 |
+
if position_ids.shape[0] != B:
|
| 574 |
+
# Replica posiciones idénticas por batch (semántica correcta)
|
| 575 |
+
position_ids = position_ids.expand(B, -1) # [B, S]
|
| 576 |
+
|
| 577 |
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 578 |
+
|
| 579 |
+
# inv_freq en float32 en el device correcto (sin expand con stride 0)
|
| 580 |
+
inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32) # [d/2]
|
| 581 |
+
|
| 582 |
+
with torch.autocast(device_type=device_type, enabled=False): # fuerza float32
|
| 583 |
+
# Θ[b,s,i] = position_ids[b,s] * inv_freq[i]
|
| 584 |
+
freqs = position_ids.to(dtype=torch.float32).unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0)
|
| 585 |
+
# freqs: [B, S, d/2]
|
| 586 |
+
|
| 587 |
+
emb = torch.cat((freqs, freqs), dim=-1) # [B, S, d]
|
| 588 |
cos = emb.cos() * self.attention_scaling
|
| 589 |
sin = emb.sin() * self.attention_scaling
|
| 590 |
+
|
| 591 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 592 |
|
| 593 |
+
|
| 594 |
def rotate_half(x):
|
| 595 |
"""Rotates half the hidden dims of the input."""
|
| 596 |
x1 = x[..., : x.shape[-1] // 2]
|