guanwenyu1995 commited on
Commit
603f8f5
·
verified ·
1 Parent(s): b832296

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +1 -1
modeling_llama.py CHANGED
@@ -280,7 +280,7 @@ def eager_attention_forward(
280
 
281
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
282
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
283
- attn_output = torch.matmtes)
284
  attn_output = attn_output.transpose(1, 2).contiguous()
285
 
286
  return attn_output, attn_weights
 
280
 
281
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
282
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
283
+ attn_output = torch.matmul(attn_weights, value_states)
284
  attn_output = attn_output.transpose(1, 2).contiguous()
285
 
286
  return attn_output, attn_weights