File size: 1,810 Bytes
3a19a3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
from configuration_bbb import BBBConfig
from modeling_bbb import BBBModelForSequenceClassification
import os
BASE_ARCH_PARAMS = {
"d_tab": 384,
"d_img": 2048,
"d_txt": 768,
"proj_dim": 2048
}
def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str):
config_params = BASE_ARCH_PARAMS.copy()
config_params["task"] = task_name
config_params["problem_type"] = problem_type
config_params["dropout"] = dropout
config = BBBConfig(**config_params)
hf_model = BBBModelForSequenceClassification(config)
hf_model.eval()
if not os.path.exists(checkpoint_path):
return
old_state_dict = torch.load(checkpoint_path, map_location="cpu")
new_state_dict = {}
for key, value in old_state_dict.items():
if key.startswith("proj") or key.startswith("attention_pooling"):
new_state_dict[key] = value
elif key.startswith("classifier."):
# The 'fc' layer is already in the correct place
new_state_dict[key] = value
else:
print(f"[Warning] Unmapped key found: {key}")
new_state_dict[key] = value
print("State dict key names adjusted.")
try:
hf_model.load_state_dict(new_state_dict, strict=True)
print("State dict loaded successfully into HF")
except RuntimeError as e:
print("\n--- ERROR LOADING STATE DICT ---")
print("Verify that the parameters in BASE_ARCH_PARAMS are correct.")
print(e)
return
print(f"Saving HF-formatted model to {save_directory}")
hf_model.save_pretrained(save_directory)
if __name__ == "__main__":
convert_model(
checkpoint_path="model_classification.pth",
task_name="classification",
dropout=0.1,
problem_type="single_label_classification",
save_directory="./classification"
) |