gary-boon Claude Opus 4.5 commited on
Commit
d1d37a8
·
1 Parent(s): 929ba88

fix: add QKV extraction support for Mistral/Devstral architecture

Browse files

- Add make_separate_proj_hook for models with separate q_proj, k_proj, v_proj
- Handle GQA (Grouped Query Attention) by expanding K/V heads to match Q heads
- Support both CodeGen (combined qkv_proj) and Mistral (separate projections)
- Update both streaming and non-streaming endpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +128 -17
backend/model_service.py CHANGED
@@ -1644,42 +1644,89 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1644
  qkv_captures = {}
1645
  hooks = []
1646
 
1647
- def make_qkv_hook(layer_idx):
 
1648
  def hook(module, input, output):
1649
  try:
1650
- # output shape: [batch, seq_len, 3 * hidden_size]
1651
- # Split into Q, K, V
1652
  if output.dim() != 3:
1653
- return # Skip if unexpected shape
1654
  batch_size, seq_len, hidden = output.shape
1655
  expected_hidden = 3 * n_heads * head_dim
1656
  if hidden != expected_hidden:
1657
- return # Skip if dimensions don't match QKV format
1658
  qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
1659
- # Separate Q, K, V: [batch, seq_len, n_heads, head_dim]
1660
  q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1661
  qkv_captures[layer_idx] = {
1662
- 'q': q[0].detach().cpu(), # Remove batch dim
1663
  'k': k[0].detach().cpu(),
1664
  'v': v[0].detach().cpu()
1665
  }
1666
  except Exception:
1667
- # Silently skip QKV capture if it fails - it's optional data
1668
  pass
1669
  return hook
1670
 
1671
- # Register hooks on all qkv_proj modules (if available)
1672
- # This is model-specific - CodeGen uses different architecture
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1673
  try:
 
1674
  if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
1675
  for layer_idx, layer in enumerate(manager.model.transformer.h):
1676
  if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
1677
- hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
1678
  hooks.append(hook)
1679
  elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
1680
- # GPT-2 style attention
1681
- hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
1682
  hooks.append(hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1683
  except Exception as hook_error:
1684
  logger.warning(f"Could not register QKV hooks: {hook_error}")
1685
 
@@ -2116,7 +2163,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2116
  qkv_captures = {}
2117
  hooks = []
2118
 
2119
- def make_qkv_hook(layer_idx):
 
2120
  def hook(module, input, output):
2121
  try:
2122
  if output.dim() != 3:
@@ -2136,16 +2184,79 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2136
  pass
2137
  return hook
2138
 
2139
- # Register hooks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2140
  try:
 
2141
  if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2142
  for layer_idx, layer in enumerate(manager.model.transformer.h):
2143
  if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
2144
- hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx))
2145
  hooks.append(hook)
2146
  elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
