import argparse import json import sys from pathlib import Path import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from CEDL import build_model def constructor_kwargs(config_path): with open(config_path) as f: cfg = json.load(f) mem = cfg.get("memory_readout", {}) source_name = str(mem.get("source", "contextual_memory_state")) source_map = { "contextual_memory_state": "q_mem", "decoder_state": "h_d", "expanded_state": "h_e", "attractor_state": "q_attractor", "q_mem": "q_mem", } return dict( lambda_head=bool(mem.get("lambda_head", True)), lambda_head_hidden=int(mem.get("lambda_head_hidden", 160)), lambda_head_bias_init=float(mem.get("lambda_head_bias_init", -7.0)), lambda_head_w_init_std=float( mem.get("lambda_head_w_init_std", 0.05)), bce_objective=( mem.get("selection_objective") == "binary_answer_background"), sel_weight=1.0, bg_weight=1.0, bg_target=float(mem.get("background_target", 0.01)), wt_sparsity_weight=float(mem.get("sparsity_weight", 0.05)), wt_sparsity_target=float(mem.get("sparsity_target", 0.05)), memory_head_enabled=bool(mem.get("enabled", True)), memory_ce_weight=float(mem.get("memory_ce_weight", 1.0)), memory_pair_ce_weight=float(mem.get("pair_ce_weight", 5.0)), memory_query_source=source_map.get(source_name, source_name), memory_readout_mode="direct", source_adapter=bool(mem.get("source_adapter", True)), context_adapter=bool(mem.get("context_adapter", True)), specialist_noinject=bool(mem.get("no_injection", True)), ) def unwrap_state_dict(obj): if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict): return obj["model"] return obj def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="pytorch_model.bin") parser.add_argument("--config", default="cedl_config.json") parser.add_argument("--device", default="cpu") args = parser.parse_args() ckpt_path = Path(args.checkpoint) if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") cfg_path = Path(args.config) if not cfg_path.exists(): raise FileNotFoundError(f"Config not found: {cfg_path}") model = build_model( "CEDL", vocab=50257, max_seq=1024, **constructor_kwargs(cfg_path), ) state = unwrap_state_dict(torch.load(ckpt_path, map_location="cpu")) result = model.load_state_dict(state, strict=True) model.to(args.device) model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f"Loaded {ckpt_path}") print(f"Parameters: {n_params:,}") print(f"Missing keys: {len(result.missing_keys)}") print(f"Unexpected keys: {len(result.unexpected_keys)}") if __name__ == "__main__": main()