jiang-cc commited on
Commit
b281ecf
·
verified ·
1 Parent(s): 6e620d4

Upload processor

Browse files
Files changed (1) hide show
  1. modeling_yangjian.py +3 -3
modeling_yangjian.py CHANGED
@@ -226,9 +226,9 @@ class OptimizedCrossAttention(nn.Module):
226
  # q, k, v: [batch_size, num_heads, seq_len, head_dim]
227
 
228
  # 选择 attention 实现
229
- attention_interface: Callable = eager_attention_forward
230
- if hasattr(self.config, '_attn_implementation') and self.config._attn_implementation != "eager":
231
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
232
 
233
  # 构造 cu_seqlens 参数(FlashAttention 必需)
234
  cu_seqlens_q = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
 
226
  # q, k, v: [batch_size, num_heads, seq_len, head_dim]
227
 
228
  # 选择 attention 实现
229
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
230
+ # if hasattr(self.config, '_attn_implementation') and self.config._attn_implementation != "eager":
231
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
232
 
233
  # 构造 cu_seqlens 参数(FlashAttention 必需)
234
  cu_seqlens_q = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)