DigitalDaimyo commited on
Commit
b14dbfc
·
verified ·
1 Parent(s): 9aaf11e

Upload universal_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. universal_loader.py +76 -0
universal_loader.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ """
4
+ Universal Checkpoint Loader for ASA Models
5
+
6
+ Loads checkpoints into either training or analysis harness.
7
+
8
+ Repository: https://github.com/DigitalDaimyo/AddressedStateAttention
9
+ """
10
+
11
+ import torch
12
+ from typing import Literal, Tuple, Dict, Any
13
+
14
+
15
+ __all__ = ['load_asm_checkpoint']
16
+
17
+
18
+ def load_asm_checkpoint(
19
+ checkpoint_path: str,
20
+ mode: Literal["train", "analysis"] = "train",
21
+ device: str = None
22
+ ) -> Tuple[Any, Any, Dict]:
23
+ """
24
+ Universal ASM checkpoint loader.
25
+
26
+ Args:
27
+ checkpoint_path: Path to .pt checkpoint file
28
+ mode: "train" (efficient) or "analysis" (intervention harness)
29
+ device: Device to load on (defaults to cuda if available)
30
+
31
+ Returns:
32
+ model: Loaded ASMLanguageModel
33
+ cfg: ASMTrainConfig object
34
+ ckpt: Full checkpoint dict (for step, loss metadata)
35
+
36
+ Example:
37
+ >>> model, cfg, ckpt = load_asm_checkpoint(
38
+ ... "best.pt", mode="analysis", device="cuda"
39
+ ... )
40
+ >>> print(f"Step {ckpt['step']}, Loss {ckpt['val_loss']:.3f}")
41
+ """
42
+
43
+ if device is None:
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+
46
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
47
+
48
+ cfg_dict = ckpt.get("cfg")
49
+ if cfg_dict is None:
50
+ raise KeyError(f"Missing 'cfg' key. Available: {list(ckpt.keys())}")
51
+
52
+ # Import appropriate harness
53
+ if mode == "train":
54
+ from training import ASMTrainConfig, build_model_from_cfg
55
+ else: # analysis
56
+ from analysis import ASMTrainConfig, build_model_from_cfg
57
+
58
+ # Build model using helper
59
+ cfg = ASMTrainConfig(**cfg_dict)
60
+ model = build_model_from_cfg(cfg)
61
+
62
+ # Load weights
63
+ state_dict = ckpt.get("model")
64
+ if state_dict is None:
65
+ raise KeyError(f"Missing 'model' key. Available: {list(ckpt.keys())}")
66
+
67
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
68
+
69
+ if missing:
70
+ print(f"⚠ Missing keys: {len(missing)}")
71
+ if unexpected:
72
+ print(f"⚠ Unexpected keys: {len(unexpected)}")
73
+
74
+ model = model.to(device).eval()
75
+
76
+ return model, cfg, ckpt