Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -1
modeling_auristream.py
CHANGED
|
@@ -532,7 +532,9 @@ class Rotary(torch.nn.Module):
|
|
| 532 |
def __init__(self, dim, base=10000, learned=False):
|
| 533 |
super().__init__()
|
| 534 |
# Compute the base inverse frequencies as before.
|
| 535 |
-
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
|
|
|
|
| 536 |
|
| 537 |
def forward(self, x):
|
| 538 |
seq_len = x.shape[1]
|
|
|
|
| 532 |
def __init__(self, dim, base=10000, learned=False):
|
| 533 |
super().__init__()
|
| 534 |
# Compute the base inverse frequencies as before.
|
| 535 |
+
# self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 536 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 537 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 538 |
|
| 539 |
def forward(self, x):
|
| 540 |
seq_len = x.shape[1]
|