Update modeling_neollm.py
Browse files- modeling_neollm.py +44 -18
modeling_neollm.py
CHANGED
|
@@ -510,30 +510,56 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 510 |
|
| 511 |
def __init__(self, config: NeoLLMConfig, device=None):
|
| 512 |
super().__init__()
|
| 513 |
-
self.config = config
|
| 514 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 515 |
self.original_max_seq_len = config.max_position_embeddings
|
| 516 |
-
|
| 517 |
-
|
|
|
|
|
|
|
| 518 |
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
|
| 519 |
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
-
|
| 522 |
-
self.rope_type = rope_type
|
| 523 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
|
| 524 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 525 |
-
else:
|
| 526 |
-
self.rope_type = "default"
|
| 527 |
-
self.attention_scaling = 1.0
|
| 528 |
-
inv_freq = self.compute_default_rope_parameters(config, device)[0]
|
| 529 |
-
|
| 530 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 531 |
-
self.original_inv_freq = self.inv_freq
|
| 532 |
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
return inv_freq, attention_scaling
|
| 538 |
|
| 539 |
@torch.no_grad()
|
|
|
|
| 510 |
|
| 511 |
def __init__(self, config: NeoLLMConfig, device=None):
|
| 512 |
super().__init__()
|
|
|
|
| 513 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 514 |
self.original_max_seq_len = config.max_position_embeddings
|
| 515 |
+
self.config = config
|
| 516 |
+
|
| 517 |
+
# Determine rope_type from rope_scaling config
|
| 518 |
+
self.rope_type = "default"
|
| 519 |
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
|
| 520 |
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 521 |
+
if rope_type and rope_type in ROPE_INIT_FUNCTIONS:
|
| 522 |
+
self.rope_type = rope_type
|
| 523 |
+
|
| 524 |
+
# Initialize rope parameters
|
| 525 |
+
rope_init_fn = self.compute_default_rope_parameters
|
| 526 |
+
if self.rope_type != "default":
|
| 527 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 528 |
|
| 529 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 532 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 533 |
+
|
| 534 |
+
@staticmethod
|
| 535 |
+
def compute_default_rope_parameters(
|
| 536 |
+
config: NeoLLMConfig = None,
|
| 537 |
+
device: Optional["torch.device"] = None,
|
| 538 |
+
seq_len: int = None,
|
| 539 |
+
) -> tuple["torch.Tensor", float]:
|
| 540 |
+
"""
|
| 541 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
config: The model configuration.
|
| 545 |
+
device: The device to use for initialization of the inverse frequencies.
|
| 546 |
+
seq_len: The current sequence length. Unused for this type of RoPE.
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE
|
| 550 |
+
embeddings and the post-processing scaling factor applied to the computed cos/sin.
|
| 551 |
+
"""
|
| 552 |
+
base = config.rope_theta
|
| 553 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 554 |
+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
| 555 |
+
dim = int(dim * partial_rotary_factor)
|
| 556 |
+
|
| 557 |
+
attention_scaling = 1.0 # Unused in default RoPE
|
| 558 |
+
|
| 559 |
+
# Compute the inverse frequencies
|
| 560 |
+
inv_freq = 1.0 / (
|
| 561 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 562 |
+
)
|
| 563 |
return inv_freq, attention_scaling
|
| 564 |
|
| 565 |
@torch.no_grad()
|