File size: 4,268 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Capture routing decisions (top-k experts + gating weights) at every token.

Qwen3-MoE gate path: model.model.layers[i].mlp.gate
  - Input:  hidden_states, shape (B, S, D)
  - Output: router logits, shape (B, S, num_experts)

We save per-layer top-k (ID + gating weight) for each token, then filter down
to decision points in post-processing.

Memory strategy:
  - Hook dumps to CPU fp16 immediately
  - For each CoT, save a dict: {layer_id: (S, top_k) expert_ids, (S, top_k) gates}
  - Sharded to disk per 50 samples
"""
import torch
import torch.nn.functional as F
from typing import Dict, List, Optional
from configs.model import MODEL_CONFIG


class RoutingCapture:
    """
    Hook all MoE router gates. For each forward pass:
      captured[layer_id] = {
          "topk_ids":  (S, top_k) int16 tensor,
          "topk_gates": (S, top_k) float16 tensor,
      }
    """

    def __init__(self, model, top_k: Optional[int] = None):
        self.model = model
        self.top_k = top_k or MODEL_CONFIG["num_experts_per_tok"]
        self.handles = []
        self.captured: Dict[int, Dict[str, torch.Tensor]] = {}
        self.num_layers = MODEL_CONFIG["num_layers"]

    def _make_hook(self, layer_id: int):
        top_k = self.top_k

        def fn(module, inputs, output):
            # Capture router output (logits over experts).
            #
            # NOTE on gate weights: We compute softmax(all_logits) then top-k.
            # Qwen3 actually does top-k FIRST then softmax over the selected k.
            # The captured `topk_gates` is therefore an APPROXIMATION and should
            # NOT be used for downstream weighted analysis. Use `topk_ids` only
            # for selection-frequency analysis (which is what expert_select.py does).
            if isinstance(output, tuple):
                logits = output[0]
            else:
                logits = output
            # Normalize shape to (B, S, E)
            if logits.dim() == 2:
                # Flattened: can't recover B/S here; assume B=1
                logits = logits.unsqueeze(0)
            B, S, E = logits.shape
            # top-k selection (correct: indices match what the model actually routes to)
            topk_vals, topk_ids = logits.topk(top_k, dim=-1)   # (B, S, top_k)
            # Approximate gates (NOT model's real gating weights)
            topk_gates_approx = torch.softmax(topk_vals, dim=-1)
            # Assume B=1 (we process one CoT at a time)
            self.captured[layer_id] = {
                "topk_ids":   topk_ids.squeeze(0).to(torch.int16).cpu(),
                "topk_gates": topk_gates_approx.squeeze(0).to(torch.float16).cpu(),
            }
        return fn

    def start(self):
        self.captured = {}
        gate_path = MODEL_CONFIG["gate_attr_path"]   # "mlp.gate"
        parts = gate_path.split(".")
        for li in range(self.num_layers):
            layer = self.model.model.layers[li]
            mod = layer
            for p in parts:
                if not hasattr(mod, p):
                    raise AttributeError(
                        f"Layer {li}: attribute path '{gate_path}' not found at '{p}'. "
                        f"Use `print(dict(model.model.layers[0].named_modules()))` to explore."
                    )
                mod = getattr(mod, p)
            h = mod.register_forward_hook(self._make_hook(li))
            self.handles.append(h)

    def stop(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        return self.captured


def run_forward_and_capture_routing(model, tokenizer, text: str) -> Dict:
    """
    Run a single prefill forward pass on `text` and capture per-layer routing.

    Returns:
      {
        "input_ids": (S,) int tensor,
        "routing":   {layer_id: {"topk_ids": (S, k), "topk_gates": (S, k)}, ...}
      }
    """
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
    input_ids = enc["input_ids"].to(model.device)

    cap = RoutingCapture(model)
    cap.start()
    try:
        with torch.no_grad():
            _ = model(input_ids)
    finally:
        routing = cap.stop()

    return {
        "input_ids": input_ids.squeeze(0).cpu(),
        "routing":   routing,
    }