|
|
|
|
|
|
|
|
""" |
|
|
Universal Checkpoint Loader for ASA Models |
|
|
|
|
|
Loads checkpoints into either training or analysis harness. |
|
|
|
|
|
Repository: https://github.com/DigitalDaimyo/AddressedStateAttention |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from typing import Literal, Tuple, Dict, Any |
|
|
|
|
|
|
|
|
__all__ = ['load_asm_checkpoint'] |
|
|
|
|
|
|
|
|
def load_asm_checkpoint( |
|
|
checkpoint_path: str, |
|
|
mode: Literal["train", "analysis"] = "train", |
|
|
device: str = None |
|
|
) -> Tuple[Any, Any, Dict]: |
|
|
""" |
|
|
Universal ASM checkpoint loader. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to .pt checkpoint file |
|
|
mode: "train" (efficient) or "analysis" (intervention harness) |
|
|
device: Device to load on (defaults to cuda if available) |
|
|
|
|
|
Returns: |
|
|
model: Loaded ASMLanguageModel |
|
|
cfg: ASMTrainConfig object |
|
|
ckpt: Full checkpoint dict (for step, loss metadata) |
|
|
|
|
|
Example: |
|
|
>>> model, cfg, ckpt = load_asm_checkpoint( |
|
|
... "best.pt", mode="analysis", device="cuda" |
|
|
... ) |
|
|
>>> print(f"Step {ckpt['step']}, Loss {ckpt['val_loss']:.3f}") |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
ckpt = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
cfg_dict = ckpt.get("cfg") |
|
|
if cfg_dict is None: |
|
|
raise KeyError(f"Missing 'cfg' key. Available: {list(ckpt.keys())}") |
|
|
|
|
|
|
|
|
if mode == "train": |
|
|
from .training import ASMTrainConfig, build_model_from_cfg |
|
|
else: |
|
|
from .analysis import ASMTrainConfig, build_model_from_cfg |
|
|
|
|
|
|
|
|
cfg = ASMTrainConfig(**cfg_dict) |
|
|
model = build_model_from_cfg(cfg) |
|
|
|
|
|
|
|
|
state_dict = ckpt.get("model") |
|
|
if state_dict is None: |
|
|
raise KeyError(f"Missing 'model' key. Available: {list(ckpt.keys())}") |
|
|
|
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
if missing: |
|
|
print(f"⚠ Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f"⚠ Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
model = model.to(device).eval() |
|
|
|
|
|
return model, cfg, ckpt |
|
|
|