TITAN-BBB / convert_weights.py
gabrielbianchin's picture
Clean re-commit using proper Git LFS
3a19a3f
import torch
from configuration_bbb import BBBConfig
from modeling_bbb import BBBModelForSequenceClassification
import os
BASE_ARCH_PARAMS = {
"d_tab": 384,
"d_img": 2048,
"d_txt": 768,
"proj_dim": 2048
}
def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str):
config_params = BASE_ARCH_PARAMS.copy()
config_params["task"] = task_name
config_params["problem_type"] = problem_type
config_params["dropout"] = dropout
config = BBBConfig(**config_params)
hf_model = BBBModelForSequenceClassification(config)
hf_model.eval()
if not os.path.exists(checkpoint_path):
return
old_state_dict = torch.load(checkpoint_path, map_location="cpu")
new_state_dict = {}
for key, value in old_state_dict.items():
if key.startswith("proj") or key.startswith("attention_pooling"):
new_state_dict[key] = value
elif key.startswith("classifier."):
# The 'fc' layer is already in the correct place
new_state_dict[key] = value
else:
print(f"[Warning] Unmapped key found: {key}")
new_state_dict[key] = value
print("State dict key names adjusted.")
try:
hf_model.load_state_dict(new_state_dict, strict=True)
print("State dict loaded successfully into HF")
except RuntimeError as e:
print("\n--- ERROR LOADING STATE DICT ---")
print("Verify that the parameters in BASE_ARCH_PARAMS are correct.")
print(e)
return
print(f"Saving HF-formatted model to {save_directory}")
hf_model.save_pretrained(save_directory)
if __name__ == "__main__":
convert_model(
checkpoint_path="model_classification.pth",
task_name="classification",
dropout=0.1,
problem_type="single_label_classification",
save_directory="./classification"
)