BucketOfFish commited on
Commit
2380737
·
1 Parent(s): 3c52426

Rotary embedding correction

Browse files
Files changed (1) hide show
  1. modeling_phi.py +6 -4
modeling_phi.py CHANGED
@@ -414,10 +414,12 @@ class MHA(nn.Module):
414
  super().__init__()
415
 
416
  # Rotary embedding
417
- self.rotary_emb = RotaryEmbedding(
418
- d_rotary=math.ceil((rotary_dim // n_head) / 2), # d_rotary is half of d_head
419
- initial_cos_sin_cache_len=config.n_positions,
420
- )
 
 
421
 
422
  # MLP
423
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
 
414
  super().__init__()
415
 
416
  # Rotary embedding
417
+ self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
418
+ if self.rotary_dim > 0:
419
+ self.rotary_emb = RotaryEmbedding(
420
+ d_rotary=math.ceil((rotary_dim // n_head) / 2), # d_rotary is half of d_head
421
+ initial_cos_sin_cache_len=config.n_positions,
422
+ )
423
 
424
  # MLP
425
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(