klemenk commited on
Commit
a2a774e
·
verified ·
1 Parent(s): f552508

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -1
modeling_auristream.py CHANGED
@@ -532,7 +532,9 @@ class Rotary(torch.nn.Module):
532
  def __init__(self, dim, base=10000, learned=False):
533
  super().__init__()
534
  # Compute the base inverse frequencies as before.
535
- self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
 
 
536
 
537
  def forward(self, x):
538
  seq_len = x.shape[1]
 
532
  def __init__(self, dim, base=10000, learned=False):
533
  super().__init__()
534
  # Compute the base inverse frequencies as before.
535
+ # self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
536
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
537
+ self.register_buffer("inv_freq", inv_freq)
538
 
539
  def forward(self, x):
540
  seq_len = x.shape[1]