File size: 1,739 Bytes
c319d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from configuration_neuroclr import NeuroCLRConfig
from modeling_neuroclr import NeuroCLRModel
# ---- EDIT these to match your training ----
CFG = dict(
TSlength=128,
nhead=2,
nlayer=2,
projector_out1=128,
projector_out2=64,
pooling="flatten", # because input is [B,1,128]
normalize_input=True,
)
CKPT_PATH = ""
OUT_DIR = "." # saves into pretraining/ folder
# ------------------------------------------
def remap_state_dict(sd):
new_sd = {}
for k, v in sd.items():
k2 = k.replace("module.", "") # if DDP ever used
if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
new_sd["neuroclr." + k2] = v
else:
# keep anything else as-is (usually none)
new_sd[k2] = v
return new_sd
def main():
config = NeuroCLRConfig(**CFG)
# This enables AutoModel loading from this folder
config.auto_map = {
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
"AutoModel": "modeling_neuroclr.NeuroCLRModel",
}
model = NeuroCLRModel(config)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
# Your checkpoint uses model_state_dict
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
sd = ckpt["model_state_dict"]
elif isinstance(ckpt, dict) and "state_dict" in ckpt:
sd = ckpt["state_dict"]
else:
sd = ckpt
sd = remap_state_dict(sd)
missing, unexpected = model.load_state_dict(sd, strict=False)
print("Missing:", missing)
print("Unexpected:", unexpected)
model.save_pretrained(OUT_DIR, safe_serialization=True)
print("Saved HF pretraining model to:", OUT_DIR)
if __name__ == "__main__":
main()
|