a0y0346 commited on
Commit
685194e
·
1 Parent(s): 374d38b

fix: Add fallback SDPA benchmark when attention layer fails

Browse files

- Improved benchmark_attention_layer error handling with logging
- Added fallback to F.scaled_dot_product_attention using real model
dimensions when attention layer forward pass fails
- This ensures benchmarks still work with model's actual head
configuration even if layer-level benchmarking has issues

Files changed (2) hide show
  1. src/attention_utils.py +22 -10
  2. src/benchmark.py +106 -19
src/attention_utils.py CHANGED
@@ -143,6 +143,18 @@ def benchmark_attention_layer(
143
 
144
  enable_math, enable_flash, enable_mem_efficient = backend_flags[backend]
145
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  try:
147
  # Warmup
148
  with torch.backends.cuda.sdp_kernel(
@@ -152,11 +164,7 @@ def benchmark_attention_layer(
152
  ):
153
  with torch.no_grad():
154
  for _ in range(warmup_iterations):
155
- _ = attention_layer(
156
- hidden_states,
157
- position_ids=position_ids,
158
- attention_mask=attention_mask,
159
- )
160
 
161
  torch.cuda.synchronize()
162
  torch.cuda.reset_peak_memory_stats()
@@ -173,11 +181,7 @@ def benchmark_attention_layer(
173
  with torch.no_grad():
174
  start.record()
175
  for _ in range(num_iterations):
176
- output = attention_layer(
177
- hidden_states,
178
- position_ids=position_ids,
179
- attention_mask=attention_mask,
180
- )
181
  end.record()
182
 
183
  torch.cuda.synchronize()
@@ -193,7 +197,10 @@ def benchmark_attention_layer(
193
  }
194
 
195
  except Exception as e:
 
196
  error_msg = str(e)
 
 
197
  # Common error: Flash attention not available on certain GPUs
198
  if "flash" in error_msg.lower() or "sm75" in error_msg.lower():
199
  return {
@@ -202,6 +209,11 @@ def benchmark_attention_layer(
202
  "status": f"unsupported: {error_msg[:80]}",
203
  "backend": backend,
204
  }
 
 
 
 
 
205
  return {
206
  "time_ms": None,
207
  "memory_mb": None,
 
143
 
144
  enable_math, enable_flash, enable_mem_efficient = backend_flags[backend]
145
 
146
+ def run_attention():
147
+ """Run attention with fallback for different call signatures."""
148
+ try:
149
+ # Try standard call with position_ids
150
+ return attention_layer(
151
+ hidden_states,
152
+ position_ids=position_ids,
153
+ )
154
+ except TypeError:
155
+ # Fallback: just hidden_states
156
+ return attention_layer(hidden_states)
157
+
158
  try:
159
  # Warmup
160
  with torch.backends.cuda.sdp_kernel(
 
164
  ):
165
  with torch.no_grad():
166
  for _ in range(warmup_iterations):
167
+ _ = run_attention()
 
 
 
 
168
 
169
  torch.cuda.synchronize()
170
  torch.cuda.reset_peak_memory_stats()
 
181
  with torch.no_grad():
182
  start.record()
183
  for _ in range(num_iterations):
184
+ output = run_attention()
 
 
 
 
185
  end.record()
186
 
187
  torch.cuda.synchronize()
 
197
  }
198
 
199
  except Exception as e:
200
+ import traceback
201
  error_msg = str(e)
202
+ tb = traceback.format_exc()
203
+
204
  # Common error: Flash attention not available on certain GPUs
205
  if "flash" in error_msg.lower() or "sm75" in error_msg.lower():
206
  return {
 
209
  "status": f"unsupported: {error_msg[:80]}",
210
  "backend": backend,
211
  }
212
+
213
+ # Log detailed error for debugging
214
+ print(f"[benchmark_attention_layer] Error for {backend}: {error_msg}")
215
+ print(f"[benchmark_attention_layer] Traceback: {tb[:500]}")
216
+
217
  return {
218
  "time_ms": None,
219
  "memory_mb": None,
src/benchmark.py CHANGED
@@ -190,41 +190,128 @@ def run_attention_benchmark(
190
  device = torch.device("cuda")
191
  dtype = torch.float16
192
 
193
- # If model_name is provided, use real model attention layer
194
  if model_name is not None and model_name in MODEL_CONFIGS:
195
  try:
196
  # Load the real HuggingFace model
197
  model = load_model(model_name)
198
 
199
- # Extract attention layer from layer 0
200
- attention_layer = extract_attention_layer(model, layer_idx=0)
201
-
202
- # Get model attention info
203
  attn_info = get_model_attention_info(model)
204
 
205
- # Create proper inputs for the attention layer
206
- hidden_states, position_ids = create_attention_inputs(
207
- model, batch_size, seq_len, device, dtype
208
- )
209
 
210
  results = {"model_name": model_name, "using_real_model": True}
211
  results["model_info"] = attn_info
212
 
213
- # Benchmark each backend using the real attention layer
214
- for backend in ["math", "flash", "mem_efficient"]:
215
- result = benchmark_attention_layer(
 
 
 
 
 
 
 
216
  attention_layer=attention_layer,
217
  hidden_states=hidden_states,
218
  position_ids=position_ids,
219
- backend=backend,
220
- num_iterations=num_iterations,
221
- warmup_iterations=warmup_iterations,
222
  )
223
- results[backend] = result
 
 
 
 
 
 
 
 
 
224
 
225
- # Clean up inputs
226
- del hidden_states, position_ids
227
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  # Calculate speedups
230
  if results.get("math", {}).get("time_ms"):
 
190
  device = torch.device("cuda")
191
  dtype = torch.float16
192
 
193
+ # If model_name is provided, use real model dimensions for benchmarking
194
  if model_name is not None and model_name in MODEL_CONFIGS:
195
  try:
196
  # Load the real HuggingFace model
197
  model = load_model(model_name)
198
 
199
+ # Get model attention info for real dimensions
 
 
 
200
  attn_info = get_model_attention_info(model)
201
 
202
+ # Extract dimensions from real model
203
+ model_num_heads = attn_info["num_attention_heads"]
204
+ model_head_dim = attn_info["head_dim"]
 
205
 
206
  results = {"model_name": model_name, "using_real_model": True}
207
  results["model_info"] = attn_info
208
 
209
+ # First try: Use actual attention layer forward pass
210
+ attention_layer_works = False
211
+ try:
212
+ attention_layer = extract_attention_layer(model, layer_idx=0)
213
+ hidden_states, position_ids = create_attention_inputs(
214
+ model, batch_size, seq_len, device, dtype
215
+ )
216
+
217
+ # Test if attention layer works with first backend
218
+ test_result = benchmark_attention_layer(
219
  attention_layer=attention_layer,
220
  hidden_states=hidden_states,
221
  position_ids=position_ids,
222
+ backend="flash",
223
+ num_iterations=2,
224
+ warmup_iterations=1,
225
  )
226
+
227
+ if test_result.get("time_ms") is not None:
228
+ attention_layer_works = True
229
+
230
+ del hidden_states, position_ids
231
+ torch.cuda.empty_cache()
232
+
233
+ except Exception as layer_error:
234
+ print(f"[run_attention_benchmark] Attention layer extraction failed: {layer_error}")
235
+ attention_layer_works = False
236
 
237
+ if attention_layer_works:
238
+ # Use actual attention layer
239
+ hidden_states, position_ids = create_attention_inputs(
240
+ model, batch_size, seq_len, device, dtype
241
+ )
242
+
243
+ for backend in ["math", "flash", "mem_efficient"]:
244
+ result = benchmark_attention_layer(
245
+ attention_layer=attention_layer,
246
+ hidden_states=hidden_states,
247
+ position_ids=position_ids,
248
+ backend=backend,
249
+ num_iterations=num_iterations,
250
+ warmup_iterations=warmup_iterations,
251
+ )
252
+ results[backend] = result
253
+
254
+ del hidden_states, position_ids
255
+ torch.cuda.empty_cache()
256
+ else:
257
+ # Fallback: Use F.scaled_dot_product_attention with real model dimensions
258
+ print(f"[run_attention_benchmark] Falling back to SDPA with model dimensions")
259
+ results["fallback_mode"] = True
260
+
261
+ # Create Q, K, V tensors with real model dimensions
262
+ Q = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype)
263
+ K = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype)
264
+ V = torch.randn(batch_size, model_num_heads, seq_len, model_head_dim, device=device, dtype=dtype)
265
+
266
+ backends = [
267
+ ("math", True, False, False),
268
+ ("flash", False, True, False),
269
+ ("mem_efficient", False, False, True),
270
+ ]
271
+
272
+ for backend_name, enable_math, enable_flash, enable_mem_efficient in backends:
273
+ try:
274
+ torch.cuda.reset_peak_memory_stats()
275
+ torch.cuda.synchronize()
276
+
277
+ with torch.backends.cuda.sdp_kernel(
278
+ enable_flash=enable_flash,
279
+ enable_math=enable_math,
280
+ enable_mem_efficient=enable_mem_efficient
281
+ ):
282
+ # Warmup
283
+ for _ in range(warmup_iterations):
284
+ _ = F.scaled_dot_product_attention(Q, K, V)
285
+ torch.cuda.synchronize()
286
+
287
+ # Timed runs
288
+ start = torch.cuda.Event(enable_timing=True)
289
+ end = torch.cuda.Event(enable_timing=True)
290
+
291
+ start.record()
292
+ for _ in range(num_iterations):
293
+ _ = F.scaled_dot_product_attention(Q, K, V)
294
+ end.record()
295
+ torch.cuda.synchronize()
296
+
297
+ time_ms = start.elapsed_time(end) / num_iterations
298
+ memory_mb = torch.cuda.max_memory_allocated() / 1e6
299
+
300
+ results[backend_name] = {
301
+ "time_ms": round(time_ms, 3),
302
+ "memory_mb": round(memory_mb, 1),
303
+ "status": "success"
304
+ }
305
+
306
+ except Exception as e:
307
+ results[backend_name] = {
308
+ "time_ms": None,
309
+ "memory_mb": None,
310
+ "status": f"error: {str(e)[:50]}"
311
+ }
312
+
313
+ del Q, K, V
314
+ torch.cuda.empty_cache()
315
 
316
  # Calculate speedups
317
  if results.get("math", {}).get("time_ms"):