Spaces:
Sleeping
Sleeping
gary-boon
Claude
commited on
Commit
·
343dd57
1
Parent(s):
f5ba954
Make QKV hook robust against shape mismatches
Browse files- Add try-except around entire hook body
- Validate tensor dimensions before reshape
- Skip QKV capture silently if format doesn't match
- Prevents crashes during forward pass with different model architectures
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- backend/model_service.py +20 -11
backend/model_service.py
CHANGED
|
@@ -1314,17 +1314,26 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1314 |
|
| 1315 |
def make_qkv_hook(layer_idx):
|
| 1316 |
def hook(module, input, output):
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1328 |
return hook
|
| 1329 |
|
| 1330 |
# Register hooks on all qkv_proj modules (if available)
|
|
|
|
| 1314 |
|
| 1315 |
def make_qkv_hook(layer_idx):
|
| 1316 |
def hook(module, input, output):
|
| 1317 |
+
try:
|
| 1318 |
+
# output shape: [batch, seq_len, 3 * hidden_size]
|
| 1319 |
+
# Split into Q, K, V
|
| 1320 |
+
if output.dim() != 3:
|
| 1321 |
+
return # Skip if unexpected shape
|
| 1322 |
+
batch_size, seq_len, hidden = output.shape
|
| 1323 |
+
expected_hidden = 3 * n_heads * head_dim
|
| 1324 |
+
if hidden != expected_hidden:
|
| 1325 |
+
return # Skip if dimensions don't match QKV format
|
| 1326 |
+
qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
|
| 1327 |
+
# Separate Q, K, V: [batch, seq_len, n_heads, head_dim]
|
| 1328 |
+
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
| 1329 |
+
qkv_captures[layer_idx] = {
|
| 1330 |
+
'q': q[0].detach().cpu(), # Remove batch dim
|
| 1331 |
+
'k': k[0].detach().cpu(),
|
| 1332 |
+
'v': v[0].detach().cpu()
|
| 1333 |
+
}
|
| 1334 |
+
except Exception:
|
| 1335 |
+
# Silently skip QKV capture if it fails - it's optional data
|
| 1336 |
+
pass
|
| 1337 |
return hook
|
| 1338 |
|
| 1339 |
# Register hooks on all qkv_proj modules (if available)
|