gary-boon Claude Opus 4.5 commited on
Commit
decb5ab
·
1 Parent(s): 9056859

Limit QKV matrices to top 5 heads per layer to reduce response size

Browse files

The QKV visualization fix caused response sizes to explode because we were
sending Q/K/V matrices for all 32 heads × 40 layers = 1280 sets of matrices.
For a 50 token sequence with head_dim=128, that's ~25 million floats per
token step, causing 504 Gateway Timeouts in streaming.

Now only the top 5 heads (by attention weight) per layer retain their QKV
matrices. This reduces the QKV data by ~85% while still providing meaningful
visualization data for the most important attention patterns.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +15 -0
backend/model_service.py CHANGED
@@ -1882,6 +1882,14 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1882
  # Sort by max_weight (return all heads, frontend will decide how many to display)
1883
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
1884
 
 
 
 
 
 
 
 
 
1885
  # Detect layer-level pattern (percentage-based for any layer count)
1886
  layer_pattern = None
1887
  layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
@@ -2423,6 +2431,13 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2423
 
2424
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2425
 
 
 
 
 
 
 
 
2426
  layer_pattern = None
2427
  layer_fraction = (layer_idx + 1) / n_layers
2428
  if layer_idx == 0:
 
1882
  # Sort by max_weight (return all heads, frontend will decide how many to display)
1883
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
1884
 
1885
+ # Only keep QKV matrices for top 5 heads to avoid massive response sizes
1886
+ # (40 layers × 32 heads × 3 matrices × seq_len × head_dim is too much data)
1887
+ for i, head in enumerate(critical_heads):
1888
+ if i >= 5: # Keep QKV only for top 5 heads
1889
+ head["q_matrix"] = None
1890
+ head["k_matrix"] = None
1891
+ head["v_matrix"] = None
1892
+
1893
  # Detect layer-level pattern (percentage-based for any layer count)
1894
  layer_pattern = None
1895
  layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
 
2431
 
2432
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
2433
 
2434
+ # Only keep QKV matrices for top 5 heads to avoid massive response sizes
2435
+ for i, head in enumerate(critical_heads):
2436
+ if i >= 5:
2437
+ head["q_matrix"] = None
2438
+ head["k_matrix"] = None
2439
+ head["v_matrix"] = None
2440
+
2441
  layer_pattern = None
2442
  layer_fraction = (layer_idx + 1) / n_layers
2443
  if layer_idx == 0: