Spaces:
Running on T4
Running on T4
fix: cast SDPA output back to original dtype + idempotent patching
Browse files- model_manager.py +13 -4
model_manager.py
CHANGED
|
@@ -104,11 +104,18 @@ class ModelManager:
|
|
| 104 |
),
|
| 105 |
]
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
"\n"
|
| 110 |
" # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
|
| 111 |
" if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
|
|
|
|
| 112 |
" if q_lens is not None or k_lens is not None:\n"
|
| 113 |
' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
|
| 114 |
" q = q.transpose(1, 2).to(dtype)\n"
|
|
@@ -117,7 +124,9 @@ class ModelManager:
|
|
| 117 |
" out = torch.nn.functional.scaled_dot_product_attention(\n"
|
| 118 |
" q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p\n"
|
| 119 |
" )\n"
|
| 120 |
-
" return out.transpose(1, 2).contiguous()\n"
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
for pattern in patterns:
|
|
@@ -128,7 +137,7 @@ class ModelManager:
|
|
| 128 |
print(f"Already patched: {filepath}")
|
| 129 |
continue
|
| 130 |
if target in content:
|
| 131 |
-
content = content.replace(target,
|
| 132 |
with open(filepath, "w") as f:
|
| 133 |
f.write(content)
|
| 134 |
print(f"Patched with SDPA fallback: {filepath}")
|
|
|
|
| 104 |
),
|
| 105 |
]
|
| 106 |
|
| 107 |
+
# Use the assert + next line as target to ensure idempotent patching
|
| 108 |
+
target = (
|
| 109 |
+
' assert q.device.type == "cuda" and q.size(-1) <= 256\n'
|
| 110 |
+
"\n"
|
| 111 |
+
" # params\n"
|
| 112 |
+
)
|
| 113 |
+
replacement = (
|
| 114 |
+
' assert q.device.type == "cuda" and q.size(-1) <= 256\n'
|
| 115 |
"\n"
|
| 116 |
" # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
|
| 117 |
" if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
|
| 118 |
+
" out_dtype = q.dtype\n"
|
| 119 |
" if q_lens is not None or k_lens is not None:\n"
|
| 120 |
' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
|
| 121 |
" q = q.transpose(1, 2).to(dtype)\n"
|
|
|
|
| 124 |
" out = torch.nn.functional.scaled_dot_product_attention(\n"
|
| 125 |
" q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p\n"
|
| 126 |
" )\n"
|
| 127 |
+
" return out.transpose(1, 2).contiguous().to(out_dtype)\n"
|
| 128 |
+
"\n"
|
| 129 |
+
" # params\n"
|
| 130 |
)
|
| 131 |
|
| 132 |
for pattern in patterns:
|
|
|
|
| 137 |
print(f"Already patched: {filepath}")
|
| 138 |
continue
|
| 139 |
if target in content:
|
| 140 |
+
content = content.replace(target, replacement, 1)
|
| 141 |
with open(filepath, "w") as f:
|
| 142 |
f.write(content)
|
| 143 |
print(f"Patched with SDPA fallback: {filepath}")
|