Spaces:
Running on T4
Running on T4
perf: use float16 instead of bfloat16 in SDPA for T4 tensor core acceleration
Browse files- model_manager.py +3 -0
model_manager.py
CHANGED
|
@@ -122,6 +122,9 @@ class ModelManager:
|
|
| 122 |
" # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
|
| 123 |
" if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
|
| 124 |
" out_dtype = q.dtype\n"
|
|
|
|
|
|
|
|
|
|
| 125 |
" if q_lens is not None or k_lens is not None:\n"
|
| 126 |
' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
|
| 127 |
" q = q.transpose(1, 2).to(dtype)\n"
|
|
|
|
| 122 |
" # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
|
| 123 |
" if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
|
| 124 |
" out_dtype = q.dtype\n"
|
| 125 |
+
" # T4 lacks native bfloat16; use float16 for tensor core acceleration\n"
|
| 126 |
+
" if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():\n"
|
| 127 |
+
" dtype = torch.float16\n"
|
| 128 |
" if q_lens is not None or k_lens is not None:\n"
|
| 129 |
' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
|
| 130 |
" q = q.transpose(1, 2).to(dtype)\n"
|