NeuroCLR / classification /export_classification_to_hf.py
falmuqhim's picture
Upload folder using huggingface_hub
c319d57 verified
import torch
from configuration_neuroclr import NeuroCLRConfig
from modeling_neuroclr import NeuroCLRForSequenceClassification
# -------- EDIT THESE PATHS + nhead if needed ----------
PRETRAIN_CKPT = ""
HEAD_CKPT = ""
OUT_DIR = "."
CFG = dict(
# encoder MUST match the pretrained export
TSlength=128,
nhead=2, # change if needed
nlayer=2, # we confirmed this from your pretraining ckpt
projector_out1=128,
projector_out2=64,
pooling="flatten",
normalize_input=True,
# classification
n_rois=200,
num_labels=2,
freeze_encoder=True, # encoder frozen by default
# ResNet1D head (your exact settings)
base_filters=256,
kernel_size=16,
stride=2,
groups=32,
n_block=48,
downsample_gap=6,
increasefilter_gap=12,
use_bn=True,
use_do=True,
)
# -----------------------------------------------------
def load_model_state_dict(path):
ckpt = torch.load(path, map_location="cpu")
if isinstance(ckpt, dict):
if "model_state_dict" in ckpt:
return ckpt["model_state_dict"]
if "state_dict" in ckpt:
return ckpt["state_dict"]
return ckpt
return ckpt
def remap_encoder(sd):
# pretraining ckpt keys: transformer_encoder.* and projector.*
new = {}
for k, v in sd.items():
k2 = k.replace("module.", "")
if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
new["encoder." + k2] = v
return new
def remap_head(sd):
# head ckpt keys likely start with first_block_conv.*, basicblock_list.*, dense.* etc.
new = {}
for k, v in sd.items():
k2 = k.replace("module.", "")
head_prefixes = (
"first_block_conv.", "first_block_bn.", "first_block_relu.",
"basicblock_list.", "final_bn.", "final_relu.", "dense."
)
if k2.startswith(head_prefixes):
new["head." + k2] = v
# If your checkpoint already has head.* then keep it
elif k2.startswith("head."):
new[k2] = v
return new
def main():
config = NeuroCLRConfig(**CFG)
# Enables HF auto-classes loading from this folder
config.auto_map = {
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
"AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification",
}
model = NeuroCLRForSequenceClassification(config)
# 1) Load encoder weights from pretraining ckpt
enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT)
enc_sd = remap_encoder(enc_sd_raw)
# 2) Load head weights from classification ckpt
head_sd_raw = load_model_state_dict(HEAD_CKPT)
head_sd = remap_head(head_sd_raw)
# 3) Merge and load
merged = {}
merged.update(enc_sd)
merged.update(head_sd)
missing, unexpected = model.load_state_dict(merged, strict=False)
print("Missing:", missing)
print("Unexpected:", unexpected)
# Save to HF folder
model.save_pretrained(OUT_DIR, safe_serialization=True)
print("Saved HF classification model to:", OUT_DIR)
if __name__ == "__main__":
main()