|
|
import torch |
|
|
from configuration_bbb import BBBConfig |
|
|
from modeling_bbb import BBBModelForSequenceClassification |
|
|
import os |
|
|
|
|
|
BASE_ARCH_PARAMS = { |
|
|
"d_tab": 384, |
|
|
"d_img": 2048, |
|
|
"d_txt": 768, |
|
|
"proj_dim": 2048 |
|
|
} |
|
|
|
|
|
def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str): |
|
|
config_params = BASE_ARCH_PARAMS.copy() |
|
|
config_params["task"] = task_name |
|
|
config_params["problem_type"] = problem_type |
|
|
config_params["dropout"] = dropout |
|
|
|
|
|
config = BBBConfig(**config_params) |
|
|
|
|
|
hf_model = BBBModelForSequenceClassification(config) |
|
|
hf_model.eval() |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
return |
|
|
|
|
|
old_state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
new_state_dict = {} |
|
|
for key, value in old_state_dict.items(): |
|
|
if key.startswith("proj") or key.startswith("attention_pooling"): |
|
|
new_state_dict[key] = value |
|
|
|
|
|
elif key.startswith("classifier."): |
|
|
|
|
|
new_state_dict[key] = value |
|
|
|
|
|
else: |
|
|
print(f"[Warning] Unmapped key found: {key}") |
|
|
new_state_dict[key] = value |
|
|
|
|
|
print("State dict key names adjusted.") |
|
|
|
|
|
try: |
|
|
hf_model.load_state_dict(new_state_dict, strict=True) |
|
|
print("State dict loaded successfully into HF") |
|
|
except RuntimeError as e: |
|
|
print("\n--- ERROR LOADING STATE DICT ---") |
|
|
print("Verify that the parameters in BASE_ARCH_PARAMS are correct.") |
|
|
print(e) |
|
|
return |
|
|
|
|
|
print(f"Saving HF-formatted model to {save_directory}") |
|
|
hf_model.save_pretrained(save_directory) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
convert_model( |
|
|
checkpoint_path="model_classification.pth", |
|
|
task_name="classification", |
|
|
dropout=0.1, |
|
|
problem_type="single_label_classification", |
|
|
save_directory="./classification" |
|
|
) |