Update modeling_neollm.py
Browse files- modeling_neollm.py +12 -2
modeling_neollm.py
CHANGED
|
@@ -1406,18 +1406,28 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 1406 |
B = x.shape[0]
|
| 1407 |
if position_ids.shape[0] != B:
|
| 1408 |
position_ids = position_ids.expand(B, -1)
|
| 1409 |
-
|
| 1410 |
device_type = (x.device.type
|
| 1411 |
if isinstance(x.device.type, str) and x.device.type != "mps"
|
| 1412 |
else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1413 |
inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32)
|
| 1414 |
-
|
| 1415 |
with torch.autocast(device_type=device_type, enabled=False):
|
| 1416 |
freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1)
|
| 1417 |
* inv_freq.unsqueeze(0).unsqueeze(0))
|
| 1418 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 1419 |
cos = emb.cos() * self.attention_scaling
|
| 1420 |
sin = emb.sin() * self.attention_scaling
|
|
|
|
|
|
|
| 1421 |
|
| 1422 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 1423 |
|
|
|
|
| 1406 |
B = x.shape[0]
|
| 1407 |
if position_ids.shape[0] != B:
|
| 1408 |
position_ids = position_ids.expand(B, -1)
|
| 1409 |
+
|
| 1410 |
device_type = (x.device.type
|
| 1411 |
if isinstance(x.device.type, str) and x.device.type != "mps"
|
| 1412 |
else "cpu")
|
| 1413 |
+
|
| 1414 |
+
if self.inv_freq.device.type == "meta":
|
| 1415 |
+
inv_freq_data, _ = self.compute_default_rope_parameters(
|
| 1416 |
+
self.config, device=x.device
|
| 1417 |
+
)
|
| 1418 |
+
self.register_buffer("inv_freq", inv_freq_data, persistent=False)
|
| 1419 |
+
self.register_buffer("original_inv_freq", inv_freq_data.clone(), persistent=False)
|
| 1420 |
+
|
| 1421 |
inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32)
|
| 1422 |
+
|
| 1423 |
with torch.autocast(device_type=device_type, enabled=False):
|
| 1424 |
freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1)
|
| 1425 |
* inv_freq.unsqueeze(0).unsqueeze(0))
|
| 1426 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 1427 |
cos = emb.cos() * self.attention_scaling
|
| 1428 |
sin = emb.sin() * self.attention_scaling
|
| 1429 |
+
|
| 1430 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 1431 |
|
| 1432 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 1433 |
|