""" Capture routing decisions (top-k experts + gating weights) at every token. Qwen3-MoE gate path: model.model.layers[i].mlp.gate - Input: hidden_states, shape (B, S, D) - Output: router logits, shape (B, S, num_experts) We save per-layer top-k (ID + gating weight) for each token, then filter down to decision points in post-processing. Memory strategy: - Hook dumps to CPU fp16 immediately - For each CoT, save a dict: {layer_id: (S, top_k) expert_ids, (S, top_k) gates} - Sharded to disk per 50 samples """ import torch import torch.nn.functional as F from typing import Dict, List, Optional from configs.model import MODEL_CONFIG class RoutingCapture: """ Hook all MoE router gates. For each forward pass: captured[layer_id] = { "topk_ids": (S, top_k) int16 tensor, "topk_gates": (S, top_k) float16 tensor, } """ def __init__(self, model, top_k: Optional[int] = None): self.model = model self.top_k = top_k or MODEL_CONFIG["num_experts_per_tok"] self.handles = [] self.captured: Dict[int, Dict[str, torch.Tensor]] = {} self.num_layers = MODEL_CONFIG["num_layers"] def _make_hook(self, layer_id: int): top_k = self.top_k def fn(module, inputs, output): # Capture router output (logits over experts). # # NOTE on gate weights: We compute softmax(all_logits) then top-k. # Qwen3 actually does top-k FIRST then softmax over the selected k. # The captured `topk_gates` is therefore an APPROXIMATION and should # NOT be used for downstream weighted analysis. Use `topk_ids` only # for selection-frequency analysis (which is what expert_select.py does). if isinstance(output, tuple): logits = output[0] else: logits = output # Normalize shape to (B, S, E) if logits.dim() == 2: # Flattened: can't recover B/S here; assume B=1 logits = logits.unsqueeze(0) B, S, E = logits.shape # top-k selection (correct: indices match what the model actually routes to) topk_vals, topk_ids = logits.topk(top_k, dim=-1) # (B, S, top_k) # Approximate gates (NOT model's real gating weights) topk_gates_approx = torch.softmax(topk_vals, dim=-1) # Assume B=1 (we process one CoT at a time) self.captured[layer_id] = { "topk_ids": topk_ids.squeeze(0).to(torch.int16).cpu(), "topk_gates": topk_gates_approx.squeeze(0).to(torch.float16).cpu(), } return fn def start(self): self.captured = {} gate_path = MODEL_CONFIG["gate_attr_path"] # "mlp.gate" parts = gate_path.split(".") for li in range(self.num_layers): layer = self.model.model.layers[li] mod = layer for p in parts: if not hasattr(mod, p): raise AttributeError( f"Layer {li}: attribute path '{gate_path}' not found at '{p}'. " f"Use `print(dict(model.model.layers[0].named_modules()))` to explore." ) mod = getattr(mod, p) h = mod.register_forward_hook(self._make_hook(li)) self.handles.append(h) def stop(self): for h in self.handles: h.remove() self.handles = [] return self.captured def run_forward_and_capture_routing(model, tokenizer, text: str) -> Dict: """ Run a single prefill forward pass on `text` and capture per-layer routing. Returns: { "input_ids": (S,) int tensor, "routing": {layer_id: {"topk_ids": (S, k), "topk_gates": (S, k)}, ...} } """ enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) input_ids = enc["input_ids"].to(model.device) cap = RoutingCapture(model) cap.start() try: with torch.no_grad(): _ = model(input_ids) finally: routing = cap.stop() return { "input_ids": input_ids.squeeze(0).cpu(), "routing": routing, }