KitsuVp commited on
Commit
a66b911
·
verified ·
1 Parent(s): 51cf4ed

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +13 -11
modeling_neollm.py CHANGED
@@ -511,21 +511,23 @@ class NeoLLMRotaryEmbedding(nn.Module):
511
 
512
  def __init__(self, config: NeoLLMConfig, device=None):
513
  super().__init__()
514
- # BC: "rope_type" was originally "type"
515
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
516
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
517
- else:
518
- self.rope_type = "default"
519
  self.max_seq_len_cached = config.max_position_embeddings
520
  self.original_max_seq_len = config.max_position_embeddings
521
-
522
- self.config = config
523
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
524
-
525
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
 
 
526
  self.register_buffer("inv_freq", inv_freq, persistent=False)
527
  self.original_inv_freq = self.inv_freq
528
-
529
  @torch.no_grad()
530
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
531
  def forward(self, x, position_ids):
 
511
 
512
  def __init__(self, config: NeoLLMConfig, device=None):
513
  super().__init__()
514
+ self.config = config
 
 
 
 
515
  self.max_seq_len_cached = config.max_position_embeddings
516
  self.original_max_seq_len = config.max_position_embeddings
517
+
518
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
519
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
520
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
521
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
522
+ else:
523
+ self.rope_type = None
524
+ self.attention_scaling = 1.0
525
+
526
+ dim = int(config.head_dim * config.partial_rotary_factor)
527
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
528
+
529
  self.register_buffer("inv_freq", inv_freq, persistent=False)
530
  self.original_inv_freq = self.inv_freq
 
531
  @torch.no_grad()
532
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
533
  def forward(self, x, position_ids):