gary-boon Claude Opus 4.6 (1M context) commited on
Commit
82349c1
·
1 Parent(s): 6f48db0

Add FFN contribution metrics and gate activation stats to research endpoint

Browse files

- Compute attn_contribution/ffn_contribution ratios from existing
attention and MLP output norm hooks (no new forward passes needed)
- Add gate activation hook for SwiGLU models (LLaMA/Mistral) capturing
sparsity, mean, and max of gate values per layer per token
- Expose ffnType and intermediateSize in modelInfo response
- Clear gate_activation_stats each generation step alongside existing dicts

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +48 -5
backend/model_service.py CHANGED
@@ -2663,9 +2663,10 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2663
  except Exception as hook_error:
2664
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2665
 
2666
- # Phase 4: Hooks for attention and MLP output norms
2667
  attn_output_norms = {}
2668
  mlp_output_norms = {}
 
2669
 
2670
  def make_attn_output_hook(layer_idx):
2671
  def hook(module, input, output):
@@ -2689,6 +2690,30 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2689
  pass
2690
  return hook
2691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2692
  try:
2693
  # CodeGen style
2694
  if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
@@ -2708,7 +2733,13 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2708
  if hasattr(layer, 'mlp'):
2709
  hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
2710
  hooks.append(hook)
2711
- logger.info(f"Registered attn/MLP output hooks for contribution tracking")
 
 
 
 
 
 
2712
  except Exception as hook_error:
2713
  logger.warning(f"Could not register attn/MLP hooks: {hook_error}")
2714
 
@@ -2728,6 +2759,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2728
  qkv_captures.clear()
2729
  attn_output_norms.clear()
2730
  mlp_output_norms.clear()
 
2731
 
2732
  # Forward pass with full outputs
2733
  outputs = manager.model(
@@ -3195,11 +3227,20 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3195
  "delta_norm": delta_norm,
3196
  "margin_contribution": margin_contribution,
3197
  }
3198
- # Phase 4: Attention and MLP output norms
3199
  if layer_idx in attn_output_norms:
3200
- layer_entry["attn_output_norm"] = attn_output_norms[layer_idx]
3201
  if layer_idx in mlp_output_norms:
3202
- layer_entry["mlp_output_norm"] = mlp_output_norms[layer_idx]
 
 
 
 
 
 
 
 
 
3203
 
3204
  # Phase 5: Logit lens at sampled layers (every 8th layer)
3205
  logit_lens_stride = max(1, n_layers // 5)
@@ -3490,6 +3531,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3490
  "headDim": head_dim,
3491
  "vocabSize": manager.model.config.vocab_size,
3492
  "tunedLensAvailable": tuned_lens_runtime.available,
 
 
3493
  },
3494
  "generationTime": generation_time,
3495
  "numTokensGenerated": len(generated_tokens),
 
2663
  except Exception as hook_error:
2664
  logger.warning(f"Could not register QKV hooks: {hook_error}")
2665
 
2666
+ # Phase 4: Hooks for attention and MLP output norms + gate activation stats
2667
  attn_output_norms = {}
2668
  mlp_output_norms = {}
2669
+ gate_activation_stats = {}
2670
 
2671
  def make_attn_output_hook(layer_idx):
2672
  def hook(module, input, output):
 
2690
  pass
2691
  return hook
2692
 
2693
+ def make_gate_hook(layer_idx):
2694
+ """Capture gate activation stats for SwiGLU FFN (LLaMA/Mistral)."""
2695
+ def hook(module, input, output):
2696
+ try:
2697
+ inp = input[0] if isinstance(input, tuple) else input
2698
+ if inp.dim() == 3:
2699
+ inp = inp[0, -1] # last token
2700
+ elif inp.dim() == 2:
2701
+ inp = inp[-1]
2702
+ if hasattr(module, 'gate_proj'):
2703
+ gate_out = torch.nn.functional.silu(module.gate_proj(inp))
2704
+ abs_gate = gate_out.abs()
2705
+ gate_activation_stats[layer_idx] = {
2706
+ "sparsity": round(float((abs_gate < 0.01).float().mean().item()), 4),
2707
+ "mean": round(float(gate_out.mean().item()), 4),
2708
+ "max": round(float(gate_out.max().item()), 4),
2709
+ }
2710
+ except Exception:
2711
+ pass
2712
+ return hook
2713
+
2714
+ # Detect FFN type from first layer
2715
+ ffn_type = "gelu" # default
2716
+
2717
  try:
2718
  # CodeGen style
2719
  if hasattr(manager.model, 'transformer') and hasattr(manager.model.transformer, 'h'):
 
2733
  if hasattr(layer, 'mlp'):
2734
  hook = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
2735
  hooks.append(hook)
2736
+ # Gate hook for SwiGLU models
2737
+ if hasattr(layer.mlp, 'gate_proj'):
2738
+ hook = layer.mlp.register_forward_hook(make_gate_hook(layer_idx))
2739
+ hooks.append(hook)
2740
+ if layer_idx == 0:
2741
+ ffn_type = "swiglu"
2742
+ logger.info(f"Registered attn/MLP output hooks for contribution tracking (ffn_type={ffn_type})")
2743
  except Exception as hook_error:
2744
  logger.warning(f"Could not register attn/MLP hooks: {hook_error}")
2745
 
 
2759
  qkv_captures.clear()
2760
  attn_output_norms.clear()
2761
  mlp_output_norms.clear()
2762
+ gate_activation_stats.clear()
2763
 
2764
  # Forward pass with full outputs
2765
  outputs = manager.model(
 
3227
  "delta_norm": delta_norm,
3228
  "margin_contribution": margin_contribution,
3229
  }
3230
+ # Phase 4: Attention and MLP output norms + contribution ratios
3231
  if layer_idx in attn_output_norms:
3232
+ layer_entry["attn_output_norm"] = round(attn_output_norms[layer_idx], 4)
3233
  if layer_idx in mlp_output_norms:
3234
+ layer_entry["mlp_output_norm"] = round(mlp_output_norms[layer_idx], 4)
3235
+ if layer_idx in attn_output_norms and layer_idx in mlp_output_norms:
3236
+ attn_n = attn_output_norms[layer_idx]
3237
+ mlp_n = mlp_output_norms[layer_idx]
3238
+ total = attn_n + mlp_n
3239
+ if total > 0:
3240
+ layer_entry["attn_contribution"] = round(attn_n / total, 4)
3241
+ layer_entry["ffn_contribution"] = round(mlp_n / total, 4)
3242
+ if layer_idx in gate_activation_stats:
3243
+ layer_entry["gate_stats"] = gate_activation_stats[layer_idx]
3244
 
3245
  # Phase 5: Logit lens at sampled layers (every 8th layer)
3246
  logit_lens_stride = max(1, n_layers // 5)
 
3531
  "headDim": head_dim,
3532
  "vocabSize": manager.model.config.vocab_size,
3533
  "tunedLensAvailable": tuned_lens_runtime.available,
3534
+ "ffnType": ffn_type,
3535
+ "intermediateSize": getattr(manager.model.config, 'intermediate_size', None),
3536
  },
3537
  "generationTime": generation_time,
3538
  "numTokensGenerated": len(generated_tokens),