import torch from configuration_neuroclr import NeuroCLRConfig from modeling_neuroclr import NeuroCLRModel # ---- EDIT these to match your training ---- CFG = dict( TSlength=128, nhead=2, nlayer=2, projector_out1=128, projector_out2=64, pooling="flatten", # because input is [B,1,128] normalize_input=True, ) CKPT_PATH = "" OUT_DIR = "." # saves into pretraining/ folder # ------------------------------------------ def remap_state_dict(sd): new_sd = {} for k, v in sd.items(): k2 = k.replace("module.", "") # if DDP ever used if k2.startswith("transformer_encoder.") or k2.startswith("projector."): new_sd["neuroclr." + k2] = v else: # keep anything else as-is (usually none) new_sd[k2] = v return new_sd def main(): config = NeuroCLRConfig(**CFG) # This enables AutoModel loading from this folder config.auto_map = { "AutoConfig": "configuration_neuroclr.NeuroCLRConfig", "AutoModel": "modeling_neuroclr.NeuroCLRModel", } model = NeuroCLRModel(config) ckpt = torch.load(CKPT_PATH, map_location="cpu") # Your checkpoint uses model_state_dict if isinstance(ckpt, dict) and "model_state_dict" in ckpt: sd = ckpt["model_state_dict"] elif isinstance(ckpt, dict) and "state_dict" in ckpt: sd = ckpt["state_dict"] else: sd = ckpt sd = remap_state_dict(sd) missing, unexpected = model.load_state_dict(sd, strict=False) print("Missing:", missing) print("Unexpected:", unexpected) model.save_pretrained(OUT_DIR, safe_serialization=True) print("Saved HF pretraining model to:", OUT_DIR) if __name__ == "__main__": main()