File size: 1,276 Bytes
c1ea99a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import torch
from configuration_seqscreen import SeqScreenConfig
from modeling_seqscreen import SeqScreenModel


def convert_model(checkpoint_path, save_directory):
  config = SeqScreenConfig()
  hf_model = SeqScreenModel(config)
  hf_model.eval()

  old_state_dict = torch.load(checkpoint_path, map_location="cpu")

  expected_prefixes = ("proj_prot.", "proj_mol.")

  new_state_dict = {}
  for key, value in old_state_dict.items():
    if key.startswith(expected_prefixes):
      new_state_dict[key] = value
    else:
      print(f"[Skip] {key}")

  missing = set(hf_model.state_dict().keys()) - set(new_state_dict.keys())
  unexpected = set(new_state_dict.keys()) - set(hf_model.state_dict().keys())

  if missing:
    raise RuntimeError(f"Missing keys in checkpoint: {missing}")
  if unexpected:
    raise RuntimeError(f"Unexpected keys after filtering: {unexpected}")

  hf_model.load_state_dict(new_state_dict, strict=True)
  print("State dict loaded successfully.")

  os.makedirs(save_directory, exist_ok=True)
  hf_model.save_pretrained(save_directory)
  config.save_pretrained(save_directory)
  print(f"Model saved to: {save_directory}")


if __name__ == "__main__":
  convert_model(
    checkpoint_path="model.pt",
    save_directory="./seqscreen_hf",
  )