KitsuVp commited on
Commit
14f6d4e
·
verified ·
1 Parent(s): a66b911

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +6 -2
modeling_neollm.py CHANGED
@@ -515,9 +515,13 @@ class NeoLLMRotaryEmbedding(nn.Module):
515
  self.max_seq_len_cached = config.max_position_embeddings
516
  self.original_max_seq_len = config.max_position_embeddings
517
 
 
518
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
519
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
520
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
 
521
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
522
  else:
523
  self.rope_type = None
 
515
  self.max_seq_len_cached = config.max_position_embeddings
516
  self.original_max_seq_len = config.max_position_embeddings
517
 
518
+ rope_type = None
519
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
520
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
521
+
522
+ if rope_type and rope_type != "default" and rope_type in ROPE_INIT_FUNCTIONS:
523
+ self.rope_type = rope_type
524
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
525
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
526
  else:
527
  self.rope_type = None