Update rotary.py
Browse files
rotary.py
CHANGED
|
@@ -495,7 +495,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 495 |
def base(self, new_base):
|
| 496 |
new_base = float(new_base)
|
| 497 |
if new_base > 0:
|
| 498 |
-
if self._base != new_base:
|
| 499 |
self._base = new_base
|
| 500 |
self._update_cos_sin_cache(
|
| 501 |
self._seq_len_cached,
|
|
|
|
| 495 |
def base(self, new_base):
|
| 496 |
new_base = float(new_base)
|
| 497 |
if new_base > 0:
|
| 498 |
+
if self._base != new_base: # only update if the base value has changed
|
| 499 |
self._base = new_base
|
| 500 |
self._update_cos_sin_cache(
|
| 501 |
self._seq_len_cached,
|