Update rotary.py
Browse files
rotary.py
CHANGED
|
@@ -494,14 +494,15 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 494 |
@base.setter
|
| 495 |
def base(self, new_base):
|
| 496 |
new_base = float(new_base)
|
| 497 |
-
if new_base > 0
|
| 498 |
-
self._base = new_base
|
| 499 |
-
|
| 500 |
-
self.
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
|
|
|
| 505 |
else:
|
| 506 |
raise ValueError("Rotary base value must be positive")
|
| 507 |
|
|
|
|
| 494 |
@base.setter
|
| 495 |
def base(self, new_base):
|
| 496 |
new_base = float(new_base)
|
| 497 |
+
if new_base > 0:
|
| 498 |
+
if self._base != new_base:
|
| 499 |
+
self._base = new_base
|
| 500 |
+
self._update_cos_sin_cache(
|
| 501 |
+
self._seq_len_cached,
|
| 502 |
+
device=self.inv_freq.device,
|
| 503 |
+
dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
|
| 504 |
+
rotary_base_changed=True,
|
| 505 |
+
)
|
| 506 |
else:
|
| 507 |
raise ValueError("Rotary base value must be positive")
|
| 508 |
|