andersonbcdefg commited on
Commit
87cede6
·
1 Parent(s): 22143dc

Upload modeling_flash_llama.py

Browse files
Files changed (1) hide show
  1. modeling_flash_llama.py +10 -0
modeling_flash_llama.py CHANGED
@@ -369,6 +369,12 @@ class LlamaAttention(nn.Module):
369
 
370
  unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
371
  unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
 
 
 
 
 
 
372
  attn_outputs = flash_attn_varlen_kvpacked_func(
373
  unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
374
  max_seqlen_q, max_seqlen_k,
@@ -385,6 +391,10 @@ class LlamaAttention(nn.Module):
385
  else:
386
 
387
  # no padding tokens, more efficient
 
 
 
 
388
 
389
  attn_outputs = flash_attn_kvpacked_func(
390
  q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
 
369
 
370
  unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
371
  unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
372
+ # cast kv and q to bf16 or fp16 if currently in float32
373
+ if unpadded_kv.dtype == torch.float32:
374
+ unpadded_kv = unpadded_kv.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
375
+ unpadded_q = unpadded_q.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
376
+
377
+
378
  attn_outputs = flash_attn_varlen_kvpacked_func(
379
  unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
380
  max_seqlen_q, max_seqlen_k,
 
391
  else:
392
 
393
  # no padding tokens, more efficient
394
+ # cast to bf16 or fp16 if currently in float32
395
+ if kv.dtype == torch.float32:
396
+ kv = kv.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
397
+ q = q.to(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
398
 
399
  attn_outputs = flash_attn_kvpacked_func(
400
  q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)