Spaces:
Paused
Paused
File size: 4,710 Bytes
6f48db0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Tuned Lens Runtime — load and apply per-layer affine probes for improved
intermediate-layer predictions.
Each probe applies a learned linear correction A_l(x) = x @ W_l^T + b_l
(initialised to identity + zero during training) that is trained to minimise
KL divergence between the corrected layer's predictions and the model's
final-layer predictions.
See scripts/train_tuned_lens.py for the training pipeline.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
TUNED_LENS_DIR = os.environ.get("TUNED_LENS_DIR", "./tuned_lens_weights")
class TunedLensRuntime:
"""Load, cache, and apply per-layer affine probes at inference time."""
def __init__(self):
self._probes: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
self._metadata: Optional[dict] = None
self._available = False
@property
def available(self) -> bool:
return self._available
def load(self, model_id: str, device: torch.device, dtype: torch.dtype,
weights_dir: Optional[str] = None) -> bool:
"""Load tuned lens checkpoint for *model_id*.
Returns True if weights were loaded successfully, False otherwise.
Failure is non-fatal — the system falls back to raw logit lens.
"""
base_dir = Path(weights_dir or TUNED_LENS_DIR)
model_dir = base_dir / model_id
if not model_dir.exists():
logger.info(f"Tuned lens: no weights directory for {model_id} at {model_dir}")
return False
# Find the checkpoint — pick the first .pt file
pt_files = sorted(model_dir.glob("tuned_lens_*.pt"))
if not pt_files:
logger.info(f"Tuned lens: no .pt checkpoint found in {model_dir}")
return False
checkpoint_path = pt_files[0]
metadata_path = model_dir / "metadata.json"
try:
# Load metadata
if metadata_path.exists():
with open(metadata_path, "r") as f:
self._metadata = json.load(f)
logger.info(f"Tuned lens: metadata loaded — {self._metadata.get('n_layers')} layers, "
f"d_model={self._metadata.get('d_model')}")
else:
self._metadata = {}
# Load state dict
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Parse layer_N.weight / layer_N.bias entries
self._probes = {}
layer_indices = set()
for key in state_dict:
parts = key.split(".")
if len(parts) == 2 and parts[0].startswith("layer_"):
idx = int(parts[0].split("_")[1])
layer_indices.add(idx)
for idx in sorted(layer_indices):
w_key = f"layer_{idx}.weight"
b_key = f"layer_{idx}.bias"
if w_key in state_dict and b_key in state_dict:
weight = state_dict[w_key].to(device=device, dtype=dtype)
bias = state_dict[b_key].to(device=device, dtype=dtype)
self._probes[idx] = (weight, bias)
if not self._probes:
logger.warning(f"Tuned lens: checkpoint loaded but no layer probes found")
return False
self._available = True
logger.info(f"Tuned lens: loaded {len(self._probes)} layer probes from {checkpoint_path} "
f"(device={device}, dtype={dtype})")
return True
except Exception as e:
logger.warning(f"Tuned lens: failed to load checkpoint — {e}")
self._probes = {}
self._metadata = None
self._available = False
return False
def apply(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor:
"""Apply the affine probe for *layer_idx*: hidden @ W^T + b.
If no probe exists for this layer, returns the hidden state unchanged
(identity fallback).
"""
if layer_idx not in self._probes:
return hidden_state
weight, bias = self._probes[layer_idx]
return hidden_state @ weight.T + bias
def get_info(self) -> dict:
"""Return metadata dict for health/debug endpoints."""
return {
"available": self._available,
"num_probes": len(self._probes),
"layer_indices": sorted(self._probes.keys()),
"metadata": self._metadata or {},
}
# Global singleton
tuned_lens_runtime = TunedLensRuntime()
|