File size: 2,556 Bytes
cbe6208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Command line tool to push locally save model checkpoints to huggingface

use:
python checkpoint_to_huggingface.py "path/to/model/checkpoints" \
    --huggingface-repo="openclimatefix/pvnet_uk_region" \
    --wandb-repo="openclimatefix/pvnet2.1" \
    --local-path="~/tmp/this_model" \
    --no-push-to-hub
"""

import tempfile

import typer
import wandb

from pvnet.load_model import get_model_from_checkpoints

app = typer.Typer(pretty_exceptions_show_locals=False)

@app.command()
def push_to_huggingface(
    checkpoint_dir_paths: list[str],
    huggingface_repo: str = "openclimatefix/pvnet_uk_region",  # e.g. openclimatefix/windnet_india
    wandb_repo: str = "openclimatefix/pvnet2.1",
    val_best: bool = True,
    wandb_ids: list[str] = [],
    local_path: str = None,
    push_to_hub: bool = True,
):
    """Push a local model to a huggingface model repo

    Args:
        checkpoint_dir_paths: Path(s) of the checkpoint directory(ies)
        huggingface_repo: Name of the HuggingFace repo to push the model to
        wandb_repo: Name of the wandb repo which has training logs
        val_best: Use best model according to val loss, else last saved model
        wandb_ids: The wandb ID code(s)
        local_path: Where to save the local copy of the model
        push_to_hub: Whether to push the model to the hub or just create local version.
    """

    assert push_to_hub or local_path is not None

    is_ensemble = len(checkpoint_dir_paths) > 1

    # Check if checkpoint dir name is wandb run ID
    if wandb_ids == []:
        all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)]
        for path in checkpoint_dir_paths:
            dirname = path.split("/")[-1]
            if dirname in all_wandb_ids:
                wandb_ids.append(dirname)
            else:
                wandb_ids.append(None)

    model, model_config, data_config = get_model_from_checkpoints(checkpoint_dir_paths, val_best)

    if not is_ensemble:
        wandb_ids = wandb_ids[0]

    # Push to hub
    if local_path is None:
        temp_dir = tempfile.TemporaryDirectory()
        model_output_dir = temp_dir.name
    else:
        model_output_dir = local_path

    model.save_pretrained(
        model_output_dir,
        config=model_config,
        data_config=data_config,
        wandb_repo=wandb_repo,
        wandb_ids=wandb_ids,
        push_to_hub=push_to_hub,
        repo_id=huggingface_repo if push_to_hub else None,
    )

    if local_path is None:
        temp_dir.cleanup()


if __name__ == "__main__":
    app()