Update transformer.py
Browse files- transformer.py +1 -3
transformer.py
CHANGED
|
@@ -189,10 +189,8 @@ class GroupedQueryAttention(nn.Module):
|
|
| 189 |
if attention_mask.dim() == 2:
|
| 190 |
attention_mask = attention_mask[:, None, None, :]
|
| 191 |
if attention_mask.dtype != torch.float:
|
| 192 |
-
# 假设传入的是 1(Keep)/0(Mask)
|
| 193 |
extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
|
| 194 |
else:
|
| 195 |
-
# 假设传入的已经是加性 mask (0/-inf)
|
| 196 |
extended_mask = attention_mask
|
| 197 |
|
| 198 |
attn_scores = attn_scores + extended_mask
|
|
@@ -203,7 +201,7 @@ class GroupedQueryAttention(nn.Module):
|
|
| 203 |
torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
|
| 204 |
diagonal=1
|
| 205 |
)
|
| 206 |
-
causal_mask = causal_mask[-q.shape[2]:, :]
|
| 207 |
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
|
| 208 |
|
| 209 |
attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
|
|
| 189 |
if attention_mask.dim() == 2:
|
| 190 |
attention_mask = attention_mask[:, None, None, :]
|
| 191 |
if attention_mask.dtype != torch.float:
|
|
|
|
| 192 |
extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
|
| 193 |
else:
|
|
|
|
| 194 |
extended_mask = attention_mask
|
| 195 |
|
| 196 |
attn_scores = attn_scores + extended_mask
|
|
|
|
| 201 |
torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
|
| 202 |
diagonal=1
|
| 203 |
)
|
| 204 |
+
causal_mask = causal_mask[-q.shape[2]:, :]
|
| 205 |
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
|
| 206 |
|
| 207 |
attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|