import os import torch from configuration_seqscreen import SeqScreenConfig from modeling_seqscreen import SeqScreenModel def convert_model(checkpoint_path, save_directory): config = SeqScreenConfig() hf_model = SeqScreenModel(config) hf_model.eval() old_state_dict = torch.load(checkpoint_path, map_location="cpu") expected_prefixes = ("proj_prot.", "proj_mol.") new_state_dict = {} for key, value in old_state_dict.items(): if key.startswith(expected_prefixes): new_state_dict[key] = value else: print(f"[Skip] {key}") missing = set(hf_model.state_dict().keys()) - set(new_state_dict.keys()) unexpected = set(new_state_dict.keys()) - set(hf_model.state_dict().keys()) if missing: raise RuntimeError(f"Missing keys in checkpoint: {missing}") if unexpected: raise RuntimeError(f"Unexpected keys after filtering: {unexpected}") hf_model.load_state_dict(new_state_dict, strict=True) print("State dict loaded successfully.") os.makedirs(save_directory, exist_ok=True) hf_model.save_pretrained(save_directory) config.save_pretrained(save_directory) print(f"Model saved to: {save_directory}") if __name__ == "__main__": convert_model( checkpoint_path="model.pt", save_directory="./seqscreen_hf", )