jiang-cc commited on
Commit
6a496d2
·
verified ·
1 Parent(s): 9a6e3c7

Upload processor

Browse files
Files changed (2) hide show
  1. modeling_yangjian.py +13 -1
  2. tokenizer_config.json +4 -0
modeling_yangjian.py CHANGED
@@ -230,6 +230,13 @@ class OptimizedCrossAttention(nn.Module):
230
  if hasattr(self.config, '_attn_implementation') and self.config._attn_implementation != "eager":
231
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
232
 
 
 
 
 
 
 
 
233
  # 执行 attention 计算
234
  attn_output, _ = attention_interface(
235
  self,
@@ -237,13 +244,18 @@ class OptimizedCrossAttention(nn.Module):
237
  k,
238
  v,
239
  attention_mask=attention_mask,
 
 
 
 
240
  dropout=0.0 if not self.training else self.attention_dropout,
241
  scaling=self.scaling,
242
  is_causal=False,
243
  **kwargs,
244
  )
245
 
246
- # 重塑输出
 
247
  attn_output = attn_output.transpose(1, 2).contiguous() # [batch_size, seq_len_q, num_heads, head_dim]
248
  attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) # [batch_size, seq_len_q, hidden_size]
249
 
 
230
  if hasattr(self.config, '_attn_implementation') and self.config._attn_implementation != "eager":
231
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
232
 
233
+ # 构造 cu_seqlens 参数(FlashAttention 必需)
234
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
235
+ if self.is_cross_attention and key_value_states is not None:
236
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * seq_len_kv, step=seq_len_kv, dtype=torch.int32, device=k.device)
237
+ else:
238
+ cu_seqlens_k = cu_seqlens_q
239
+
240
  # 执行 attention 计算
241
  attn_output, _ = attention_interface(
242
  self,
 
244
  k,
245
  v,
246
  attention_mask=attention_mask,
247
+ cu_seqlens_q=cu_seqlens_q,
248
+ cu_seqlens_k=cu_seqlens_k,
249
+ max_seqlen_q=seq_len_q,
250
+ max_seqlen_k=seq_len_kv if self.is_cross_attention and key_value_states is not None else seq_len_q,
251
  dropout=0.0 if not self.training else self.attention_dropout,
252
  scaling=self.scaling,
253
  is_causal=False,
254
  **kwargs,
255
  )
256
 
257
+ attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len_q, self.head_dim)
258
+
259
  attn_output = attn_output.transpose(1, 2).contiguous() # [batch_size, seq_len_q, num_heads, head_dim]
260
  attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) # [batch_size, seq_len_q, hidden_size]
261
 
tokenizer_config.json CHANGED
@@ -202,8 +202,12 @@
202
  "eos_token": "<|im_end|>",
203
  "errors": "replace",
204
  "extra_special_tokens": {},
 
205
  "model_max_length": 131072,
 
206
  "pad_token": "<|endoftext|>",
 
 
207
  "processor_class": "YangJianProcessor",
208
  "split_special_tokens": false,
209
  "tokenizer_class": "Qwen2Tokenizer",
 
202
  "eos_token": "<|im_end|>",
203
  "errors": "replace",
204
  "extra_special_tokens": {},
205
+ "max_length": null,
206
  "model_max_length": 131072,
207
+ "pad_to_multiple_of": null,
208
  "pad_token": "<|endoftext|>",
209
+ "pad_token_type_id": 0,
210
+ "padding_side": "right",
211
  "processor_class": "YangJianProcessor",
212
  "split_special_tokens": false,
213
  "tokenizer_class": "Qwen2Tokenizer",