"""Push Borealis model to HuggingFace Hub.""" import os import torch from huggingface_hub import HfApi, create_repo, upload_folder from safetensors.torch import save_model # Config HF_REPO = "Vikhrmodels/Borealis-5b-it" CHECKPOINT_PATH = "/home/alex/Borealis/borealis_instruct_ckpts/checkpoint-2898/pytorch_model.bin" OUTPUT_DIR = "/home/alex/Borealis/hf_upload" class DictModule(torch.nn.Module): """Wrapper to use save_model with state_dict.""" def __init__(self, state_dict): super().__init__() for k, v in state_dict.items(): # Replace dots with underscores for valid attr names self.register_buffer(k.replace(".", "__DOT__"), v) def state_dict(self, *args, **kwargs): sd = super().state_dict(*args, **kwargs) return {k.replace("__DOT__", "."): v for k, v in sd.items()} def main(): print(f"Loading checkpoint from {CHECKPOINT_PATH}...") state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False) print(f"Loaded {len(state_dict)} keys") # Handle shared tensors by cloning print("Handling shared tensors...") new_state_dict = {} for k, v in state_dict.items(): new_state_dict[k] = v.clone() # Convert to safetensors using save_model print("Converting to safetensors format...") safetensors_path = os.path.join(OUTPUT_DIR, "model.safetensors") from safetensors.torch import save_file save_file(new_state_dict, safetensors_path) print(f"Saved to {safetensors_path}") # Create repo print(f"\nCreating/accessing repo: {HF_REPO}") api = HfApi() try: create_repo(HF_REPO, repo_type="model", exist_ok=True) except Exception as e: print(f"Repo note: {e}") # Upload folder print(f"\nUploading to {HF_REPO}...") api.upload_folder( folder_path=OUTPUT_DIR, repo_id=HF_REPO, repo_type="model", ) print(f"\nDone! Model available at: https://huggingface.co/{HF_REPO}") if __name__ == "__main__": main()