| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from CEDL import build_model |
|
|
|
|
| def constructor_kwargs(config_path): |
| with open(config_path) as f: |
| cfg = json.load(f) |
| mem = cfg.get("memory_readout", {}) |
| source_name = str(mem.get("source", "contextual_memory_state")) |
| source_map = { |
| "contextual_memory_state": "q_mem", |
| "decoder_state": "h_d", |
| "expanded_state": "h_e", |
| "attractor_state": "q_attractor", |
| "q_mem": "q_mem", |
| } |
| return dict( |
| lambda_head=bool(mem.get("lambda_head", True)), |
| lambda_head_hidden=int(mem.get("lambda_head_hidden", 160)), |
| lambda_head_bias_init=float(mem.get("lambda_head_bias_init", -7.0)), |
| lambda_head_w_init_std=float( |
| mem.get("lambda_head_w_init_std", 0.05)), |
| bce_objective=( |
| mem.get("selection_objective") == "binary_answer_background"), |
| sel_weight=1.0, |
| bg_weight=1.0, |
| bg_target=float(mem.get("background_target", 0.01)), |
| wt_sparsity_weight=float(mem.get("sparsity_weight", 0.05)), |
| wt_sparsity_target=float(mem.get("sparsity_target", 0.05)), |
| memory_head_enabled=bool(mem.get("enabled", True)), |
| memory_ce_weight=float(mem.get("memory_ce_weight", 1.0)), |
| memory_pair_ce_weight=float(mem.get("pair_ce_weight", 5.0)), |
| memory_query_source=source_map.get(source_name, source_name), |
| memory_readout_mode="direct", |
| source_adapter=bool(mem.get("source_adapter", True)), |
| context_adapter=bool(mem.get("context_adapter", True)), |
| specialist_noinject=bool(mem.get("no_injection", True)), |
| ) |
|
|
|
|
| def unwrap_state_dict(obj): |
| if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict): |
| return obj["model"] |
| return obj |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="pytorch_model.bin") |
| parser.add_argument("--config", default="cedl_config.json") |
| parser.add_argument("--device", default="cpu") |
| args = parser.parse_args() |
|
|
| ckpt_path = Path(args.checkpoint) |
| if not ckpt_path.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") |
| cfg_path = Path(args.config) |
| if not cfg_path.exists(): |
| raise FileNotFoundError(f"Config not found: {cfg_path}") |
|
|
| model = build_model( |
| "CEDL", |
| vocab=50257, |
| max_seq=1024, |
| **constructor_kwargs(cfg_path), |
| ) |
|
|
| state = unwrap_state_dict(torch.load(ckpt_path, map_location="cpu")) |
| result = model.load_state_dict(state, strict=True) |
| model.to(args.device) |
| model.eval() |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Loaded {ckpt_path}") |
| print(f"Parameters: {n_params:,}") |
| print(f"Missing keys: {len(result.missing_keys)}") |
| print(f"Unexpected keys: {len(result.unexpected_keys)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|