Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -10
modeling_auristream.py
CHANGED
|
@@ -529,18 +529,10 @@ class MLP(nn.Module):
|
|
| 529 |
|
| 530 |
|
| 531 |
class Rotary(torch.nn.Module):
|
| 532 |
-
def __init__(self, dim, base=
|
| 533 |
super().__init__()
|
| 534 |
# Compute the base inverse frequencies as before.
|
| 535 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 536 |
-
# If learned is True, register as a parameter; otherwise, as a buffer.
|
| 537 |
-
if learned:
|
| 538 |
-
# Initialize randomly and register as a parameter.
|
| 539 |
-
self.inv_freq = torch.nn.Parameter(inv_freq)
|
| 540 |
-
nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
|
| 541 |
-
else:
|
| 542 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 543 |
-
self.learned = learned # (optional) Save the flag if needed later
|
| 544 |
|
| 545 |
def forward(self, x):
|
| 546 |
seq_len = x.shape[1]
|
|
@@ -552,6 +544,7 @@ class Rotary(torch.nn.Module):
|
|
| 552 |
sin_cached = freqs.sin()
|
| 553 |
return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
|
| 554 |
|
|
|
|
| 555 |
def apply_rotary_emb(x, cos, sin):
|
| 556 |
assert x.ndim == 4 # multihead attention expected
|
| 557 |
d = x.shape[3] // 2
|
|
|
|
| 529 |
|
| 530 |
|
| 531 |
class Rotary(torch.nn.Module):
|
| 532 |
+
def __init__(self, dim, base=10000, learned=False):
|
| 533 |
super().__init__()
|
| 534 |
# Compute the base inverse frequencies as before.
|
| 535 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
def forward(self, x):
|
| 538 |
seq_len = x.shape[1]
|
|
|
|
| 544 |
sin_cached = freqs.sin()
|
| 545 |
return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
|
| 546 |
|
| 547 |
+
|
| 548 |
def apply_rotary_emb(x, cos, sin):
|
| 549 |
assert x.ndim == 4 # multihead attention expected
|
| 550 |
d = x.shape[3] // 2
|