Spaces:
Runtime error
Runtime error
| """ | |
| FastAPI application exposing the rain nowcast API and a Gradio UI. | |
| The previous Streamlit proxy was difficult to keep alive on Spaces due to | |
| websocket restrictions. This module provides the same REST endpoints while | |
| mounting a lightweight Gradio front-end so the UI works without websocket | |
| tunnelling. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import subprocess | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import joblib | |
| import pandas as pd | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException, Query | |
| from pydantic import BaseModel, Field | |
| from xgboost import XGBClassifier | |
| # --------- Paths --------- | |
| ROOT = Path(__file__).resolve().parents[1] | |
| MODELS = ROOT / "models" | |
| RESULTS = ROOT / "results" | |
| SCRIPTS = ROOT / "scripts" | |
| MODEL_PATH = MODELS / "rain_xgb_tuned.joblib" | |
| META_PATH = MODELS / "rain_xgb_tuned_meta.json" | |
| MODEL_JSON_PATH = MODELS / "xgb_tuned.json" | |
| HOURLY_CSV = RESULTS / "hourly.csv" | |
| # Make training utilities importable. | |
| import sys | |
| sys.path.insert(0, str(ROOT)) | |
| from scripts.train_xgb_tuned_final import build_features # type: ignore | |
| # --------- Load model + meta at startup --------- | |
| if not META_PATH.exists(): | |
| raise RuntimeError( | |
| "Model metadata missing. Run `python scripts/train_xgb_tuned_final.py` " | |
| "or copy models/rain_xgb_tuned_meta.json into place." | |
| ) | |
| meta = json.loads(META_PATH.read_text()) | |
| FEATURES = meta["features"] | |
| THRESH = meta["thresholds"] | |
| HORIZON_H = int(meta["horizon_hours"]) | |
| def _load_model() -> XGBClassifier: | |
| if MODEL_PATH.exists(): | |
| return joblib.load(MODEL_PATH) | |
| if MODEL_JSON_PATH.exists(): | |
| params = meta.get("model", {}).get("params", {}) | |
| booster = XGBClassifier(**params) | |
| booster.load_model(MODEL_JSON_PATH) | |
| return booster | |
| raise RuntimeError( | |
| "Model artifact missing. Run `python scripts/train_xgb_tuned_final.py` " | |
| "to generate models/rain_xgb_tuned.joblib (or xgb_tuned.json), " | |
| "or copy the trained file into the models/ directory." | |
| ) | |
| model = _load_model() | |
| # --------- Helpers --------- | |
| def ensure_hourly(lat: float, lon: float, past_days: int = 90) -> pd.DataFrame: | |
| """Refresh the cached hourly CSV when it is missing or stale.""" | |
| env = os.environ.copy() | |
| env["LAT"] = str(lat) | |
| env["LON"] = str(lon) | |
| env["PAST_DAYS"] = str(past_days) | |
| needs_refresh = True | |
| if HOURLY_CSV.exists(): | |
| age_hours = (datetime.now().timestamp() - HOURLY_CSV.stat().st_mtime) / 3600 | |
| needs_refresh = age_hours > 6 | |
| if (not HOURLY_CSV.exists()) or needs_refresh: | |
| try: | |
| subprocess.run(["bash", str(SCRIPTS / "fetch_weather.sh")], check=True, env=env) | |
| subprocess.run(["python3", str(SCRIPTS / "export_hourly.py")], check=True, env=env) | |
| except subprocess.CalledProcessError as exc: | |
| raise HTTPException(status_code=502, detail=f"Data refresh failed: {exc}") from exc | |
| return pd.read_csv(HOURLY_CSV, parse_dates=["time"]) | |
| def predict_latest(df: pd.DataFrame, mode: str) -> Dict[str, object]: | |
| """Build features, score the latest hour, and return a structured response.""" | |
| Xdf = build_features(df.copy()) | |
| if Xdf.empty: | |
| raise HTTPException(status_code=422, detail="Not enough rows to build features.") | |
| try: | |
| Xdf = Xdf[FEATURES] | |
| except KeyError as exc: | |
| raise HTTPException(status_code=500, detail=f"Feature mismatch: {exc}") from exc | |
| x = Xdf.iloc[[-1]].values | |
| probability = float(model.predict_proba(x)[0, 1]) | |
| thresholds = { | |
| "default": float(THRESH["default"]), | |
| "recall": float(THRESH["high_recall"]), | |
| "precision": float(THRESH["high_precision"]), | |
| } | |
| if mode not in thresholds: | |
| raise HTTPException(status_code=400, detail=f"Unsupported mode '{mode}'.") | |
| threshold = thresholds[mode] | |
| decision = "RAIN" if probability >= threshold else "No rain" | |
| ts = df.loc[Xdf.index, "time"].iloc[-1] | |
| return { | |
| "timestamp": ts.isoformat(), | |
| "probability": probability, | |
| "threshold": threshold, | |
| "mode": mode, | |
| "decision": decision, | |
| "horizon_hours": HORIZON_H, | |
| } | |
| def format_prediction(result: Dict[str, object]) -> str: | |
| """Generate a concise markdown summary for the UI.""" | |
| emoji = "π§οΈ" if result["decision"] == "RAIN" else "β " | |
| probability = result["probability"] | |
| threshold = result["threshold"] | |
| mode = result["mode"] | |
| timestamp = result["timestamp"] | |
| return ( | |
| f"{emoji} **Decision:** {result['decision']} (mode **{mode}**)\n\n" | |
| f"- Probability of rain β€ {HORIZON_H}h: **{probability:.3f}**\n" | |
| f"- Threshold: **{threshold:.2f}**\n" | |
| f"- Issued for hour ending **{timestamp}**" | |
| ) | |
| class PredictBody(BaseModel): | |
| lat: float = Field(6.5244, description="Latitude") | |
| lon: float = Field(3.3792, description="Longitude") | |
| mode: str = Field("default", description="default | recall | precision") | |
| past_days: int = Field(90, ge=14, le=180, description="How much history to fetch (days)") | |
| app = FastAPI(title="Rain Nowcast API", version="1.1.0") | |
| def health() -> Dict[str, object]: | |
| return { | |
| "status": "ok", | |
| "model_file": MODEL_PATH.name, | |
| "horizon_hours": HORIZON_H, | |
| "thresholds": THRESH, | |
| "features": FEATURES, | |
| } | |
| def predict(body: PredictBody) -> Dict[str, object]: | |
| df = ensure_hourly(body.lat, body.lon, body.past_days) | |
| out = predict_latest(df, body.mode) | |
| return {"ok": True, "result": out} | |
| def predict_get( | |
| lat: float = Query(6.5244), | |
| lon: float = Query(3.3792), | |
| mode: str = Query("default"), | |
| past_days: int = Query(90, ge=14, le=180), | |
| ) -> Dict[str, object]: | |
| df = ensure_hourly(lat, lon, past_days) | |
| out = predict_latest(df, mode) | |
| return {"ok": True, "result": out} | |
| # --------- Gradio UI --------- | |
| CITY_PRESETS: Dict[str, Tuple[float, float]] = { | |
| "Lagos π³π¬": (6.5244, 3.3792), | |
| "Accra π¬π": (5.6037, -0.1870), | |
| "Nairobi π°πͺ": (-1.2864, 36.8172), | |
| "Kampala πΊπ¬": (0.3476, 32.5825), | |
| "Addis Ababa πͺπΉ": (8.9806, 38.7578), | |
| "Custom": (0.0, 0.0), | |
| } | |
| def _resolve_location(city: str, lat: float, lon: float) -> Tuple[float, float, str]: | |
| if city in CITY_PRESETS and city != "Custom": | |
| chosen_lat, chosen_lon = CITY_PRESETS[city] | |
| label = city | |
| else: | |
| chosen_lat, chosen_lon = lat, lon | |
| label = f"Custom ({lat:.3f}, {lon:.3f})" | |
| return chosen_lat, chosen_lon, label | |
| def gradio_predict( | |
| city: str, | |
| lat: float, | |
| lon: float, | |
| mode: str, | |
| past_days: int, | |
| ) -> Tuple[str, pd.DataFrame, pd.DataFrame]: | |
| chosen_lat, chosen_lon, label = _resolve_location(city, lat, lon) | |
| df = ensure_hourly(chosen_lat, chosen_lon, past_days) | |
| result = predict_latest(df, mode) | |
| summary = format_prediction(result) | |
| last48 = df.tail(48).copy() | |
| chart = last48[["time", "temp_c", "humidity", "precip_mm", "rain_mm"]].copy() | |
| chart = chart.melt(id_vars="time", var_name="series", value_name="value") | |
| if pd.api.types.is_datetime64_any_dtype(chart["time"]): | |
| chart["time"] = chart["time"].dt.strftime("%Y-%m-%d %H:%M") | |
| latest = pd.DataFrame( | |
| { | |
| "location": [label], | |
| "timestamp": [result["timestamp"]], | |
| "mode": [result["mode"]], | |
| "probability": [result["probability"]], | |
| "threshold": [result["threshold"]], | |
| "decision": [result["decision"]], | |
| } | |
| ) | |
| return summary, latest, chart | |
| with gr.Blocks(css=".gradio-container {max-width: 900px;}") as demo: | |
| gr.Markdown("# π§οΈ Rain Nowcast\nPredict the probability of rain in the next " | |
| f"{HORIZON_H} hours using the tuned XGBoost model.") | |
| with gr.Row(): | |
| city_input = gr.Dropdown( | |
| label="City preset", | |
| choices=list(CITY_PRESETS.keys()), | |
| value="Lagos π³π¬", | |
| ) | |
| mode_input = gr.Radio( | |
| label="Decision mode", | |
| choices=["default", "recall", "precision"], | |
| value="default", | |
| info="default=balanced, recall=warn more, precision=extra picky", | |
| ) | |
| with gr.Row(): | |
| lat_input = gr.Number(label="Latitude (used if city is Custom)", value=6.5244) | |
| lon_input = gr.Number(label="Longitude (used if city is Custom)", value=3.3792) | |
| past_days_input = gr.Slider( | |
| label="History window (days)", | |
| minimum=14, | |
| maximum=180, | |
| value=90, | |
| step=1, | |
| ) | |
| submit = gr.Button("Run prediction", variant="primary") | |
| summary_md = gr.Markdown() | |
| latest_df = gr.Dataframe(label="Latest prediction", wrap=True) | |
| chart_df = gr.LinePlot( | |
| label="Last 48h weather (hourly)", | |
| x="time", | |
| y="value", | |
| color="series", | |
| overlay_point=True, | |
| width=900, | |
| height=350, | |
| ) | |
| submit.click( | |
| gradio_predict, | |
| inputs=[city_input, lat_input, lon_input, mode_input, past_days_input], | |
| outputs=[summary_md, latest_df, chart_df], | |
| ) | |
| gr.Markdown( | |
| "Model features match the training pipeline " | |
| "(see `scripts/train_xgb_tuned_final.py`). Data fetched from Open-Meteo." | |
| ) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |