H-Liu1997 commited on
Commit
6c6483b
·
1 Parent(s): 7237651

fix: build proper attention mask in SDPA fallback for text cross-attention

Browse files
Files changed (1) hide show
  1. model_manager.py +12 -3
model_manager.py CHANGED
@@ -122,13 +122,22 @@ class ModelManager:
122
  " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
123
  " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
124
  " out_dtype = q.dtype\n"
125
- " if q_lens is not None or k_lens is not None:\n"
126
- ' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
127
  " q = q.transpose(1, 2).to(dtype)\n"
128
  " k = k.transpose(1, 2).to(dtype)\n"
129
  " v = v.transpose(1, 2).to(dtype)\n"
 
 
 
 
 
 
 
 
 
130
  " out = torch.nn.functional.scaled_dot_product_attention(\n"
131
- " q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p\n"
132
  " )\n"
133
  " return out.transpose(1, 2).contiguous().to(out_dtype)\n"
134
  "\n"
 
122
  " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
123
  " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
124
  " out_dtype = q.dtype\n"
125
+ " b, lq, nq, c = q.shape\n"
126
+ " lk = k.size(1)\n"
127
  " q = q.transpose(1, 2).to(dtype)\n"
128
  " k = k.transpose(1, 2).to(dtype)\n"
129
  " v = v.transpose(1, 2).to(dtype)\n"
130
+ " attn_mask = None\n"
131
+ " is_causal_flag = causal\n"
132
+ " if k_lens is not None:\n"
133
+ " valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n"
134
+ " attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n"
135
+ " is_causal_flag = False\n"
136
+ " if causal:\n"
137
+ " cm = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)\n"
138
+ " attn_mask = attn_mask.masked_fill(cm[None, None, :, :], float('-inf'))\n"
139
  " out = torch.nn.functional.scaled_dot_product_attention(\n"
140
+ " q, k, v, attn_mask=attn_mask, is_causal=is_causal_flag, dropout_p=dropout_p\n"
141
  " )\n"
142
  " return out.transpose(1, 2).contiguous().to(out_dtype)\n"
143
  "\n"