Update modeling_phi.py
Browse files- modeling_phi.py +1 -1
modeling_phi.py
CHANGED
|
@@ -308,7 +308,6 @@ class PhiAttention(nn.Module):
|
|
| 308 |
past_key_value: Optional[Cache] = None,
|
| 309 |
output_attentions: bool = False,
|
| 310 |
use_cache: bool = False,
|
| 311 |
-
**kwargs,
|
| 312 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 313 |
bsz, q_len, _ = hidden_states.size()
|
| 314 |
|
|
@@ -358,6 +357,7 @@ class PhiAttention(nn.Module):
|
|
| 358 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 359 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 360 |
|
|
|
|
| 361 |
attn_weights = torch.matmul(
|
| 362 |
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
| 363 |
) / math.sqrt(self.head_dim)
|
|
|
|
| 308 |
past_key_value: Optional[Cache] = None,
|
| 309 |
output_attentions: bool = False,
|
| 310 |
use_cache: bool = False,
|
|
|
|
| 311 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 312 |
bsz, q_len, _ = hidden_states.size()
|
| 313 |
|
|
|
|
| 357 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 358 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 359 |
|
| 360 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
| 361 |
attn_weights = torch.matmul(
|
| 362 |
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
| 363 |
) / math.sqrt(self.head_dim)
|