Update modeling_neollm.py
Browse files- modeling_neollm.py +6 -2
modeling_neollm.py
CHANGED
|
@@ -515,9 +515,13 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 515 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 516 |
self.original_max_seq_len = config.max_position_embeddings
|
| 517 |
|
|
|
|
| 518 |
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
|
| 519 |
-
|
| 520 |
-
|
|
|
|
|
|
|
|
|
|
| 521 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 522 |
else:
|
| 523 |
self.rope_type = None
|
|
|
|
| 515 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 516 |
self.original_max_seq_len = config.max_position_embeddings
|
| 517 |
|
| 518 |
+
rope_type = None
|
| 519 |
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
|
| 520 |
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 521 |
+
|
| 522 |
+
if rope_type and rope_type != "default" and rope_type in ROPE_INIT_FUNCTIONS:
|
| 523 |
+
self.rope_type = rope_type
|
| 524 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
|
| 525 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 526 |
else:
|
| 527 |
self.rope_type = None
|