H-Liu1997 commited on
Commit
80c3e53
·
1 Parent(s): 3ab7701

perf: use float16 instead of bfloat16 in SDPA for T4 tensor core acceleration

Browse files
Files changed (1) hide show
  1. model_manager.py +3 -0
model_manager.py CHANGED
@@ -122,6 +122,9 @@ class ModelManager:
122
  " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
123
  " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
124
  " out_dtype = q.dtype\n"
 
 
 
125
  " if q_lens is not None or k_lens is not None:\n"
126
  ' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
127
  " q = q.transpose(1, 2).to(dtype)\n"
 
122
  " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
123
  " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
124
  " out_dtype = q.dtype\n"
125
+ " # T4 lacks native bfloat16; use float16 for tensor core acceleration\n"
126
+ " if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():\n"
127
+ " dtype = torch.float16\n"
128
  " if q_lens is not None or k_lens is not None:\n"
129
  ' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
130
  " q = q.transpose(1, 2).to(dtype)\n"