Spaces:
Running
Running
| """ | |
| Model loading and inference utilities for the weather forecast demo. | |
| Wraps the existing inference/predict.py logic, adding user-friendly | |
| post-processing (Celsius, wind speed/direction, rain likelihood). | |
| """ | |
| import math | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| # In HF Space, models/ is in the same directory as this file | |
| PROJECT_ROOT = Path(__file__).resolve().parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from models import create_model, get_model_defaults | |
| # ββ Model cache (loaded once, reused across requests) ββββββββββββββββββ | |
| _model_cache: dict = {} | |
| TARGET_VARS = [ | |
| ("TMP@2m_above_ground", "Temperature (2m)", "K"), | |
| ("RH@2m_above_ground", "Relative Humidity", "%"), | |
| ("UGRD@10m_above_ground", "U-Wind (10m)", "m/s"), | |
| ("VGRD@10m_above_ground", "V-Wind (10m)", "m/s"), | |
| ("GUST@surface", "Wind Gust", "m/s"), | |
| ("APCP_1hr_acc_fcst@surface", "Precipitation (1hr)", "mm"), | |
| ] | |
| # Available models with display info | |
| AVAILABLE_MODELS = { | |
| "cnn_baseline": { | |
| "display_name": "CNN Baseline", | |
| "checkpoint": "checkpoints/cnn_baseline.pt", | |
| "params": "11.3M", | |
| }, | |
| "resnet18": { | |
| "display_name": "ResNet-18", | |
| "checkpoint": "checkpoints/resnet18.pt", | |
| "params": "11.2M", | |
| }, | |
| "vit": { | |
| "display_name": "WeatherViT", | |
| "checkpoint": "checkpoints/vit.pt", | |
| "params": "7.4M", | |
| }, | |
| } | |
| def load_model(model_name: str, device: str = "cpu"): | |
| """ | |
| Load a trained model from checkpoint. Caches in memory for reuse. | |
| Returns: | |
| (model, norm_stats) tuple | |
| """ | |
| if model_name in _model_cache: | |
| return _model_cache[model_name] | |
| ckpt_path = PROJECT_ROOT / AVAILABLE_MODELS[model_name]["checkpoint"] | |
| if not ckpt_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") | |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) | |
| args = ckpt["args"] | |
| ckpt_model_name = args["model"] | |
| defaults = get_model_defaults(ckpt_model_name) | |
| n_frames = args.get("n_frames") or defaults["n_frames"] | |
| model_kwargs = { | |
| "n_input_channels": 42, | |
| "n_targets": 6, | |
| "base_channels": args.get("base_channels", 64), | |
| } | |
| if n_frames > 1: | |
| model_kwargs["n_frames"] = n_frames | |
| model = create_model(ckpt_model_name, **model_kwargs) | |
| model.load_state_dict(ckpt["model"]) | |
| model.to(device).eval() | |
| norm_stats = ckpt.get("norm_stats") | |
| _model_cache[model_name] = (model, norm_stats) | |
| return model, norm_stats | |
| def predict_raw(model, norm_stats, input_array: np.ndarray, device: str = "cpu") -> np.ndarray: | |
| """ | |
| Run inference on a (450, 449, 42) input array. | |
| Returns: | |
| np.ndarray of shape (6,) with denormalized physical values. | |
| """ | |
| x = torch.from_numpy(input_array).float() | |
| x = x.permute(2, 0, 1).unsqueeze(0) # (1, 42, 450, 449) | |
| if norm_stats: | |
| mean = norm_stats["input_mean"] | |
| std = norm_stats["input_std"] | |
| # Ensure correct device | |
| if isinstance(mean, torch.Tensor): | |
| mean = mean.float() | |
| std = std.float() | |
| x = (x - mean) / (std + 1e-7) | |
| x = x.to(device) | |
| with torch.no_grad(): | |
| pred = model(x).squeeze(0).cpu() # (6,) | |
| if norm_stats: | |
| target_mean = norm_stats["target_mean"] | |
| target_std = norm_stats["target_std"] | |
| if isinstance(target_mean, torch.Tensor): | |
| target_mean = target_mean.float() | |
| target_std = target_std.float() | |
| pred = pred * target_std + target_mean | |
| return pred.numpy() | |
| def _wind_direction_str(degrees: float) -> str: | |
| """Convert wind direction in degrees to compass string.""" | |
| dirs = ["N", "NNE", "NE", "ENE", "E", "ESE", "SE", "SSE", | |
| "S", "SSW", "SW", "WSW", "W", "WNW", "NW", "NNW"] | |
| idx = round(degrees / 22.5) % 16 | |
| return dirs[idx] | |
| def format_forecast(pred: np.ndarray) -> dict: | |
| """ | |
| Convert raw model output (6 physical values) into a user-friendly forecast dict. | |
| """ | |
| temp_k = float(pred[0]) | |
| rh = float(pred[1]) | |
| u_wind = float(pred[2]) | |
| v_wind = float(pred[3]) | |
| gust = float(pred[4]) | |
| apcp = float(pred[5]) | |
| # Derived quantities | |
| temp_c = temp_k - 273.15 | |
| temp_f = temp_c * 9 / 5 + 32 | |
| wind_speed = math.sqrt(u_wind**2 + v_wind**2) | |
| # Meteorological wind direction: direction FROM which wind blows | |
| wind_dir_deg = (math.degrees(math.atan2(-u_wind, -v_wind)) + 360) % 360 | |
| wind_dir_str = _wind_direction_str(wind_dir_deg) | |
| # Rain likelihood based on APCP threshold | |
| apcp = max(apcp, 0.0) # Clamp negative predictions | |
| if apcp > 5.0: | |
| rain_str = "Heavy Rain Likely" | |
| elif apcp > 2.0: | |
| rain_str = "Rain Likely" | |
| elif apcp > 0.5: | |
| rain_str = "Light Rain Possible" | |
| else: | |
| rain_str = "No Rain Expected" | |
| return { | |
| "temperature_k": temp_k, | |
| "temperature_c": temp_c, | |
| "temperature_f": temp_f, | |
| "humidity_pct": max(0.0, min(100.0, rh)), | |
| "u_wind_ms": u_wind, | |
| "v_wind_ms": v_wind, | |
| "wind_speed_ms": wind_speed, | |
| "wind_dir_deg": wind_dir_deg, | |
| "wind_dir_str": wind_dir_str, | |
| "gust_ms": max(gust, 0.0), | |
| "precipitation_mm": apcp, | |
| "rain_status": rain_str, | |
| } | |
| def run_forecast(model_name: str, input_array: np.ndarray, device: str = "cpu") -> dict: | |
| """Full pipeline: load model β predict β format results.""" | |
| model, norm_stats = load_model(model_name, device) | |
| pred = predict_raw(model, norm_stats, input_array, device) | |
| return format_forecast(pred) | |