Spaces:
Paused
Paused
gary-boon Claude Opus 4.6 (1M context) commited on
Commit ·
82349c1
1
Parent(s): 6f48db0
Add FFN contribution metrics and gate activation stats to research endpoint
Browse files- Compute attn_contribution/ffn_contribution ratios from existing
attention and MLP output norm hooks (no new forward passes needed)
- Add gate activation hook for SwiGLU models (LLaMA/Mistral) capturing
sparsity, mean, and max of gate values per layer per token
- Expose ffnType and intermediateSize in modelInfo response
- Clear gate_activation_stats each generation step alongside existing dicts
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- backend/model_service.py +48 -5
backend/model_service.py
CHANGED
|
@@ -2663,9 +2663,10 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2663 |
except Exception as hook_error:
|
| 2664 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2665 |
|
| 2666 |
-
# Phase 4: Hooks for attention and MLP output norms
|
| 2667 |
attn_output_norms = {}
|
| 2668 |
mlp_output_norms = {}
|
|
|
|
| 2669 |
|
| 2670 |
def make_attn_output_hook(layer_idx):
|
| 2671 |
def hook(module, input, output):
|
|
@@ -2689,6 +2690,30 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2689 |
pass
|
| 2690 |
return hook
|
| 2691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2692 |
try:
|
| 2693 |
# CodeGen style
|
| 2694 |
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
|
@@ -2708,7 +2733,13 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2708 |
if hasattr(layer, 'mlp'):
|
| 2709 |
hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
|
| 2710 |
hooks.append(hook)
|
| 2711 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2712 |
except Exception as hook_error:
|
| 2713 |
logger.warning(f"Could not register attn/MLP hooks: {hook_error}")
|
| 2714 |
|
|
@@ -2728,6 +2759,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2728 |
qkv_captures.clear()
|
| 2729 |
attn_output_norms.clear()
|
| 2730 |
mlp_output_norms.clear()
|
|
|
|
| 2731 |
|
| 2732 |
# Forward pass with full outputs
|
| 2733 |
outputs = manager.model(
|
|
@@ -3195,11 +3227,20 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 3195 |
"delta_norm": delta_norm,
|
| 3196 |
"margin_contribution": margin_contribution,
|
| 3197 |
}
|
| 3198 |
-
# Phase 4: Attention and MLP output norms
|
| 3199 |
if layer_idx in attn_output_norms:
|
| 3200 |
-
layer_entry["attn_output_norm"] = attn_output_norms[layer_idx]
|
| 3201 |
if layer_idx in mlp_output_norms:
|
| 3202 |
-
layer_entry["mlp_output_norm"] = mlp_output_norms[layer_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3203 |
|
| 3204 |
# Phase 5: Logit lens at sampled layers (every 8th layer)
|
| 3205 |
logit_lens_stride = max(1, n_layers // 5)
|
|
@@ -3490,6 +3531,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 3490 |
"headDim": head_dim,
|
| 3491 |
"vocabSize": manager.model.config.vocab_size,
|
| 3492 |
"tunedLensAvailable": tuned_lens_runtime.available,
|
|
|
|
|
|
|
| 3493 |
},
|
| 3494 |
"generationTime": generation_time,
|
| 3495 |
"numTokensGenerated": len(generated_tokens),
|
|
|
|
| 2663 |
except Exception as hook_error:
|
| 2664 |
logger.warning(f"Could not register QKV hooks: {hook_error}")
|
| 2665 |
|
| 2666 |
+
# Phase 4: Hooks for attention and MLP output norms + gate activation stats
|
| 2667 |
attn_output_norms = {}
|
| 2668 |
mlp_output_norms = {}
|
| 2669 |
+
gate_activation_stats = {}
|
| 2670 |
|
| 2671 |
def make_attn_output_hook(layer_idx):
|
| 2672 |
def hook(module, input, output):
|
|
|
|
| 2690 |
pass
|
| 2691 |
return hook
|
| 2692 |
|
| 2693 |
+
def make_gate_hook(layer_idx):
|
| 2694 |
+
"""Capture gate activation stats for SwiGLU FFN (LLaMA/Mistral)."""
|
| 2695 |
+
def hook(module, input, output):
|
| 2696 |
+
try:
|
| 2697 |
+
inp = input[0] if isinstance(input, tuple) else input
|
| 2698 |
+
if inp.dim() == 3:
|
| 2699 |
+
inp = inp[0, -1] # last token
|
| 2700 |
+
elif inp.dim() == 2:
|
| 2701 |
+
inp = inp[-1]
|
| 2702 |
+
if hasattr(module, 'gate_proj'):
|
| 2703 |
+
gate_out = torch.nn.functional.silu(module.gate_proj(inp))
|
| 2704 |
+
abs_gate = gate_out.abs()
|
| 2705 |
+
gate_activation_stats[layer_idx] = {
|
| 2706 |
+
"sparsity": round(float((abs_gate < 0.01).float().mean().item()), 4),
|
| 2707 |
+
"mean": round(float(gate_out.mean().item()), 4),
|
| 2708 |
+
"max": round(float(gate_out.max().item()), 4),
|
| 2709 |
+
}
|
| 2710 |
+
except Exception:
|
| 2711 |
+
pass
|
| 2712 |
+
return hook
|
| 2713 |
+
|
| 2714 |
+
# Detect FFN type from first layer
|
| 2715 |
+
ffn_type = "gelu" # default
|
| 2716 |
+
|
| 2717 |
try:
|
| 2718 |
# CodeGen style
|
| 2719 |
if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
|
|
|
|
| 2733 |
if hasattr(layer, 'mlp'):
|
| 2734 |
hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
|
| 2735 |
hooks.append(hook)
|
| 2736 |
+
# Gate hook for SwiGLU models
|
| 2737 |
+
if hasattr(layer.mlp, 'gate_proj'):
|
| 2738 |
+
hook = layer.mlp.register_forward_hook(make_gate_hook(layer_idx))
|
| 2739 |
+
hooks.append(hook)
|
| 2740 |
+
if layer_idx == 0:
|
| 2741 |
+
ffn_type = "swiglu"
|
| 2742 |
+
logger.info(f"Registered attn/MLP output hooks for contribution tracking (ffn_type={ffn_type})")
|
| 2743 |
except Exception as hook_error:
|
| 2744 |
logger.warning(f"Could not register attn/MLP hooks: {hook_error}")
|
| 2745 |
|
|
|
|
| 2759 |
qkv_captures.clear()
|
| 2760 |
attn_output_norms.clear()
|
| 2761 |
mlp_output_norms.clear()
|
| 2762 |
+
gate_activation_stats.clear()
|
| 2763 |
|
| 2764 |
# Forward pass with full outputs
|
| 2765 |
outputs = manager.model(
|
|
|
|
| 3227 |
"delta_norm": delta_norm,
|
| 3228 |
"margin_contribution": margin_contribution,
|
| 3229 |
}
|
| 3230 |
+
# Phase 4: Attention and MLP output norms + contribution ratios
|
| 3231 |
if layer_idx in attn_output_norms:
|
| 3232 |
+
layer_entry["attn_output_norm"] = round(attn_output_norms[layer_idx], 4)
|
| 3233 |
if layer_idx in mlp_output_norms:
|
| 3234 |
+
layer_entry["mlp_output_norm"] = round(mlp_output_norms[layer_idx], 4)
|
| 3235 |
+
if layer_idx in attn_output_norms and layer_idx in mlp_output_norms:
|
| 3236 |
+
attn_n = attn_output_norms[layer_idx]
|
| 3237 |
+
mlp_n = mlp_output_norms[layer_idx]
|
| 3238 |
+
total = attn_n + mlp_n
|
| 3239 |
+
if total > 0:
|
| 3240 |
+
layer_entry["attn_contribution"] = round(attn_n / total, 4)
|
| 3241 |
+
layer_entry["ffn_contribution"] = round(mlp_n / total, 4)
|
| 3242 |
+
if layer_idx in gate_activation_stats:
|
| 3243 |
+
layer_entry["gate_stats"] = gate_activation_stats[layer_idx]
|
| 3244 |
|
| 3245 |
# Phase 5: Logit lens at sampled layers (every 8th layer)
|
| 3246 |
logit_lens_stride = max(1, n_layers // 5)
|
|
|
|
| 3531 |
"headDim": head_dim,
|
| 3532 |
"vocabSize": manager.model.config.vocab_size,
|
| 3533 |
"tunedLensAvailable": tuned_lens_runtime.available,
|
| 3534 |
+
"ffnType": ffn_type,
|
| 3535 |
+
"intermediateSize": getattr(manager.model.config, 'intermediate_size', None),
|
| 3536 |
},
|
| 3537 |
"generationTime": generation_time,
|
| 3538 |
"numTokensGenerated": len(generated_tokens),
|