| """ |
| Capture routing decisions (top-k experts + gating weights) at every token. |
| |
| Qwen3-MoE gate path: model.model.layers[i].mlp.gate |
| - Input: hidden_states, shape (B, S, D) |
| - Output: router logits, shape (B, S, num_experts) |
| |
| We save per-layer top-k (ID + gating weight) for each token, then filter down |
| to decision points in post-processing. |
| |
| Memory strategy: |
| - Hook dumps to CPU fp16 immediately |
| - For each CoT, save a dict: {layer_id: (S, top_k) expert_ids, (S, top_k) gates} |
| - Sharded to disk per 50 samples |
| """ |
| import torch |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional |
| from configs.model import MODEL_CONFIG |
|
|
|
|
| class RoutingCapture: |
| """ |
| Hook all MoE router gates. For each forward pass: |
| captured[layer_id] = { |
| "topk_ids": (S, top_k) int16 tensor, |
| "topk_gates": (S, top_k) float16 tensor, |
| } |
| """ |
|
|
| def __init__(self, model, top_k: Optional[int] = None): |
| self.model = model |
| self.top_k = top_k or MODEL_CONFIG["num_experts_per_tok"] |
| self.handles = [] |
| self.captured: Dict[int, Dict[str, torch.Tensor]] = {} |
| self.num_layers = MODEL_CONFIG["num_layers"] |
|
|
| def _make_hook(self, layer_id: int): |
| top_k = self.top_k |
|
|
| def fn(module, inputs, output): |
| |
| |
| |
| |
| |
| |
| |
| if isinstance(output, tuple): |
| logits = output[0] |
| else: |
| logits = output |
| |
| if logits.dim() == 2: |
| |
| logits = logits.unsqueeze(0) |
| B, S, E = logits.shape |
| |
| topk_vals, topk_ids = logits.topk(top_k, dim=-1) |
| |
| topk_gates_approx = torch.softmax(topk_vals, dim=-1) |
| |
| self.captured[layer_id] = { |
| "topk_ids": topk_ids.squeeze(0).to(torch.int16).cpu(), |
| "topk_gates": topk_gates_approx.squeeze(0).to(torch.float16).cpu(), |
| } |
| return fn |
|
|
| def start(self): |
| self.captured = {} |
| gate_path = MODEL_CONFIG["gate_attr_path"] |
| parts = gate_path.split(".") |
| for li in range(self.num_layers): |
| layer = self.model.model.layers[li] |
| mod = layer |
| for p in parts: |
| if not hasattr(mod, p): |
| raise AttributeError( |
| f"Layer {li}: attribute path '{gate_path}' not found at '{p}'. " |
| f"Use `print(dict(model.model.layers[0].named_modules()))` to explore." |
| ) |
| mod = getattr(mod, p) |
| h = mod.register_forward_hook(self._make_hook(li)) |
| self.handles.append(h) |
|
|
| def stop(self): |
| for h in self.handles: |
| h.remove() |
| self.handles = [] |
| return self.captured |
|
|
|
|
| def run_forward_and_capture_routing(model, tokenizer, text: str) -> Dict: |
| """ |
| Run a single prefill forward pass on `text` and capture per-layer routing. |
| |
| Returns: |
| { |
| "input_ids": (S,) int tensor, |
| "routing": {layer_id: {"topk_ids": (S, k), "topk_gates": (S, k)}, ...} |
| } |
| """ |
| enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) |
| input_ids = enc["input_ids"].to(model.device) |
|
|
| cap = RoutingCapture(model) |
| cap.start() |
| try: |
| with torch.no_grad(): |
| _ = model(input_ids) |
| finally: |
| routing = cap.stop() |
|
|
| return { |
| "input_ids": input_ids.squeeze(0).cpu(), |
| "routing": routing, |
| } |
|
|