File size: 4,002 Bytes
df27dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Command-line interface for ConvGRU-Ensemble inference and serving."""

import time

import fire
import numpy as np
import xarray as xr


def _load_model(checkpoint: str | None = None, hub_repo: str | None = None, device: str = "cpu"):
    """Load model from local checkpoint or HuggingFace Hub."""
    from .lightning_model import RadarLightningModel

    if hub_repo is not None:
        print(f"Loading model from HuggingFace Hub: {hub_repo}")
        return RadarLightningModel.from_pretrained(hub_repo, device=device)
    elif checkpoint is not None:
        print(f"Loading model from checkpoint: {checkpoint}")
        return RadarLightningModel.from_checkpoint(checkpoint, device=device)
    else:
        raise ValueError("Either --checkpoint or --hub-repo must be provided.")


def predict(
    input: str,
    checkpoint: str | None = None,
    hub_repo: str | None = None,
    variable: str = "RR",
    forecast_steps: int = 12,
    ensemble_size: int = 10,
    device: str = "cpu",
    output: str = "predictions.nc",
):
    """
    Run inference on a NetCDF input file and save predictions as NetCDF.

    Args:
        input: Path to input NetCDF file with rain rate data (T, H, W) or (T, Y, X).
        checkpoint: Path to local .ckpt checkpoint file.
        hub_repo: HuggingFace Hub repo ID (e.g., 'it4lia/irene'). Alternative to --checkpoint.
        variable: Name of the rain rate variable in the NetCDF file.
        forecast_steps: Number of future timesteps to forecast.
        ensemble_size: Number of ensemble members to generate.
        device: Device for inference ('cpu' or 'cuda').
        output: Path for the output NetCDF file.
    """
    model = _load_model(checkpoint, hub_repo, device)

    # Load input data
    print(f"Loading input: {input}")
    ds = xr.open_dataset(input)
    if variable not in ds:
        available = list(ds.data_vars)
        raise ValueError(f"Variable '{variable}' not found. Available: {available}")

    data = ds[variable].values  # (T, H, W) or similar
    if data.ndim != 3:
        raise ValueError(f"Expected 3D data (T, H, W), got shape {data.shape}")

    print(f"Input shape: {data.shape}")
    past = data.astype(np.float32)

    # Run inference
    t0 = time.perf_counter()
    preds = model.predict(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size)
    elapsed = time.perf_counter() - t0
    print(f"Output shape: {preds.shape} (ensemble, time, H, W)")
    print(f"Elapsed: {elapsed:.2f}s")

    # Build output dataset
    ds_out = xr.Dataset(
        {
            "precipitation_forecast": xr.DataArray(
                data=preds,
                dims=["ensemble_member", "forecast_step", "y", "x"],
                attrs={"units": "mm/h", "long_name": "Ensemble precipitation forecast"},
            ),
        },
        attrs={
            "model": "ConvGRU-Ensemble",
            "forecast_steps": forecast_steps,
            "ensemble_size": ensemble_size,
            "source_file": str(input),
        },
    )

    ds_out.to_netcdf(output)
    print(f"Predictions saved to: {output}")


def serve(
    checkpoint: str | None = None,
    hub_repo: str | None = None,
    host: str = "0.0.0.0",
    port: int = 8000,
    device: str = "cpu",
):
    """
    Start the FastAPI inference server.

    Args:
        checkpoint: Path to local .ckpt checkpoint file.
        hub_repo: HuggingFace Hub repo ID (e.g., 'it4lia/irene'). Alternative to --checkpoint.
        host: Host to bind to.
        port: Port to listen on.
        device: Device for inference ('cpu' or 'cuda').
    """
    import os

    if checkpoint is not None:
        os.environ["MODEL_CHECKPOINT"] = checkpoint
    if hub_repo is not None:
        os.environ["HF_REPO_ID"] = hub_repo
    os.environ.setdefault("DEVICE", device)

    import uvicorn

    uvicorn.run("convgru_ensemble.serve:app", host=host, port=port)


def main():
    fire.Fire({"predict": predict, "serve": serve})


if __name__ == "__main__":
    main()