Spaces:
Running on T4
Running on T4
fix: build proper attention mask in SDPA fallback for text cross-attention
Browse files- model_manager.py +12 -3
model_manager.py
CHANGED
|
@@ -122,13 +122,22 @@ class ModelManager:
|
|
| 122 |
" # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
|
| 123 |
" if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
|
| 124 |
" out_dtype = q.dtype\n"
|
| 125 |
-
"
|
| 126 |
-
|
| 127 |
" q = q.transpose(1, 2).to(dtype)\n"
|
| 128 |
" k = k.transpose(1, 2).to(dtype)\n"
|
| 129 |
" v = v.transpose(1, 2).to(dtype)\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
" out = torch.nn.functional.scaled_dot_product_attention(\n"
|
| 131 |
-
" q, k, v, attn_mask=
|
| 132 |
" )\n"
|
| 133 |
" return out.transpose(1, 2).contiguous().to(out_dtype)\n"
|
| 134 |
"\n"
|
|
|
|
| 122 |
" # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
|
| 123 |
" if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
|
| 124 |
" out_dtype = q.dtype\n"
|
| 125 |
+
" b, lq, nq, c = q.shape\n"
|
| 126 |
+
" lk = k.size(1)\n"
|
| 127 |
" q = q.transpose(1, 2).to(dtype)\n"
|
| 128 |
" k = k.transpose(1, 2).to(dtype)\n"
|
| 129 |
" v = v.transpose(1, 2).to(dtype)\n"
|
| 130 |
+
" attn_mask = None\n"
|
| 131 |
+
" is_causal_flag = causal\n"
|
| 132 |
+
" if k_lens is not None:\n"
|
| 133 |
+
" valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
|
| 134 |
+
" attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
|
| 135 |
+
" is_causal_flag = False\n"
|
| 136 |
+
" if causal:\n"
|
| 137 |
+
" cm = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)\n"
|
| 138 |
+
" attn_mask = attn_mask.masked_fill(cm[None, None, :, :], float('-inf'))\n"
|
| 139 |
" out = torch.nn.functional.scaled_dot_product_attention(\n"
|
| 140 |
+
" q, k, v, attn_mask=attn_mask, is_causal=is_causal_flag, dropout_p=dropout_p\n"
|
| 141 |
" )\n"
|
| 142 |
" return out.transpose(1, 2).contiguous().to(out_dtype)\n"
|
| 143 |
"\n"
|