KitsuVp commited on
Commit
804ee22
·
verified ·
1 Parent(s): 40bd072

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +44 -18
modeling_neollm.py CHANGED
@@ -510,30 +510,56 @@ class NeoLLMRotaryEmbedding(nn.Module):
510
 
511
  def __init__(self, config: NeoLLMConfig, device=None):
512
  super().__init__()
513
- self.config = config
514
  self.max_seq_len_cached = config.max_position_embeddings
515
  self.original_max_seq_len = config.max_position_embeddings
516
-
517
- rope_type = None
 
 
518
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
519
  rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
 
 
 
 
 
520
 
521
- if rope_type and rope_type != "default" and rope_type in ROPE_INIT_FUNCTIONS:
522
- self.rope_type = rope_type
523
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
524
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
525
- else:
526
- self.rope_type = "default"
527
- self.attention_scaling = 1.0
528
- inv_freq = self.compute_default_rope_parameters(config, device)[0]
529
-
530
- self.register_buffer("inv_freq", inv_freq, persistent=False)
531
- self.original_inv_freq = self.inv_freq
532
 
533
- def compute_default_rope_parameters(self, config, device=None):
534
- dim = int(config.head_dim * config.partial_rotary_factor)
535
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
536
- attention_scaling = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  return inv_freq, attention_scaling
538
 
539
  @torch.no_grad()
 
510
 
511
  def __init__(self, config: NeoLLMConfig, device=None):
512
  super().__init__()
 
513
  self.max_seq_len_cached = config.max_position_embeddings
514
  self.original_max_seq_len = config.max_position_embeddings
515
+ self.config = config
516
+
517
+ # Determine rope_type from rope_scaling config
518
+ self.rope_type = "default"
519
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
520
  rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
521
+ if rope_type and rope_type in ROPE_INIT_FUNCTIONS:
522
+ self.rope_type = rope_type
523
+
524
+ # Initialize rope parameters
525
+ rope_init_fn = self.compute_default_rope_parameters
526
+ if self.rope_type != "default":
527
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
528
 
529
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
 
 
 
 
 
 
 
 
 
 
530
 
531
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
532
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
533
+
534
+ @staticmethod
535
+ def compute_default_rope_parameters(
536
+ config: NeoLLMConfig = None,
537
+ device: Optional["torch.device"] = None,
538
+ seq_len: int = None,
539
+ ) -> tuple["torch.Tensor", float]:
540
+ """
541
+ Computes the inverse frequencies according to the original RoPE implementation
542
+
543
+ Args:
544
+ config: The model configuration.
545
+ device: The device to use for initialization of the inverse frequencies.
546
+ seq_len: The current sequence length. Unused for this type of RoPE.
547
+
548
+ Returns:
549
+ Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE
550
+ embeddings and the post-processing scaling factor applied to the computed cos/sin.
551
+ """
552
+ base = config.rope_theta
553
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
554
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
555
+ dim = int(dim * partial_rotary_factor)
556
+
557
+ attention_scaling = 1.0 # Unused in default RoPE
558
+
559
+ # Compute the inverse frequencies
560
+ inv_freq = 1.0 / (
561
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
562
+ )
563
  return inv_freq, attention_scaling
564
 
565
  @torch.no_grad()