import torch from configuration_bbb import BBBConfig from modeling_bbb import BBBModelForSequenceClassification import os BASE_ARCH_PARAMS = { "d_tab": 384, "d_img": 2048, "d_txt": 768, "proj_dim": 2048 } def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str): config_params = BASE_ARCH_PARAMS.copy() config_params["task"] = task_name config_params["problem_type"] = problem_type config_params["dropout"] = dropout config = BBBConfig(**config_params) hf_model = BBBModelForSequenceClassification(config) hf_model.eval() if not os.path.exists(checkpoint_path): return old_state_dict = torch.load(checkpoint_path, map_location="cpu") new_state_dict = {} for key, value in old_state_dict.items(): if key.startswith("proj") or key.startswith("attention_pooling"): new_state_dict[key] = value elif key.startswith("classifier."): # The 'fc' layer is already in the correct place new_state_dict[key] = value else: print(f"[Warning] Unmapped key found: {key}") new_state_dict[key] = value print("State dict key names adjusted.") try: hf_model.load_state_dict(new_state_dict, strict=True) print("State dict loaded successfully into HF") except RuntimeError as e: print("\n--- ERROR LOADING STATE DICT ---") print("Verify that the parameters in BASE_ARCH_PARAMS are correct.") print(e) return print(f"Saving HF-formatted model to {save_directory}") hf_model.save_pretrained(save_directory) if __name__ == "__main__": convert_model( checkpoint_path="model_classification.pth", task_name="classification", dropout=0.1, problem_type="single_label_classification", save_directory="./classification" )