File size: 4,710 Bytes
6f48db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Tuned Lens Runtime — load and apply per-layer affine probes for improved
intermediate-layer predictions.

Each probe applies a learned linear correction A_l(x) = x @ W_l^T + b_l
(initialised to identity + zero during training) that is trained to minimise
KL divergence between the corrected layer's predictions and the model's
final-layer predictions.

See scripts/train_tuned_lens.py for the training pipeline.
"""

import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn

logger = logging.getLogger(__name__)

TUNED_LENS_DIR = os.environ.get("TUNED_LENS_DIR", "./tuned_lens_weights")


class TunedLensRuntime:
    """Load, cache, and apply per-layer affine probes at inference time."""

    def __init__(self):
        self._probes: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
        self._metadata: Optional[dict] = None
        self._available = False

    @property
    def available(self) -> bool:
        return self._available

    def load(self, model_id: str, device: torch.device, dtype: torch.dtype,
             weights_dir: Optional[str] = None) -> bool:
        """Load tuned lens checkpoint for *model_id*.

        Returns True if weights were loaded successfully, False otherwise.
        Failure is non-fatal — the system falls back to raw logit lens.
        """
        base_dir = Path(weights_dir or TUNED_LENS_DIR)
        model_dir = base_dir / model_id

        if not model_dir.exists():
            logger.info(f"Tuned lens: no weights directory for {model_id} at {model_dir}")
            return False

        # Find the checkpoint — pick the first .pt file
        pt_files = sorted(model_dir.glob("tuned_lens_*.pt"))
        if not pt_files:
            logger.info(f"Tuned lens: no .pt checkpoint found in {model_dir}")
            return False

        checkpoint_path = pt_files[0]
        metadata_path = model_dir / "metadata.json"

        try:
            # Load metadata
            if metadata_path.exists():
                with open(metadata_path, "r") as f:
                    self._metadata = json.load(f)
                logger.info(f"Tuned lens: metadata loaded — {self._metadata.get('n_layers')} layers, "
                            f"d_model={self._metadata.get('d_model')}")
            else:
                self._metadata = {}

            # Load state dict
            state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

            # Parse layer_N.weight / layer_N.bias entries
            self._probes = {}
            layer_indices = set()
            for key in state_dict:
                parts = key.split(".")
                if len(parts) == 2 and parts[0].startswith("layer_"):
                    idx = int(parts[0].split("_")[1])
                    layer_indices.add(idx)

            for idx in sorted(layer_indices):
                w_key = f"layer_{idx}.weight"
                b_key = f"layer_{idx}.bias"
                if w_key in state_dict and b_key in state_dict:
                    weight = state_dict[w_key].to(device=device, dtype=dtype)
                    bias = state_dict[b_key].to(device=device, dtype=dtype)
                    self._probes[idx] = (weight, bias)

            if not self._probes:
                logger.warning(f"Tuned lens: checkpoint loaded but no layer probes found")
                return False

            self._available = True
            logger.info(f"Tuned lens: loaded {len(self._probes)} layer probes from {checkpoint_path} "
                        f"(device={device}, dtype={dtype})")
            return True

        except Exception as e:
            logger.warning(f"Tuned lens: failed to load checkpoint — {e}")
            self._probes = {}
            self._metadata = None
            self._available = False
            return False

    def apply(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor:
        """Apply the affine probe for *layer_idx*: hidden @ W^T + b.

        If no probe exists for this layer, returns the hidden state unchanged
        (identity fallback).
        """
        if layer_idx not in self._probes:
            return hidden_state
        weight, bias = self._probes[layer_idx]
        return hidden_state @ weight.T + bias

    def get_info(self) -> dict:
        """Return metadata dict for health/debug endpoints."""
        return {
            "available": self._available,
            "num_probes": len(self._probes),
            "layer_indices": sorted(self._probes.keys()),
            "metadata": self._metadata or {},
        }


# Global singleton
tuned_lens_runtime = TunedLensRuntime()