Update modeling_neollm.py
Browse files- modeling_neollm.py +9 -6
modeling_neollm.py
CHANGED
|
@@ -505,7 +505,6 @@ class StackMemory(nn.Module):
|
|
| 505 |
return output, new_stack[:, -1], new_mask[:, -1]
|
| 506 |
|
| 507 |
# ==================== ROTARY EMBEDDING ====================
|
| 508 |
-
|
| 509 |
class NeoLLMRotaryEmbedding(nn.Module):
|
| 510 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 511 |
|
|
@@ -524,14 +523,19 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 524 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
|
| 525 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 526 |
else:
|
| 527 |
-
self.rope_type =
|
| 528 |
self.attention_scaling = 1.0
|
| 529 |
-
|
| 530 |
-
dim = int(config.head_dim * config.partial_rotary_factor)
|
| 531 |
-
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 532 |
|
| 533 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 534 |
self.original_inv_freq = self.inv_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
@torch.no_grad()
|
| 536 |
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 537 |
def forward(self, x, position_ids):
|
|
@@ -547,7 +551,6 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 547 |
|
| 548 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 549 |
|
| 550 |
-
|
| 551 |
def rotate_half(x):
|
| 552 |
"""Rotates half the hidden dims of the input."""
|
| 553 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
|
| 505 |
return output, new_stack[:, -1], new_mask[:, -1]
|
| 506 |
|
| 507 |
# ==================== ROTARY EMBEDDING ====================
|
|
|
|
| 508 |
class NeoLLMRotaryEmbedding(nn.Module):
|
| 509 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 510 |
|
|
|
|
| 523 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
|
| 524 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 525 |
else:
|
| 526 |
+
self.rope_type = "default"
|
| 527 |
self.attention_scaling = 1.0
|
| 528 |
+
inv_freq = self.compute_default_rope_parameters(config, device)[0]
|
|
|
|
|
|
|
| 529 |
|
| 530 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 531 |
self.original_inv_freq = self.inv_freq
|
| 532 |
+
|
| 533 |
+
def compute_default_rope_parameters(self, config, device=None):
|
| 534 |
+
dim = int(config.head_dim * config.partial_rotary_factor)
|
| 535 |
+
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 536 |
+
attention_scaling = 1.0
|
| 537 |
+
return inv_freq, attention_scaling
|
| 538 |
+
|
| 539 |
@torch.no_grad()
|
| 540 |
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 541 |
def forward(self, x, position_ids):
|
|
|
|
| 551 |
|
| 552 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 553 |
|
|
|
|
| 554 |
def rotate_half(x):
|
| 555 |
"""Rotates half the hidden dims of the input."""
|
| 556 |
x1 = x[..., : x.shape[-1] // 2]
|