NeuroCLR / pretraining /export_pretraining_to_hf.py
falmuqhim's picture
Upload folder using huggingface_hub
c319d57 verified
import torch
from configuration_neuroclr import NeuroCLRConfig
from modeling_neuroclr import NeuroCLRModel
# ---- EDIT these to match your training ----
CFG = dict(
TSlength=128,
nhead=2,
nlayer=2,
projector_out1=128,
projector_out2=64,
pooling="flatten", # because input is [B,1,128]
normalize_input=True,
)
CKPT_PATH = ""
OUT_DIR = "." # saves into pretraining/ folder
# ------------------------------------------
def remap_state_dict(sd):
new_sd = {}
for k, v in sd.items():
k2 = k.replace("module.", "") # if DDP ever used
if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
new_sd["neuroclr." + k2] = v
else:
# keep anything else as-is (usually none)
new_sd[k2] = v
return new_sd
def main():
config = NeuroCLRConfig(**CFG)
# This enables AutoModel loading from this folder
config.auto_map = {
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
"AutoModel": "modeling_neuroclr.NeuroCLRModel",
}
model = NeuroCLRModel(config)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
# Your checkpoint uses model_state_dict
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
sd = ckpt["model_state_dict"]
elif isinstance(ckpt, dict) and "state_dict" in ckpt:
sd = ckpt["state_dict"]
else:
sd = ckpt
sd = remap_state_dict(sd)
missing, unexpected = model.load_state_dict(sd, strict=False)
print("Missing:", missing)
print("Unexpected:", unexpected)
model.save_pretrained(OUT_DIR, safe_serialization=True)
print("Saved HF pretraining model to:", OUT_DIR)
if __name__ == "__main__":
main()