gary-boon Claude commited on
Commit
343dd57
·
1 Parent(s): f5ba954

Make QKV hook robust against shape mismatches

Browse files

- Add try-except around entire hook body
- Validate tensor dimensions before reshape
- Skip QKV capture silently if format doesn't match
- Prevents crashes during forward pass with different model architectures

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +20 -11
backend/model_service.py CHANGED
@@ -1314,17 +1314,26 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1314
 
1315
  def make_qkv_hook(layer_idx):
1316
  def hook(module, input, output):
1317
- # output shape: [batch, seq_len, 3 * hidden_size]
1318
- # Split into Q, K, V
1319
- batch_size, seq_len, _ = output.shape
1320
- qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
1321
- # Separate Q, K, V: [batch, seq_len, n_heads, head_dim]
1322
- q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1323
- qkv_captures[layer_idx] = {
1324
- 'q': q[0].detach().cpu(), # Remove batch dim
1325
- 'k': k[0].detach().cpu(),
1326
- 'v': v[0].detach().cpu()
1327
- }
 
 
 
 
 
 
 
 
 
1328
  return hook
1329
 
1330
  # Register hooks on all qkv_proj modules (if available)
 
1314
 
1315
  def make_qkv_hook(layer_idx):
1316
  def hook(module, input, output):
1317
+ try:
1318
+ # output shape: [batch, seq_len, 3 * hidden_size]
1319
+ # Split into Q, K, V
1320
+ if output.dim() != 3:
1321
+ return # Skip if unexpected shape
1322
+ batch_size, seq_len, hidden = output.shape
1323
+ expected_hidden = 3 * n_heads * head_dim
1324
+ if hidden != expected_hidden:
1325
+ return # Skip if dimensions don't match QKV format
1326
+ qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
1327
+ # Separate Q, K, V: [batch, seq_len, n_heads, head_dim]
1328
+ q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1329
+ qkv_captures[layer_idx] = {
1330
+ 'q': q[0].detach().cpu(), # Remove batch dim
1331
+ 'k': k[0].detach().cpu(),
1332
+ 'v': v[0].detach().cpu()
1333
+ }
1334
+ except Exception:
1335
+ # Silently skip QKV capture if it fails - it's optional data
1336
+ pass
1337
  return hook
1338
 
1339
  # Register hooks on all qkv_proj modules (if available)