szxllm commited on
Commit
1ef8665
·
verified ·
1 Parent(s): 02d752c

Update transformer.py

Browse files
Files changed (1) hide show
  1. transformer.py +1 -3
transformer.py CHANGED
@@ -189,10 +189,8 @@ class GroupedQueryAttention(nn.Module):
189
  if attention_mask.dim() == 2:
190
  attention_mask = attention_mask[:, None, None, :]
191
  if attention_mask.dtype != torch.float:
192
- # 假设传入的是 1(Keep)/0(Mask)
193
  extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
194
  else:
195
- # 假设传入的已经是加性 mask (0/-inf)
196
  extended_mask = attention_mask
197
 
198
  attn_scores = attn_scores + extended_mask
@@ -203,7 +201,7 @@ class GroupedQueryAttention(nn.Module):
203
  torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
204
  diagonal=1
205
  )
206
- causal_mask = causal_mask[-q.shape[2]:, :]#还没懂
207
  attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
208
 
209
  attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
 
189
  if attention_mask.dim() == 2:
190
  attention_mask = attention_mask[:, None, None, :]
191
  if attention_mask.dtype != torch.float:
 
192
  extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
193
  else:
 
194
  extended_mask = attention_mask
195
 
196
  attn_scores = attn_scores + extended_mask
 
201
  torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
202
  diagonal=1
203
  )
204
+ causal_mask = causal_mask[-q.shape[2]:, :]
205
  attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
206
 
207
  attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)