""" Capture post-layer residual stream at decision points. Only captures: - target layers (those containing top-K experts from Stage 1) - decision point tokens (plan / mon / exec / general newline) Memory strategy: - Save fp16 - Only decision-point positions, not full sequence """ import torch from typing import Dict, List, Optional from configs.model import MODEL_CONFIG class ResidualCapture: """ Hook decoder layers' output (post-layer residual stream). For each forward pass, captured[layer_id] = (S, D) fp16 tensor. """ def __init__(self, model, target_layers: Optional[List[int]] = None): self.model = model self.target_layers = target_layers if target_layers else list(range(MODEL_CONFIG["num_layers"])) self.handles = [] self.captured: Dict[int, torch.Tensor] = {} def _make_hook(self, layer_id: int): def fn(module, inputs, output): # Decoder layer output: could be (hidden_states,) or (hidden_states, ...) if isinstance(output, tuple): h = output[0] else: h = output # h shape: (B, S, D) self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu() return fn def start(self): self.captured = {} for li in self.target_layers: layer = self.model.model.layers[li] h = layer.register_forward_hook(self._make_hook(li)) self.handles.append(h) def stop(self): for h in self.handles: h.remove() self.handles = [] return self.captured def run_forward_and_capture_residuals( model, tokenizer, text: str, target_layers: List[int] ) -> Dict: """ Run a forward pass and capture post-layer residual at target_layers only. Returns {layer_id: (S, D) fp16 tensor} """ enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) input_ids = enc["input_ids"].to(model.device) cap = ResidualCapture(model, target_layers=target_layers) cap.start() try: with torch.no_grad(): _ = model(input_ids) finally: residuals = cap.stop() return { "input_ids": input_ids.squeeze(0).cpu(), "residuals": residuals, }