Update modeling_neollm.py
Browse files- modeling_neollm.py +13 -11
modeling_neollm.py
CHANGED
|
@@ -511,21 +511,23 @@ class NeoLLMRotaryEmbedding(nn.Module):
|
|
| 511 |
|
| 512 |
def __init__(self, config: NeoLLMConfig, device=None):
|
| 513 |
super().__init__()
|
| 514 |
-
|
| 515 |
-
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 516 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 517 |
-
else:
|
| 518 |
-
self.rope_type = "default"
|
| 519 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 520 |
self.original_max_seq_len = config.max_position_embeddings
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 527 |
self.original_inv_freq = self.inv_freq
|
| 528 |
-
|
| 529 |
@torch.no_grad()
|
| 530 |
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 531 |
def forward(self, x, position_ids):
|
|
|
|
| 511 |
|
| 512 |
def __init__(self, config: NeoLLMConfig, device=None):
|
| 513 |
super().__init__()
|
| 514 |
+
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 516 |
self.original_max_seq_len = config.max_position_embeddings
|
| 517 |
+
|
| 518 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
|
| 519 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 520 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 521 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 522 |
+
else:
|
| 523 |
+
self.rope_type = None
|
| 524 |
+
self.attention_scaling = 1.0
|
| 525 |
+
|
| 526 |
+
dim = int(config.head_dim * config.partial_rotary_factor)
|
| 527 |
+
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 528 |
+
|
| 529 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 530 |
self.original_inv_freq = self.inv_freq
|
|
|
|
| 531 |
@torch.no_grad()
|
| 532 |
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 533 |
def forward(self, x, position_ids):
|