2147
- hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
2148
  hooks.append(hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2149
  except Exception as hook_error:
2150
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2151
 
 
1644
  qkv_captures = {}
1645
  hooks = []
1646
 
1647
+ # Hook for combined QKV projection (CodeGen style)
1648
+ def make_combined_qkv_hook(layer_idx):
1649
  def hook(module, input, output):
1650
  try:
 
 
1651
  if output.dim() != 3:
1652
+ return
1653
  batch_size, seq_len, hidden = output.shape
1654
  expected_hidden = 3 * n_heads * head_dim
1655
  if hidden != expected_hidden:
1656
+ return
1657
  qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
 
1658
  q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1659
  qkv_captures[layer_idx] = {
1660
+ 'q': q[0].detach().cpu(),
1661
  'k': k[0].detach().cpu(),
1662
  'v': v[0].detach().cpu()
1663
  }
1664
  except Exception:
 
1665
  pass
1666
  return hook
1667
 
1668
+ # Hooks for separate Q, K, V projections (Mistral/LLaMA style)
1669
+ def make_separate_proj_hook(layer_idx, proj_type, num_kv_heads=None):
1670
+ def hook(module, input, output):
1671
+ try:
1672
+ if output.dim() != 3:
1673
+ return
1674
+ batch_size, seq_len, hidden = output.shape
1675
+
1676
+ if proj_type == 'q':
1677
+ proj_heads = n_heads
1678
+ else:
1679
+ proj_heads = num_kv_heads if num_kv_heads else n_heads
1680
+
1681
+ proj_head_dim = hidden // proj_heads
1682
+ if hidden != proj_heads * proj_head_dim:
1683
+ return
1684
+
1685
+ proj_output = output.reshape(batch_size, seq_len, proj_heads, proj_head_dim)
1686
+
1687
+ if proj_type != 'q' and num_kv_heads and num_kv_heads < n_heads:
1688
+ repeat_factor = n_heads // num_kv_heads
1689
+ proj_output = proj_output.repeat_interleave(repeat_factor, dim=2)
1690
+
1691
+ if layer_idx not in qkv_captures:
1692
+ qkv_captures[layer_idx] = {}
1693
+
1694
+ qkv_captures[layer_idx][proj_type] = proj_output[0].detach().cpu()
1695
+ except Exception:
1696
+ pass
1697
+ return hook
1698
+
1699
+ # Register hooks based on model architecture
1700
  try:
1701
+ # CodeGen style: model.transformer.h[layer].attn.qkv_proj
1702
  if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
1703
  for layer_idx, layer in enumerate(manager.model.transformer.h):
1704
  if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
1705
+ hook = layer.attn.qkv_proj.register_forward_hook(make_combined_qkv_hook(layer_idx))
1706
  hooks.append(hook)
1707
  elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
1708
+ hook = layer.attn.c_attn.register_forward_hook(make_combined_qkv_hook(layer_idx))
 
1709
  hooks.append(hook)
1710
+
1711
+ # Mistral/LLaMA style: model.model.layers[layer].self_attn.{q,k,v}_proj
1712
+ elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'):
1713
+ num_kv_heads = getattr(manager.model.config, 'num_key_value_heads', None)
1714
+ for layer_idx, layer in enumerate(manager.model.model.layers):
1715
+ if hasattr(layer, 'self_attn'):
1716
+ attn = layer.self_attn
1717
+ if hasattr(attn, 'q_proj'):
1718
+ hook = attn.q_proj.register_forward_hook(
1719
+ make_separate_proj_hook(layer_idx, 'q', num_kv_heads))
1720
+ hooks.append(hook)
1721
+ if hasattr(attn, 'k_proj'):
1722
+ hook = attn.k_proj.register_forward_hook(
1723
+ make_separate_proj_hook(layer_idx, 'k', num_kv_heads))
1724
+ hooks.append(hook)
1725
+ if hasattr(attn, 'v_proj'):
1726
+ hook = attn.v_proj.register_forward_hook(
1727
+ make_separate_proj_hook(layer_idx, 'v', num_kv_heads))
1728
+ hooks.append(hook)
1729
+ logger.info(f"Registered QKV hooks for {len(hooks)//3} Mistral layers (GQA: {num_kv_heads} KV heads)")
1730
  except Exception as hook_error:
1731
  logger.warning(f"Could not register QKV hooks: {hook_error}")
1732
 
 
2163
  qkv_captures = {}
2164
  hooks = []
2165
 
2166
+ # Hook for combined QKV projection (CodeGen style)
2167
+ def make_combined_qkv_hook(layer_idx):
2168
  def hook(module, input, output):
2169
  try:
2170
  if output.dim() != 3:
 
2184
  pass
2185
  return hook
2186
 
2187
+ # Hooks for separate Q, K, V projections (Mistral/LLaMA style)
2188
+ def make_separate_proj_hook(layer_idx, proj_type, num_kv_heads=None):
2189
+ """Create hook for separate Q/K/V projection modules.
2190
+
2191
+ For GQA models, K and V have fewer heads than Q, so we need to
2192
+ expand them to match Q's head count for consistent visualization.
2193
+ """
2194
+ def hook(module, input, output):
2195
+ try:
2196
+ if output.dim() != 3:
2197
+ return
2198
+ batch_size, seq_len, hidden = output.shape
2199
+
2200
+ # Determine number of heads for this projection
2201
+ if proj_type == 'q':
2202
+ proj_heads = n_heads
2203
+ else:
2204
+ # K and V may have fewer heads (GQA)
2205
+ proj_heads = num_kv_heads if num_kv_heads else n_heads
2206
+
2207
+ proj_head_dim = hidden // proj_heads
2208
+ if hidden != proj_heads * proj_head_dim:
2209
+ return
2210
+
2211
+ # Reshape to [batch, seq, heads, head_dim]
2212
+ proj_output = output.reshape(batch_size, seq_len, proj_heads, proj_head_dim)
2213
+
2214
+ # For GQA, expand K/V to match Q's head count
2215
+ if proj_type != 'q' and num_kv_heads and num_kv_heads < n_heads:
2216
+ # Repeat each KV head to match Q heads
2217
+ repeat_factor = n_heads // num_kv_heads
2218
+ proj_output = proj_output.repeat_interleave(repeat_factor, dim=2)
2219
+
2220
+ # Initialize layer entry if needed
2221
+ if layer_idx not in qkv_captures:
2222
+ qkv_captures[layer_idx] = {}
2223
+
2224
+ qkv_captures[layer_idx][proj_type] = proj_output[0].detach().cpu()
2225
+ except Exception as e:
2226
+ logger.debug(f"QKV capture error for layer {layer_idx} {proj_type}: {e}")
2227
+ return hook
2228
+
2229
+ # Register hooks based on model architecture
2230
  try:
2231
+ # CodeGen style: model.transformer.h[layer].attn.qkv_proj
2232
  if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
2233
  for layer_idx, layer in enumerate(manager.model.transformer.h):
2234
  if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
2235
+ hook = layer.attn.qkv_proj.register_forward_hook(make_combined_qkv_hook(layer_idx))
2236
  hooks.append(hook)
2237
  elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
2238
+ hook = layer.attn.c_attn.register_forward_hook(make_combined_qkv_hook(layer_idx))
2239
  hooks.append(hook)
2240
+
2241
+ # Mistral/LLaMA style: model.model.layers[layer].self_attn.{q,k,v}_proj
2242
+ elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'):
2243
+ num_kv_heads = getattr(manager.model.config, 'num_key_value_heads', None)
2244
+ for layer_idx, layer in enumerate(manager.model.model.layers):
2245
+ if hasattr(layer, 'self_attn'):
2246
+ attn = layer.self_attn
2247
+ if hasattr(attn, 'q_proj'):
2248
+ hook = attn.q_proj.register_forward_hook(
2249
+ make_separate_proj_hook(layer_idx, 'q', num_kv_heads))
2250
+ hooks.append(hook)
2251
+ if hasattr(attn, 'k_proj'):
2252
+ hook = attn.k_proj.register_forward_hook(
2253
+ make_separate_proj_hook(layer_idx, 'k', num_kv_heads))
2254
+ hooks.append(hook)
2255
+ if hasattr(attn, 'v_proj'):
2256
+ hook = attn.v_proj.register_forward_hook(
2257
+ make_separate_proj_hook(layer_idx, 'v', num_kv_heads))
2258
+ hooks.append(hook)
2259
+ logger.info(f"Registered QKV hooks for {len(hooks)//3} Mistral layers (GQA: {num_kv_heads} KV heads)")
2260
  except Exception as hook_error:
2261
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2262