File size: 1,810 Bytes
3a19a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from configuration_bbb import BBBConfig
from modeling_bbb import BBBModelForSequenceClassification
import os

BASE_ARCH_PARAMS = {
  "d_tab": 384,
  "d_img": 2048,
  "d_txt": 768,
  "proj_dim": 2048
}

def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str):  
  config_params = BASE_ARCH_PARAMS.copy()
  config_params["task"] = task_name
  config_params["problem_type"] = problem_type
  config_params["dropout"] = dropout

  config = BBBConfig(**config_params)
  
  hf_model = BBBModelForSequenceClassification(config)
  hf_model.eval()

  if not os.path.exists(checkpoint_path):
    return
    
  old_state_dict = torch.load(checkpoint_path, map_location="cpu")

  new_state_dict = {}
  for key, value in old_state_dict.items():
    if key.startswith("proj") or key.startswith("attention_pooling"):
      new_state_dict[key] = value
      
    elif key.startswith("classifier."):
      # The 'fc' layer is already in the correct place
      new_state_dict[key] = value
      
    else:
      print(f"[Warning] Unmapped key found: {key}")
      new_state_dict[key] = value

  print("State dict key names adjusted.")

  try:
    hf_model.load_state_dict(new_state_dict, strict=True)
    print("State dict loaded successfully into HF")
  except RuntimeError as e:
    print("\n--- ERROR LOADING STATE DICT ---")
    print("Verify that the parameters in BASE_ARCH_PARAMS are correct.")
    print(e)
    return

  print(f"Saving HF-formatted model to {save_directory}")
  hf_model.save_pretrained(save_directory)

if __name__ == "__main__":
  convert_model(
    checkpoint_path="model_classification.pth",
    task_name="classification",
    dropout=0.1,
    problem_type="single_label_classification",
    save_directory="./classification"
  )