import torch from configuration_neuroclr import NeuroCLRConfig from modeling_neuroclr import NeuroCLRForSequenceClassification # -------- EDIT THESE PATHS + nhead if needed ---------- PRETRAIN_CKPT = "" HEAD_CKPT = "" OUT_DIR = "." CFG = dict( # encoder MUST match the pretrained export TSlength=128, nhead=2, # change if needed nlayer=2, # we confirmed this from your pretraining ckpt projector_out1=128, projector_out2=64, pooling="flatten", normalize_input=True, # classification n_rois=200, num_labels=2, freeze_encoder=True, # encoder frozen by default # ResNet1D head (your exact settings) base_filters=256, kernel_size=16, stride=2, groups=32, n_block=48, downsample_gap=6, increasefilter_gap=12, use_bn=True, use_do=True, ) # ----------------------------------------------------- def load_model_state_dict(path): ckpt = torch.load(path, map_location="cpu") if isinstance(ckpt, dict): if "model_state_dict" in ckpt: return ckpt["model_state_dict"] if "state_dict" in ckpt: return ckpt["state_dict"] return ckpt return ckpt def remap_encoder(sd): # pretraining ckpt keys: transformer_encoder.* and projector.* new = {} for k, v in sd.items(): k2 = k.replace("module.", "") if k2.startswith("transformer_encoder.") or k2.startswith("projector."): new["encoder." + k2] = v return new def remap_head(sd): # head ckpt keys likely start with first_block_conv.*, basicblock_list.*, dense.* etc. new = {} for k, v in sd.items(): k2 = k.replace("module.", "") head_prefixes = ( "first_block_conv.", "first_block_bn.", "first_block_relu.", "basicblock_list.", "final_bn.", "final_relu.", "dense." ) if k2.startswith(head_prefixes): new["head." + k2] = v # If your checkpoint already has head.* then keep it elif k2.startswith("head."): new[k2] = v return new def main(): config = NeuroCLRConfig(**CFG) # Enables HF auto-classes loading from this folder config.auto_map = { "AutoConfig": "configuration_neuroclr.NeuroCLRConfig", "AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification", } model = NeuroCLRForSequenceClassification(config) # 1) Load encoder weights from pretraining ckpt enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT) enc_sd = remap_encoder(enc_sd_raw) # 2) Load head weights from classification ckpt head_sd_raw = load_model_state_dict(HEAD_CKPT) head_sd = remap_head(head_sd_raw) # 3) Merge and load merged = {} merged.update(enc_sd) merged.update(head_sd) missing, unexpected = model.load_state_dict(merged, strict=False) print("Missing:", missing) print("Unexpected:", unexpected) # Save to HF folder model.save_pretrained(OUT_DIR, safe_serialization=True) print("Saved HF classification model to:", OUT_DIR) if __name__ == "__main__": main()