Commit
·
2380737
1
Parent(s):
3c52426
Rotary embedding correction
Browse files- modeling_phi.py +6 -4
modeling_phi.py
CHANGED
|
@@ -414,10 +414,12 @@ class MHA(nn.Module):
|
|
| 414 |
super().__init__()
|
| 415 |
|
| 416 |
# Rotary embedding
|
| 417 |
-
self.
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
| 421 |
|
| 422 |
# MLP
|
| 423 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
|
|
|
|
| 414 |
super().__init__()
|
| 415 |
|
| 416 |
# Rotary embedding
|
| 417 |
+
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
| 418 |
+
if self.rotary_dim > 0:
|
| 419 |
+
self.rotary_emb = RotaryEmbedding(
|
| 420 |
+
d_rotary=math.ceil((rotary_dim // n_head) / 2), # d_rotary is half of d_head
|
| 421 |
+
initial_cos_sin_cache_len=config.n_positions,
|
| 422 |
+
)
|
| 423 |
|
| 424 |
# MLP
|
| 425 |
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
|