Upload YangJianVLForConditionalGeneration
Browse files- modeling_yangjian.py +2 -2
modeling_yangjian.py
CHANGED
|
@@ -231,9 +231,9 @@ class OptimizedCrossAttention(nn.Module):
|
|
| 231 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 232 |
|
| 233 |
# 构造 cu_seqlens 参数(FlashAttention 必需)
|
| 234 |
-
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
|
| 235 |
if self.is_cross_attention and key_value_states is not None:
|
| 236 |
-
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seq_len_kv, step=seq_len_kv, dtype=torch.int32, device=k.device)
|
| 237 |
else:
|
| 238 |
cu_seqlens_k = cu_seqlens_q
|
| 239 |
|
|
|
|
| 231 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 232 |
|
| 233 |
# 构造 cu_seqlens 参数(FlashAttention 必需)
|
| 234 |
+
cu_seqlens_q = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
|
| 235 |
if self.is_cross_attention and key_value_states is not None:
|
| 236 |
+
cu_seqlens_k = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_kv, step=seq_len_kv, dtype=torch.int32, device=k.device)
|
| 237 |
else:
|
| 238 |
cu_seqlens_k = cu_seqlens_q
|
| 239 |
|