|
|
import torch |
|
|
from configuration_neuroclr import NeuroCLRConfig |
|
|
from modeling_neuroclr import NeuroCLRForSequenceClassification |
|
|
|
|
|
|
|
|
PRETRAIN_CKPT = "" |
|
|
HEAD_CKPT = "" |
|
|
OUT_DIR = "." |
|
|
|
|
|
CFG = dict( |
|
|
|
|
|
TSlength=128, |
|
|
nhead=2, |
|
|
nlayer=2, |
|
|
projector_out1=128, |
|
|
projector_out2=64, |
|
|
pooling="flatten", |
|
|
normalize_input=True, |
|
|
|
|
|
|
|
|
n_rois=200, |
|
|
num_labels=2, |
|
|
freeze_encoder=True, |
|
|
|
|
|
|
|
|
base_filters=256, |
|
|
kernel_size=16, |
|
|
stride=2, |
|
|
groups=32, |
|
|
n_block=48, |
|
|
downsample_gap=6, |
|
|
increasefilter_gap=12, |
|
|
use_bn=True, |
|
|
use_do=True, |
|
|
) |
|
|
|
|
|
|
|
|
def load_model_state_dict(path): |
|
|
ckpt = torch.load(path, map_location="cpu") |
|
|
if isinstance(ckpt, dict): |
|
|
if "model_state_dict" in ckpt: |
|
|
return ckpt["model_state_dict"] |
|
|
if "state_dict" in ckpt: |
|
|
return ckpt["state_dict"] |
|
|
return ckpt |
|
|
return ckpt |
|
|
|
|
|
def remap_encoder(sd): |
|
|
|
|
|
new = {} |
|
|
for k, v in sd.items(): |
|
|
k2 = k.replace("module.", "") |
|
|
if k2.startswith("transformer_encoder.") or k2.startswith("projector."): |
|
|
new["encoder." + k2] = v |
|
|
return new |
|
|
|
|
|
def remap_head(sd): |
|
|
|
|
|
new = {} |
|
|
for k, v in sd.items(): |
|
|
k2 = k.replace("module.", "") |
|
|
|
|
|
head_prefixes = ( |
|
|
"first_block_conv.", "first_block_bn.", "first_block_relu.", |
|
|
"basicblock_list.", "final_bn.", "final_relu.", "dense." |
|
|
) |
|
|
if k2.startswith(head_prefixes): |
|
|
new["head." + k2] = v |
|
|
|
|
|
|
|
|
elif k2.startswith("head."): |
|
|
new[k2] = v |
|
|
|
|
|
return new |
|
|
|
|
|
def main(): |
|
|
config = NeuroCLRConfig(**CFG) |
|
|
|
|
|
|
|
|
config.auto_map = { |
|
|
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig", |
|
|
"AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification", |
|
|
} |
|
|
|
|
|
model = NeuroCLRForSequenceClassification(config) |
|
|
|
|
|
|
|
|
enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT) |
|
|
enc_sd = remap_encoder(enc_sd_raw) |
|
|
|
|
|
|
|
|
head_sd_raw = load_model_state_dict(HEAD_CKPT) |
|
|
head_sd = remap_head(head_sd_raw) |
|
|
|
|
|
|
|
|
merged = {} |
|
|
merged.update(enc_sd) |
|
|
merged.update(head_sd) |
|
|
|
|
|
missing, unexpected = model.load_state_dict(merged, strict=False) |
|
|
print("Missing:", missing) |
|
|
print("Unexpected:", unexpected) |
|
|
|
|
|
|
|
|
model.save_pretrained(OUT_DIR, safe_serialization=True) |
|
|
print("Saved HF classification model to:", OUT_DIR) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|