| | import torch |
| | from configuration_neuroclr import NeuroCLRConfig |
| | from modeling_neuroclr import NeuroCLRForSequenceClassification |
| |
|
| | |
| | PRETRAIN_CKPT = "" |
| | HEAD_CKPT = "" |
| | OUT_DIR = "." |
| |
|
| | CFG = dict( |
| | |
| | TSlength=128, |
| | nhead=2, |
| | nlayer=2, |
| | projector_out1=128, |
| | projector_out2=64, |
| | pooling="flatten", |
| | normalize_input=True, |
| |
|
| | |
| | n_rois=200, |
| | num_labels=2, |
| | freeze_encoder=True, |
| |
|
| | |
| | base_filters=256, |
| | kernel_size=16, |
| | stride=2, |
| | groups=32, |
| | n_block=48, |
| | downsample_gap=6, |
| | increasefilter_gap=12, |
| | use_bn=True, |
| | use_do=True, |
| | ) |
| | |
| |
|
| | def load_model_state_dict(path): |
| | ckpt = torch.load(path, map_location="cpu") |
| | if isinstance(ckpt, dict): |
| | if "model_state_dict" in ckpt: |
| | return ckpt["model_state_dict"] |
| | if "state_dict" in ckpt: |
| | return ckpt["state_dict"] |
| | return ckpt |
| | return ckpt |
| |
|
| | def remap_encoder(sd): |
| | |
| | new = {} |
| | for k, v in sd.items(): |
| | k2 = k.replace("module.", "") |
| | if k2.startswith("transformer_encoder.") or k2.startswith("projector."): |
| | new["encoder." + k2] = v |
| | return new |
| |
|
| | def remap_head(sd): |
| | |
| | new = {} |
| | for k, v in sd.items(): |
| | k2 = k.replace("module.", "") |
| |
|
| | head_prefixes = ( |
| | "first_block_conv.", "first_block_bn.", "first_block_relu.", |
| | "basicblock_list.", "final_bn.", "final_relu.", "dense." |
| | ) |
| | if k2.startswith(head_prefixes): |
| | new["head." + k2] = v |
| |
|
| | |
| | elif k2.startswith("head."): |
| | new[k2] = v |
| |
|
| | return new |
| |
|
| | def main(): |
| | config = NeuroCLRConfig(**CFG) |
| |
|
| | |
| | config.auto_map = { |
| | "AutoConfig": "configuration_neuroclr.NeuroCLRConfig", |
| | "AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification", |
| | } |
| |
|
| | model = NeuroCLRForSequenceClassification(config) |
| |
|
| | |
| | enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT) |
| | enc_sd = remap_encoder(enc_sd_raw) |
| |
|
| | |
| | head_sd_raw = load_model_state_dict(HEAD_CKPT) |
| | head_sd = remap_head(head_sd_raw) |
| |
|
| | |
| | merged = {} |
| | merged.update(enc_sd) |
| | merged.update(head_sd) |
| |
|
| | missing, unexpected = model.load_state_dict(merged, strict=False) |
| | print("Missing:", missing) |
| | print("Unexpected:", unexpected) |
| |
|
| | |
| | model.save_pretrained(OUT_DIR, safe_serialization=True) |
| | print("Saved HF classification model to:", OUT_DIR) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|