Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
decb5ab
1
Parent(s):
9056859
Limit QKV matrices to top 5 heads per layer to reduce response size
Browse filesThe QKV visualization fix caused response sizes to explode because we were
sending Q/K/V matrices for all 32 heads × 40 layers = 1280 sets of matrices.
For a 50 token sequence with head_dim=128, that's ~25 million floats per
token step, causing 504 Gateway Timeouts in streaming.
Now only the top 5 heads (by attention weight) per layer retain their QKV
matrices. This reduces the QKV data by ~85% while still providing meaningful
visualization data for the most important attention patterns.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/model_service.py +15 -0
backend/model_service.py
CHANGED
|
@@ -1882,6 +1882,14 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1882 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 1883 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 1884 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1885 |
# Detect layer-level pattern (percentage-based for any layer count)
|
| 1886 |
layer_pattern = None
|
| 1887 |
layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
|
|
@@ -2423,6 +2431,13 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2423 |
|
| 2424 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2426 |
layer_pattern = None
|
| 2427 |
layer_fraction = (layer_idx + 1) / n_layers
|
| 2428 |
if layer_idx == 0:
|
|
|
|
| 1882 |
# Sort by max_weight (return all heads, frontend will decide how many to display)
|
| 1883 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 1884 |
|
| 1885 |
+
# Only keep QKV matrices for top 5 heads to avoid massive response sizes
|
| 1886 |
+
# (40 layers × 32 heads × 3 matrices × seq_len × head_dim is too much data)
|
| 1887 |
+
for i, head in enumerate(critical_heads):
|
| 1888 |
+
if i >= 5: # Keep QKV only for top 5 heads
|
| 1889 |
+
head["q_matrix"] = None
|
| 1890 |
+
head["k_matrix"] = None
|
| 1891 |
+
head["v_matrix"] = None
|
| 1892 |
+
|
| 1893 |
# Detect layer-level pattern (percentage-based for any layer count)
|
| 1894 |
layer_pattern = None
|
| 1895 |
layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
|
|
|
|
| 2431 |
|
| 2432 |
critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
|
| 2433 |
|
| 2434 |
+
# Only keep QKV matrices for top 5 heads to avoid massive response sizes
|
| 2435 |
+
for i, head in enumerate(critical_heads):
|
| 2436 |
+
if i >= 5:
|
| 2437 |
+
head["q_matrix"] = None
|
| 2438 |
+
head["k_matrix"] = None
|
| 2439 |
+
head["v_matrix"] = None
|
| 2440 |
+
|
| 2441 |
layer_pattern = None
|
| 2442 |
layer_fraction = (layer_idx + 1) / n_layers
|
| 2443 |
if layer_idx == 0:
|