Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
d1d37a8
1
Parent(s):
929ba88
fix: add QKV extraction support for Mistral/Devstral architecture
Browse files- Add make_separate_proj_hook for models with separate q_proj, k_proj, v_proj
- Handle GQA (Grouped Query Attention) by expanding K/V heads to match Q heads
- Support both CodeGen (combined qkv_proj) and Mistral (separate projections)
- Update both streaming and non-streaming endpoints
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/model_service.py +128 -17
backend/model_service.py
CHANGED
|
@@ -1644,42 +1644,89 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1644 |
qkv_captures = {}
|
| 1645 |
hooks = []
|
| 1646 |
|
| 1647 |
-
|
|
|
|
| 1648 |
def hook(module, input, output):
|
| 1649 |
try:
|
| 1650 |
-
# output shape: [batch, seq_len, 3 * hidden_size]
|
| 1651 |
-
# Split into Q, K, V
|
| 1652 |
if output.dim() != 3:
|
| 1653 |
-
return
|
| 1654 |
batch_size, seq_len, hidden = output.shape
|
| 1655 |
expected_hidden = 3 * n_heads * head_dim
|
| 1656 |
if hidden != expected_hidden:
|
| 1657 |
-
return
|
| 1658 |
qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
|
| 1659 |
-
# Separate Q, K, V: [batch, seq_len, n_heads, head_dim]
|
| 1660 |
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
| 1661 |
qkv_captures[layer_idx] = {
|
| 1662 |
-
'q': q[0].detach().cpu(),
|
| 1663 |
'k': k[0].detach().cpu(),
|
| 1664 |
'v': v[0].detach().cpu()
|
| 1665 |
}
|
| 1666 |
except Exception:
|
| 1667 |
-
# Silently skip QKV capture if it fails - it's optional data
|
| 1668 |
pass
|
| 1669 |
return hook
|
| 1670 |
|
| 1671 |
-
#
|
| 1672 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1673 |
try:
|
|
|
|
| 1674 |
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 1675 |
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 1676 |
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 1677 |
-
hook = layer.attn.qkv_proj.register_forward_hook(
|
| 1678 |
hooks.append(hook)
|
| 1679 |
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 1680 |
-
|
| 1681 |
-
hook = layer.attn.c_attn.register_forward_hook(make_qkv_hook(layer_idx))
|
| 1682 |
hooks.append(hook)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1683 |
except Exception as hook_error:
|
| 1684 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 1685 |
|
|
@@ -2116,7 +2163,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2116 |
qkv_captures = {}
|
| 2117 |
hooks = []
|
| 2118 |
|
| 2119 |
-
|
|
|
|
| 2120 |
def hook(module, input, output):
|
| 2121 |
try:
|
| 2122 |
if output.dim() != 3:
|
|
@@ -2136,16 +2184,79 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2136 |
pass
|
| 2137 |
return hook
|
| 2138 |
|
| 2139 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2140 |
try:
|
|
|
|
| 2141 |
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2142 |
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 2143 |
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 2144 |
-
hook = layer.attn.qkv_proj.register_forward_hook(
|
| 2145 |
hooks.append(hook)
|
| 2146 |
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 2147 |
-
hook = layer.attn.c_attn.register_forward_hook(
|
| 2148 |
hooks.append(hook)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2149 |
except Exception as hook_error:
|
| 2150 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2151 |
|
|
|
|
| 1644 |
qkv_captures = {}
|
| 1645 |
hooks = []
|
| 1646 |
|
| 1647 |
+
# Hook for combined QKV projection (CodeGen style)
|
| 1648 |
+
def make_combined_qkv_hook(layer_idx):
|
| 1649 |
def hook(module, input, output):
|
| 1650 |
try:
|
|
|
|
|
|
|
| 1651 |
if output.dim() != 3:
|
| 1652 |
+
return
|
| 1653 |
batch_size, seq_len, hidden = output.shape
|
| 1654 |
expected_hidden = 3 * n_heads * head_dim
|
| 1655 |
if hidden != expected_hidden:
|
| 1656 |
+
return
|
| 1657 |
qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim)
|
|
|
|
| 1658 |
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
| 1659 |
qkv_captures[layer_idx] = {
|
| 1660 |
+
'q': q[0].detach().cpu(),
|
| 1661 |
'k': k[0].detach().cpu(),
|
| 1662 |
'v': v[0].detach().cpu()
|
| 1663 |
}
|
| 1664 |
except Exception:
|
|
|
|
| 1665 |
pass
|
| 1666 |
return hook
|
| 1667 |
|
| 1668 |
+
# Hooks for separate Q, K, V projections (Mistral/LLaMA style)
|
| 1669 |
+
def make_separate_proj_hook(layer_idx, proj_type, num_kv_heads=None):
|
| 1670 |
+
def hook(module, input, output):
|
| 1671 |
+
try:
|
| 1672 |
+
if output.dim() != 3:
|
| 1673 |
+
return
|
| 1674 |
+
batch_size, seq_len, hidden = output.shape
|
| 1675 |
+
|
| 1676 |
+
if proj_type == 'q':
|
| 1677 |
+
proj_heads = n_heads
|
| 1678 |
+
else:
|
| 1679 |
+
proj_heads = num_kv_heads if num_kv_heads else n_heads
|
| 1680 |
+
|
| 1681 |
+
proj_head_dim = hidden // proj_heads
|
| 1682 |
+
if hidden != proj_heads * proj_head_dim:
|
| 1683 |
+
return
|
| 1684 |
+
|
| 1685 |
+
proj_output = output.reshape(batch_size, seq_len, proj_heads, proj_head_dim)
|
| 1686 |
+
|
| 1687 |
+
if proj_type != 'q' and num_kv_heads and num_kv_heads < n_heads:
|
| 1688 |
+
repeat_factor = n_heads // num_kv_heads
|
| 1689 |
+
proj_output = proj_output.repeat_interleave(repeat_factor, dim=2)
|
| 1690 |
+
|
| 1691 |
+
if layer_idx not in qkv_captures:
|
| 1692 |
+
qkv_captures[layer_idx] = {}
|
| 1693 |
+
|
| 1694 |
+
qkv_captures[layer_idx][proj_type] = proj_output[0].detach().cpu()
|
| 1695 |
+
except Exception:
|
| 1696 |
+
pass
|
| 1697 |
+
return hook
|
| 1698 |
+
|
| 1699 |
+
# Register hooks based on model architecture
|
| 1700 |
try:
|
| 1701 |
+
# CodeGen style: model.transformer.h[layer].attn.qkv_proj
|
| 1702 |
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 1703 |
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 1704 |
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 1705 |
+
hook = layer.attn.qkv_proj.register_forward_hook(make_combined_qkv_hook(layer_idx))
|
| 1706 |
hooks.append(hook)
|
| 1707 |
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 1708 |
+
hook = layer.attn.c_attn.register_forward_hook(make_combined_qkv_hook(layer_idx))
|
|
|
|
| 1709 |
hooks.append(hook)
|
| 1710 |
+
|
| 1711 |
+
# Mistral/LLaMA style: model.model.layers[layer].self_attn.{q,k,v}_proj
|
| 1712 |
+
elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'):
|
| 1713 |
+
num_kv_heads = getattr(manager.model.config, 'num_key_value_heads', None)
|
| 1714 |
+
for layer_idx, layer in enumerate(manager.model.model.layers):
|
| 1715 |
+
if hasattr(layer, 'self_attn'):
|
| 1716 |
+
attn = layer.self_attn
|
| 1717 |
+
if hasattr(attn, 'q_proj'):
|
| 1718 |
+
hook = attn.q_proj.register_forward_hook(
|
| 1719 |
+
make_separate_proj_hook(layer_idx, 'q', num_kv_heads))
|
| 1720 |
+
hooks.append(hook)
|
| 1721 |
+
if hasattr(attn, 'k_proj'):
|
| 1722 |
+
hook = attn.k_proj.register_forward_hook(
|
| 1723 |
+
make_separate_proj_hook(layer_idx, 'k', num_kv_heads))
|
| 1724 |
+
hooks.append(hook)
|
| 1725 |
+
if hasattr(attn, 'v_proj'):
|
| 1726 |
+
hook = attn.v_proj.register_forward_hook(
|
| 1727 |
+
make_separate_proj_hook(layer_idx, 'v', num_kv_heads))
|
| 1728 |
+
hooks.append(hook)
|
| 1729 |
+
logger.info(f"Registered QKV hooks for {len(hooks)//3} Mistral layers (GQA: {num_kv_heads} KV heads)")
|
| 1730 |
except Exception as hook_error:
|
| 1731 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 1732 |
|
|
|
|
| 2163 |
qkv_captures = {}
|
| 2164 |
hooks = []
|
| 2165 |
|
| 2166 |
+
# Hook for combined QKV projection (CodeGen style)
|
| 2167 |
+
def make_combined_qkv_hook(layer_idx):
|
| 2168 |
def hook(module, input, output):
|
| 2169 |
try:
|
| 2170 |
if output.dim() != 3:
|
|
|
|
| 2184 |
pass
|
| 2185 |
return hook
|
| 2186 |
|
| 2187 |
+
# Hooks for separate Q, K, V projections (Mistral/LLaMA style)
|
| 2188 |
+
def make_separate_proj_hook(layer_idx, proj_type, num_kv_heads=None):
|
| 2189 |
+
"""Create hook for separate Q/K/V projection modules.
|
| 2190 |
+
|
| 2191 |
+
For GQA models, K and V have fewer heads than Q, so we need to
|
| 2192 |
+
expand them to match Q's head count for consistent visualization.
|
| 2193 |
+
"""
|
| 2194 |
+
def hook(module, input, output):
|
| 2195 |
+
try:
|
| 2196 |
+
if output.dim() != 3:
|
| 2197 |
+
return
|
| 2198 |
+
batch_size, seq_len, hidden = output.shape
|
| 2199 |
+
|
| 2200 |
+
# Determine number of heads for this projection
|
| 2201 |
+
if proj_type == 'q':
|
| 2202 |
+
proj_heads = n_heads
|
| 2203 |
+
else:
|
| 2204 |
+
# K and V may have fewer heads (GQA)
|
| 2205 |
+
proj_heads = num_kv_heads if num_kv_heads else n_heads
|
| 2206 |
+
|
| 2207 |
+
proj_head_dim = hidden // proj_heads
|
| 2208 |
+
if hidden != proj_heads * proj_head_dim:
|
| 2209 |
+
return
|
| 2210 |
+
|
| 2211 |
+
# Reshape to [batch, seq, heads, head_dim]
|
| 2212 |
+
proj_output = output.reshape(batch_size, seq_len, proj_heads, proj_head_dim)
|
| 2213 |
+
|
| 2214 |
+
# For GQA, expand K/V to match Q's head count
|
| 2215 |
+
if proj_type != 'q' and num_kv_heads and num_kv_heads < n_heads:
|
| 2216 |
+
# Repeat each KV head to match Q heads
|
| 2217 |
+
repeat_factor = n_heads // num_kv_heads
|
| 2218 |
+
proj_output = proj_output.repeat_interleave(repeat_factor, dim=2)
|
| 2219 |
+
|
| 2220 |
+
# Initialize layer entry if needed
|
| 2221 |
+
if layer_idx not in qkv_captures:
|
| 2222 |
+
qkv_captures[layer_idx] = {}
|
| 2223 |
+
|
| 2224 |
+
qkv_captures[layer_idx][proj_type] = proj_output[0].detach().cpu()
|
| 2225 |
+
except Exception as e:
|
| 2226 |
+
logger.debug(f"QKV capture error for layer {layer_idx} {proj_type}: {e}")
|
| 2227 |
+
return hook
|
| 2228 |
+
|
| 2229 |
+
# Register hooks based on model architecture
|
| 2230 |
try:
|
| 2231 |
+
# CodeGen style: model.transformer.h[layer].attn.qkv_proj
|
| 2232 |
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
| 2233 |
for layer_idx, layer in enumerate(manager.model.transformer.h):
|
| 2234 |
if hasattr(layer, 'attn') and hasattr(layer.attn, 'qkv_proj'):
|
| 2235 |
+
hook = layer.attn.qkv_proj.register_forward_hook(make_combined_qkv_hook(layer_idx))
|
| 2236 |
hooks.append(hook)
|
| 2237 |
elif hasattr(layer, 'attn') and hasattr(layer.attn, 'c_attn'):
|
| 2238 |
+
hook = layer.attn.c_attn.register_forward_hook(make_combined_qkv_hook(layer_idx))
|
| 2239 |
hooks.append(hook)
|
| 2240 |
+
|
| 2241 |
+
# Mistral/LLaMA style: model.model.layers[layer].self_attn.{q,k,v}_proj
|
| 2242 |
+
elif hasattr(manager.model, 'model') and hasattr(manager.model.model, 'layers'):
|
| 2243 |
+
num_kv_heads = getattr(manager.model.config, 'num_key_value_heads', None)
|
| 2244 |
+
for layer_idx, layer in enumerate(manager.model.model.layers):
|
| 2245 |
+
if hasattr(layer, 'self_attn'):
|
| 2246 |
+
attn = layer.self_attn
|
| 2247 |
+
if hasattr(attn, 'q_proj'):
|
| 2248 |
+
hook = attn.q_proj.register_forward_hook(
|
| 2249 |
+
make_separate_proj_hook(layer_idx, 'q', num_kv_heads))
|
| 2250 |
+
hooks.append(hook)
|
| 2251 |
+
if hasattr(attn, 'k_proj'):
|
| 2252 |
+
hook = attn.k_proj.register_forward_hook(
|
| 2253 |
+
make_separate_proj_hook(layer_idx, 'k', num_kv_heads))
|
| 2254 |
+
hooks.append(hook)
|
| 2255 |
+
if hasattr(attn, 'v_proj'):
|
| 2256 |
+
hook = attn.v_proj.register_forward_hook(
|
| 2257 |
+
make_separate_proj_hook(layer_idx, 'v', num_kv_heads))
|
| 2258 |
+
hooks.append(hook)
|
| 2259 |
+
logger.info(f"Registered QKV hooks for {len(hooks)//3} Mistral layers (GQA: {num_kv_heads} KV heads)")
|
| 2260 |
except Exception as hook_error:
|
| 2261 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2262 |
|