Upload processor
Browse files- modeling_yangjian.py +3 -3
modeling_yangjian.py
CHANGED
|
@@ -226,9 +226,9 @@ class OptimizedCrossAttention(nn.Module):
|
|
| 226 |
# q, k, v: [batch_size, num_heads, seq_len, head_dim]
|
| 227 |
|
| 228 |
# 选择 attention 实现
|
| 229 |
-
attention_interface: Callable =
|
| 230 |
-
if hasattr(self.config, '_attn_implementation') and self.config._attn_implementation != "eager":
|
| 231 |
-
|
| 232 |
|
| 233 |
# 构造 cu_seqlens 参数(FlashAttention 必需)
|
| 234 |
cu_seqlens_q = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
|
|
|
|
| 226 |
# q, k, v: [batch_size, num_heads, seq_len, head_dim]
|
| 227 |
|
| 228 |
# 选择 attention 实现
|
| 229 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
| 230 |
+
# if hasattr(self.config, '_attn_implementation') and self.config._attn_implementation != "eager":
|
| 231 |
+
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 232 |
|
| 233 |
# 构造 cu_seqlens 参数(FlashAttention 必需)
|
| 234 |
cu_seqlens_q = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
|