v2 / src /residual_capture.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Capture post-layer residual stream at decision points.
Only captures:
- target layers (those containing top-K experts from Stage 1)
- decision point tokens (plan / mon / exec / general newline)
Memory strategy:
- Save fp16
- Only decision-point positions, not full sequence
"""
import torch
from typing import Dict, List, Optional
from configs.model import MODEL_CONFIG
class ResidualCapture:
"""
Hook decoder layers' output (post-layer residual stream).
For each forward pass, captured[layer_id] = (S, D) fp16 tensor.
"""
def __init__(self, model, target_layers: Optional[List[int]] = None):
self.model = model
self.target_layers = target_layers if target_layers else list(range(MODEL_CONFIG["num_layers"]))
self.handles = []
self.captured: Dict[int, torch.Tensor] = {}
def _make_hook(self, layer_id: int):
def fn(module, inputs, output):
# Decoder layer output: could be (hidden_states,) or (hidden_states, ...)
if isinstance(output, tuple):
h = output[0]
else:
h = output
# h shape: (B, S, D)
self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu()
return fn
def start(self):
self.captured = {}
for li in self.target_layers:
layer = self.model.model.layers[li]
h = layer.register_forward_hook(self._make_hook(li))
self.handles.append(h)
def stop(self):
for h in self.handles:
h.remove()
self.handles = []
return self.captured
def run_forward_and_capture_residuals(
model, tokenizer, text: str, target_layers: List[int]
) -> Dict:
"""
Run a forward pass and capture post-layer residual at target_layers only.
Returns {layer_id: (S, D) fp16 tensor}
"""
enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
input_ids = enc["input_ids"].to(model.device)
cap = ResidualCapture(model, target_layers=target_layers)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
residuals = cap.stop()
return {
"input_ids": input_ids.squeeze(0).cpu(),
"residuals": residuals,
}