H-Liu1997 commited on
Commit
e843211
·
1 Parent(s): bb7e158

fix: cast SDPA output back to original dtype + idempotent patching

Browse files
Files changed (1) hide show
  1. model_manager.py +13 -4
model_manager.py CHANGED
@@ -104,11 +104,18 @@ class ModelManager:
104
  ),
105
  ]
106
 
107
- target = ' assert q.device.type == "cuda" and q.size(-1) <= 256'
108
- sdpa_fallback = target + "\n" + (
 
 
 
 
 
 
109
  "\n"
110
  " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
111
  " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
 
112
  " if q_lens is not None or k_lens is not None:\n"
113
  ' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
114
  " q = q.transpose(1, 2).to(dtype)\n"
@@ -117,7 +124,9 @@ class ModelManager:
117
  " out = torch.nn.functional.scaled_dot_product_attention(\n"
118
  " q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p\n"
119
  " )\n"
120
- " return out.transpose(1, 2).contiguous()\n"
 
 
121
  )
122
 
123
  for pattern in patterns:
@@ -128,7 +137,7 @@ class ModelManager:
128
  print(f"Already patched: {filepath}")
129
  continue
130
  if target in content:
131
- content = content.replace(target, sdpa_fallback, 1)
132
  with open(filepath, "w") as f:
133
  f.write(content)
134
  print(f"Patched with SDPA fallback: {filepath}")
 
104
  ),
105
  ]
106
 
107
+ # Use the assert + next line as target to ensure idempotent patching
108
+ target = (
109
+ ' assert q.device.type == "cuda" and q.size(-1) <= 256\n'
110
+ "\n"
111
+ " # params\n"
112
+ )
113
+ replacement = (
114
+ ' assert q.device.type == "cuda" and q.size(-1) <= 256\n'
115
  "\n"
116
  " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
117
  " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
118
+ " out_dtype = q.dtype\n"
119
  " if q_lens is not None or k_lens is not None:\n"
120
  ' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
121
  " q = q.transpose(1, 2).to(dtype)\n"
 
124
  " out = torch.nn.functional.scaled_dot_product_attention(\n"
125
  " q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p\n"
126
  " )\n"
127
+ " return out.transpose(1, 2).contiguous().to(out_dtype)\n"
128
+ "\n"
129
+ " # params\n"
130
  )
131
 
132
  for pattern in patterns:
 
137
  print(f"Already patched: {filepath}")
138
  continue
139
  if target in content:
140
+ content = content.replace(target, replacement, 1)
141
  with open(filepath, "w") as f:
142
  f.write(content)
143
  print(f"Patched with SDPA fallback: {filepath}")