KitsuVp commited on
Commit
4651776
·
verified ·
1 Parent(s): 98ad47e

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +9 -6
modeling_neollm.py CHANGED
@@ -505,7 +505,6 @@ class StackMemory(nn.Module):
505
  return output, new_stack[:, -1], new_mask[:, -1]
506
 
507
  # ==================== ROTARY EMBEDDING ====================
508
-
509
  class NeoLLMRotaryEmbedding(nn.Module):
510
  inv_freq: torch.Tensor # fix linting for `register_buffer`
511
 
@@ -524,14 +523,19 @@ class NeoLLMRotaryEmbedding(nn.Module):
524
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
525
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
526
  else:
527
- self.rope_type = None
528
  self.attention_scaling = 1.0
529
-
530
- dim = int(config.head_dim * config.partial_rotary_factor)
531
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
532
 
533
  self.register_buffer("inv_freq", inv_freq, persistent=False)
534
  self.original_inv_freq = self.inv_freq
 
 
 
 
 
 
 
535
  @torch.no_grad()
536
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
537
  def forward(self, x, position_ids):
@@ -547,7 +551,6 @@ class NeoLLMRotaryEmbedding(nn.Module):
547
 
548
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
549
 
550
-
551
  def rotate_half(x):
552
  """Rotates half the hidden dims of the input."""
553
  x1 = x[..., : x.shape[-1] // 2]
 
505
  return output, new_stack[:, -1], new_mask[:, -1]
506
 
507
  # ==================== ROTARY EMBEDDING ====================
 
508
  class NeoLLMRotaryEmbedding(nn.Module):
509
  inv_freq: torch.Tensor # fix linting for `register_buffer`
510
 
 
523
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
524
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
525
  else:
526
+ self.rope_type = "default"
527
  self.attention_scaling = 1.0
528
+ inv_freq = self.compute_default_rope_parameters(config, device)[0]
 
 
529
 
530
  self.register_buffer("inv_freq", inv_freq, persistent=False)
531
  self.original_inv_freq = self.inv_freq
532
+
533
+ def compute_default_rope_parameters(self, config, device=None):
534
+ dim = int(config.head_dim * config.partial_rotary_factor)
535
+ inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
536
+ attention_scaling = 1.0
537
+ return inv_freq, attention_scaling
538
+
539
  @torch.no_grad()
540
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
541
  def forward(self, x, position_ids):
 
551
 
552
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
553
 
 
554
  def rotate_half(x):
555
  """Rotates half the hidden dims of the input."""
556
  x1 = x[..., : x.shape[-1] // 2]