andersonbcdefg commited on
Commit
a640dac
·
1 Parent(s): b4c7bb4

Upload modeling_flash_llama.py

Browse files
Files changed (1) hide show
  1. modeling_flash_llama.py +8 -41
modeling_flash_llama.py CHANGED
@@ -361,47 +361,14 @@ class LlamaAttention(nn.Module):
361
 
362
  past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None
363
 
364
- if is_padded_inputs:
365
-
366
- # varlen, ignore padding tokens, efficient for large batch with many paddings
367
-
368
- assert attention_mask is not None
369
-
370
- unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
371
- unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
372
- # cast kv and q to bf16 or fp16 if currently in float32
373
- if unpadded_kv.dtype == torch.float32:
374
- unpadded_kv = unpadded_kv.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
375
- unpadded_q = unpadded_q.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
376
-
377
-
378
- attn_outputs = flash_attn_varlen_kvpacked_func(
379
- unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
380
- max_seqlen_q, max_seqlen_k,
381
- dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
382
- causal=(not has_layer_past), return_attn_probs=output_attentions
383
- )
384
-
385
- attn_output = attn_outputs[0] if output_attentions else attn_outputs
386
- attn_output = pad_input(
387
- attn_output, indices_q, bsz, max_seqlen_q
388
- ).reshape(bsz, q_len, h_size)
389
- attn_weights = attn_outputs[2] if output_attentions else None
390
-
391
- else:
392
-
393
- # no padding tokens, more efficient
394
- # cast to bf16 or fp16 if currently in float32
395
- if kv.dtype == torch.float32:
396
- kv = kv.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
397
- q = q.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
398
-
399
- attn_outputs = flash_attn_kvpacked_func(
400
- q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
401
-
402
- attn_output = attn_outputs[0] if output_attentions else attn_outputs
403
- attn_output = attn_output.reshape(bsz, q_len, h_size)
404
- attn_weights = attn_outputs[2] if output_attentions else None
405
 
406
  if self.config.pretraining_tp > 1:
407
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 
361
 
362
  past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None
363
 
364
+ # no padding tokens, more efficient
365
+ attn_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
366
+ attn_outputs = flash_attn_kvpacked_func(
367
+ q.type(attn_dtype), kv.type(attn_dtype), dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
368
+
369
+ attn_output = attn_outputs[0] if output_attentions else attn_outputs
370
+ attn_output = attn_output.reshape(bsz, q_len, h_size)
371
+ attn_weights = attn_outputs[2] if output_attentions else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  if self.config.pretraining_tp > 1:
374
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)