gabrielbianchin commited on
Commit
c6401f2
·
verified ·
1 Parent(s): 1b28a8c

Delete convert_weights.py

Browse files
Files changed (1) hide show
  1. convert_weights.py +0 -81
convert_weights.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- from config import BBBConfig
3
- from modeling import BBBModelForSequenceClassification
4
- import os
5
-
6
- BASE_ARCH_PARAMS = {
7
- "gnn_layers": 4,
8
- "gnn_hidden": 512,
9
- "dropout": 0.1,
10
- "num_heads": 8,
11
- "input_dim": 768,
12
- "num_features": 2265,
13
- "proj_dim": 1024,
14
- "neurons_fc": [256],
15
- "num_labels": 1
16
- }
17
-
18
- def convert_model(checkpoint_path: str, task_name: str, problem_type: str, save_directory: str):
19
- config_params = BASE_ARCH_PARAMS.copy()
20
- config_params["task"] = task_name
21
- config_params["problem_type"] = problem_type
22
-
23
- config = BBBConfig(**config_params)
24
-
25
- print(f"Instantiating HF model")
26
- hf_model = BBBModelForSequenceClassification(config)
27
- hf_model.eval()
28
-
29
- if not os.path.exists(checkpoint_path):
30
- return
31
-
32
- print(f"Loading original checkpoint from {checkpoint_path}")
33
- old_state_dict = torch.load(checkpoint_path, map_location="cpu")
34
-
35
- new_state_dict = {}
36
- for key, value in old_state_dict.items():
37
- if key.startswith("gats.") or \
38
- key.startswith("bns.") or \
39
- key.startswith("proj_gnn.") or \
40
- key.startswith("proj_feat."):
41
-
42
- # Add the "base_model." prefix
43
- new_key = "base_model." + key
44
- new_state_dict[new_key] = value
45
-
46
- elif key.startswith("fc."):
47
- # The 'fc' layer is already in the correct place
48
- new_state_dict[key] = value
49
-
50
- else:
51
- print(f"[Warning] Unmapped key found: {key}")
52
- new_state_dict[key] = value
53
-
54
- print("State dict key names adjusted.")
55
-
56
- try:
57
- hf_model.load_state_dict(new_state_dict, strict=True)
58
- print("State dict loaded successfully into HF")
59
- except RuntimeError as e:
60
- print("\n--- ERROR LOADING STATE DICT ---")
61
- print("Verify that the parameters in BASE_ARCH_PARAMS are correct.")
62
- print(e)
63
- return
64
-
65
- print(f"Saving HF-formatted model to {save_directory}")
66
- hf_model.save_pretrained(save_directory)
67
-
68
- if __name__ == "__main__":
69
- convert_model(
70
- checkpoint_path="./model_classification.pth",
71
- task_name="classification",
72
- problem_type="single_label_classification",
73
- save_directory="./classification"
74
- )
75
-
76
- convert_model(
77
- checkpoint_path="./model_regression.pth",
78
- task_name="regression",
79
- problem_type="regression",
80
- save_directory="./regression"
81
- )