klemenk commited on
Commit
90a4535
·
verified ·
1 Parent(s): 6bb6180

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -10
modeling_auristream.py CHANGED
@@ -529,18 +529,10 @@ class MLP(nn.Module):
529
 
530
 
531
  class Rotary(torch.nn.Module):
532
- def __init__(self, dim, base=500000, learned=True):
533
  super().__init__()
534
  # Compute the base inverse frequencies as before.
535
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
536
- # If learned is True, register as a parameter; otherwise, as a buffer.
537
- if learned:
538
- # Initialize randomly and register as a parameter.
539
- self.inv_freq = torch.nn.Parameter(inv_freq)
540
- nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
541
- else:
542
- self.register_buffer("inv_freq", inv_freq)
543
- self.learned = learned # (optional) Save the flag if needed later
544
 
545
  def forward(self, x):
546
  seq_len = x.shape[1]
@@ -552,6 +544,7 @@ class Rotary(torch.nn.Module):
552
  sin_cached = freqs.sin()
553
  return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
554
 
 
555
  def apply_rotary_emb(x, cos, sin):
556
  assert x.ndim == 4 # multihead attention expected
557
  d = x.shape[3] // 2
 
529
 
530
 
531
  class Rotary(torch.nn.Module):
532
+ def __init__(self, dim, base=10000, learned=False):
533
  super().__init__()
534
  # Compute the base inverse frequencies as before.
535
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
 
 
 
 
 
 
 
 
536
 
537
  def forward(self, x):
538
  seq_len = x.shape[1]
 
544
  sin_cached = freqs.sin()
545
  return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
546
 
547
+
548
  def apply_rotary_emb(x, cos, sin):
549
  assert x.ndim == 4 # multihead attention expected
550
  d = x.shape[3] // 2