SeqScreen-Frozen / convert_weights.py
gabrielbianchin's picture
upload files
c1ea99a
import os
import torch
from configuration_seqscreen import SeqScreenConfig
from modeling_seqscreen import SeqScreenModel
def convert_model(checkpoint_path, save_directory):
config = SeqScreenConfig()
hf_model = SeqScreenModel(config)
hf_model.eval()
old_state_dict = torch.load(checkpoint_path, map_location="cpu")
expected_prefixes = ("proj_prot.", "proj_mol.")
new_state_dict = {}
for key, value in old_state_dict.items():
if key.startswith(expected_prefixes):
new_state_dict[key] = value
else:
print(f"[Skip] {key}")
missing = set(hf_model.state_dict().keys()) - set(new_state_dict.keys())
unexpected = set(new_state_dict.keys()) - set(hf_model.state_dict().keys())
if missing:
raise RuntimeError(f"Missing keys in checkpoint: {missing}")
if unexpected:
raise RuntimeError(f"Unexpected keys after filtering: {unexpected}")
hf_model.load_state_dict(new_state_dict, strict=True)
print("State dict loaded successfully.")
os.makedirs(save_directory, exist_ok=True)
hf_model.save_pretrained(save_directory)
config.save_pretrained(save_directory)
print(f"Model saved to: {save_directory}")
if __name__ == "__main__":
convert_model(
checkpoint_path="model.pt",
save_directory="./seqscreen_hf",
)