v2 / src /routing_capture.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Capture routing decisions (top-k experts + gating weights) at every token.
Qwen3-MoE gate path: model.model.layers[i].mlp.gate
- Input: hidden_states, shape (B, S, D)
- Output: router logits, shape (B, S, num_experts)
We save per-layer top-k (ID + gating weight) for each token, then filter down
to decision points in post-processing.
Memory strategy:
- Hook dumps to CPU fp16 immediately
- For each CoT, save a dict: {layer_id: (S, top_k) expert_ids, (S, top_k) gates}
- Sharded to disk per 50 samples
"""
import torch
import torch.nn.functional as F
from typing import Dict, List, Optional
from configs.model import MODEL_CONFIG
class RoutingCapture:
"""
Hook all MoE router gates. For each forward pass:
captured[layer_id] = {
"topk_ids": (S, top_k) int16 tensor,
"topk_gates": (S, top_k) float16 tensor,
}
"""
def __init__(self, model, top_k: Optional[int] = None):
self.model = model
self.top_k = top_k or MODEL_CONFIG["num_experts_per_tok"]
self.handles = []
self.captured: Dict[int, Dict[str, torch.Tensor]] = {}
self.num_layers = MODEL_CONFIG["num_layers"]
def _make_hook(self, layer_id: int):
top_k = self.top_k
def fn(module, inputs, output):
# Capture router output (logits over experts).
#
# NOTE on gate weights: We compute softmax(all_logits) then top-k.
# Qwen3 actually does top-k FIRST then softmax over the selected k.
# The captured `topk_gates` is therefore an APPROXIMATION and should
# NOT be used for downstream weighted analysis. Use `topk_ids` only
# for selection-frequency analysis (which is what expert_select.py does).
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
# Normalize shape to (B, S, E)
if logits.dim() == 2:
# Flattened: can't recover B/S here; assume B=1
logits = logits.unsqueeze(0)
B, S, E = logits.shape
# top-k selection (correct: indices match what the model actually routes to)
topk_vals, topk_ids = logits.topk(top_k, dim=-1) # (B, S, top_k)
# Approximate gates (NOT model's real gating weights)
topk_gates_approx = torch.softmax(topk_vals, dim=-1)
# Assume B=1 (we process one CoT at a time)
self.captured[layer_id] = {
"topk_ids": topk_ids.squeeze(0).to(torch.int16).cpu(),
"topk_gates": topk_gates_approx.squeeze(0).to(torch.float16).cpu(),
}
return fn
def start(self):
self.captured = {}
gate_path = MODEL_CONFIG["gate_attr_path"] # "mlp.gate"
parts = gate_path.split(".")
for li in range(self.num_layers):
layer = self.model.model.layers[li]
mod = layer
for p in parts:
if not hasattr(mod, p):
raise AttributeError(
f"Layer {li}: attribute path '{gate_path}' not found at '{p}'. "
f"Use `print(dict(model.model.layers[0].named_modules()))` to explore."
)
mod = getattr(mod, p)
h = mod.register_forward_hook(self._make_hook(li))
self.handles.append(h)
def stop(self):
for h in self.handles:
h.remove()
self.handles = []
return self.captured
def run_forward_and_capture_routing(model, tokenizer, text: str) -> Dict:
"""
Run a single prefill forward pass on `text` and capture per-layer routing.
Returns:
{
"input_ids": (S,) int tensor,
"routing": {layer_id: {"topk_ids": (S, k), "topk_gates": (S, k)}, ...}
}
"""
enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
input_ids = enc["input_ids"].to(model.device)
cap = RoutingCapture(model)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
routing = cap.stop()
return {
"input_ids": input_ids.squeeze(0).cpu(),
"routing": routing,
}