|
|
"""Command line tool to push locally save model checkpoints to huggingface |
|
|
|
|
|
use: |
|
|
python checkpoint_to_huggingface.py "path/to/model/checkpoints" \ |
|
|
--huggingface-repo="openclimatefix/pvnet_uk_region" \ |
|
|
--wandb-repo="openclimatefix/pvnet2.1" \ |
|
|
--local-path="~/tmp/this_model" \ |
|
|
--no-push-to-hub |
|
|
""" |
|
|
|
|
|
import tempfile |
|
|
|
|
|
import typer |
|
|
import wandb |
|
|
|
|
|
from pvnet.load_model import get_model_from_checkpoints |
|
|
|
|
|
app = typer.Typer(pretty_exceptions_show_locals=False) |
|
|
|
|
|
@app.command() |
|
|
def push_to_huggingface( |
|
|
checkpoint_dir_paths: list[str], |
|
|
huggingface_repo: str = "openclimatefix/pvnet_uk_region", |
|
|
wandb_repo: str = "openclimatefix/pvnet2.1", |
|
|
val_best: bool = True, |
|
|
wandb_ids: list[str] = [], |
|
|
local_path: str = None, |
|
|
push_to_hub: bool = True, |
|
|
): |
|
|
"""Push a local model to a huggingface model repo |
|
|
|
|
|
Args: |
|
|
checkpoint_dir_paths: Path(s) of the checkpoint directory(ies) |
|
|
huggingface_repo: Name of the HuggingFace repo to push the model to |
|
|
wandb_repo: Name of the wandb repo which has training logs |
|
|
val_best: Use best model according to val loss, else last saved model |
|
|
wandb_ids: The wandb ID code(s) |
|
|
local_path: Where to save the local copy of the model |
|
|
push_to_hub: Whether to push the model to the hub or just create local version. |
|
|
""" |
|
|
|
|
|
assert push_to_hub or local_path is not None |
|
|
|
|
|
is_ensemble = len(checkpoint_dir_paths) > 1 |
|
|
|
|
|
|
|
|
if wandb_ids == []: |
|
|
all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)] |
|
|
for path in checkpoint_dir_paths: |
|
|
dirname = path.split("/")[-1] |
|
|
if dirname in all_wandb_ids: |
|
|
wandb_ids.append(dirname) |
|
|
else: |
|
|
wandb_ids.append(None) |
|
|
|
|
|
model, model_config, data_config = get_model_from_checkpoints(checkpoint_dir_paths, val_best) |
|
|
|
|
|
if not is_ensemble: |
|
|
wandb_ids = wandb_ids[0] |
|
|
|
|
|
|
|
|
if local_path is None: |
|
|
temp_dir = tempfile.TemporaryDirectory() |
|
|
model_output_dir = temp_dir.name |
|
|
else: |
|
|
model_output_dir = local_path |
|
|
|
|
|
model.save_pretrained( |
|
|
model_output_dir, |
|
|
config=model_config, |
|
|
data_config=data_config, |
|
|
wandb_repo=wandb_repo, |
|
|
wandb_ids=wandb_ids, |
|
|
push_to_hub=push_to_hub, |
|
|
repo_id=huggingface_repo if push_to_hub else None, |
|
|
) |
|
|
|
|
|
if local_path is None: |
|
|
temp_dir.cleanup() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|