jiang-cc commited on
Commit
6e620d4
·
verified ·
1 Parent(s): 6a496d2

Upload YangJianVLForConditionalGeneration

Browse files
Files changed (1) hide show
  1. modeling_yangjian.py +2 -2
modeling_yangjian.py CHANGED
@@ -231,9 +231,9 @@ class OptimizedCrossAttention(nn.Module):
231
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
232
 
233
  # 构造 cu_seqlens 参数(FlashAttention 必需)
234
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
235
  if self.is_cross_attention and key_value_states is not None:
236
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * seq_len_kv, step=seq_len_kv, dtype=torch.int32, device=k.device)
237
  else:
238
  cu_seqlens_k = cu_seqlens_q
239
 
 
231
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
232
 
233
  # 构造 cu_seqlens 参数(FlashAttention 必需)
234
+ cu_seqlens_q = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
235
  if self.is_cross_attention and key_value_states is not None:
236
+ cu_seqlens_k = torch.arange(0, (batch_size*self.num_heads + 1) * seq_len_kv, step=seq_len_kv, dtype=torch.int32, device=k.device)
237
  else:
238
  cu_seqlens_k = cu_seqlens_q
239