CEDL / examples /load_checkpoint.py
Jasonjiao2023's picture
Upload CEDL research checkpoint
4ab8d58 verified
Raw
History Blame Contribute Delete
3.06 kB
import argparse
import json
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from CEDL import build_model
def constructor_kwargs(config_path):
with open(config_path) as f:
cfg = json.load(f)
mem = cfg.get("memory_readout", {})
source_name = str(mem.get("source", "contextual_memory_state"))
source_map = {
"contextual_memory_state": "q_mem",
"decoder_state": "h_d",
"expanded_state": "h_e",
"attractor_state": "q_attractor",
"q_mem": "q_mem",
}
return dict(
lambda_head=bool(mem.get("lambda_head", True)),
lambda_head_hidden=int(mem.get("lambda_head_hidden", 160)),
lambda_head_bias_init=float(mem.get("lambda_head_bias_init", -7.0)),
lambda_head_w_init_std=float(
mem.get("lambda_head_w_init_std", 0.05)),
bce_objective=(
mem.get("selection_objective") == "binary_answer_background"),
sel_weight=1.0,
bg_weight=1.0,
bg_target=float(mem.get("background_target", 0.01)),
wt_sparsity_weight=float(mem.get("sparsity_weight", 0.05)),
wt_sparsity_target=float(mem.get("sparsity_target", 0.05)),
memory_head_enabled=bool(mem.get("enabled", True)),
memory_ce_weight=float(mem.get("memory_ce_weight", 1.0)),
memory_pair_ce_weight=float(mem.get("pair_ce_weight", 5.0)),
memory_query_source=source_map.get(source_name, source_name),
memory_readout_mode="direct",
source_adapter=bool(mem.get("source_adapter", True)),
context_adapter=bool(mem.get("context_adapter", True)),
specialist_noinject=bool(mem.get("no_injection", True)),
)
def unwrap_state_dict(obj):
if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict):
return obj["model"]
return obj
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="pytorch_model.bin")
parser.add_argument("--config", default="cedl_config.json")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
ckpt_path = Path(args.checkpoint)
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
cfg_path = Path(args.config)
if not cfg_path.exists():
raise FileNotFoundError(f"Config not found: {cfg_path}")
model = build_model(
"CEDL",
vocab=50257,
max_seq=1024,
**constructor_kwargs(cfg_path),
)
state = unwrap_state_dict(torch.load(ckpt_path, map_location="cpu"))
result = model.load_state_dict(state, strict=True)
model.to(args.device)
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"Loaded {ckpt_path}")
print(f"Parameters: {n_params:,}")
print(f"Missing keys: {len(result.missing_keys)}")
print(f"Unexpected keys: {len(result.unexpected_keys)}")
if __name__ == "__main__":
main()