| """ |
| Capture post-layer residual stream at decision points. |
| |
| Only captures: |
| - target layers (those containing top-K experts from Stage 1) |
| - decision point tokens (plan / mon / exec / general newline) |
| |
| Memory strategy: |
| - Save fp16 |
| - Only decision-point positions, not full sequence |
| """ |
| import torch |
| from typing import Dict, List, Optional |
| from configs.model import MODEL_CONFIG |
|
|
|
|
| class ResidualCapture: |
| """ |
| Hook decoder layers' output (post-layer residual stream). |
| |
| For each forward pass, captured[layer_id] = (S, D) fp16 tensor. |
| """ |
|
|
| def __init__(self, model, target_layers: Optional[List[int]] = None): |
| self.model = model |
| self.target_layers = target_layers if target_layers else list(range(MODEL_CONFIG["num_layers"])) |
| self.handles = [] |
| self.captured: Dict[int, torch.Tensor] = {} |
|
|
| def _make_hook(self, layer_id: int): |
| def fn(module, inputs, output): |
| |
| if isinstance(output, tuple): |
| h = output[0] |
| else: |
| h = output |
| |
| self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu() |
| return fn |
|
|
| def start(self): |
| self.captured = {} |
| for li in self.target_layers: |
| layer = self.model.model.layers[li] |
| h = layer.register_forward_hook(self._make_hook(li)) |
| self.handles.append(h) |
|
|
| def stop(self): |
| for h in self.handles: |
| h.remove() |
| self.handles = [] |
| return self.captured |
|
|
|
|
| def run_forward_and_capture_residuals( |
| model, tokenizer, text: str, target_layers: List[int] |
| ) -> Dict: |
| """ |
| Run a forward pass and capture post-layer residual at target_layers only. |
| |
| Returns {layer_id: (S, D) fp16 tensor} |
| """ |
| enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) |
| input_ids = enc["input_ids"].to(model.device) |
|
|
| cap = ResidualCapture(model, target_layers=target_layers) |
| cap.start() |
| try: |
| with torch.no_grad(): |
| _ = model(input_ids) |
| finally: |
| residuals = cap.stop() |
|
|
| return { |
| "input_ids": input_ids.squeeze(0).cpu(), |
| "residuals": residuals, |
| } |
|
|