Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Body, HTTPException, Response, Query | |
| from fastapi.responses import Response | |
| import io, zipfile | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import xarray as xr | |
| import zarr | |
| from huggingface_hub import hf_hub_download | |
| from precip import compute_precipitation_probability | |
| app = FastAPI() | |
| def home(): | |
| return {"message": "Zarr Chunk Render API is live."} | |
| # 1) Download your ZIP from Hugging Face | |
| zip_path = hf_hub_download( | |
| repo_id="vValentine7/test_full_1", | |
| filename="weather_all_7_07_chunk.zarr.zip", | |
| repo_type="dataset" | |
| ) | |
| # 2) Mount it as a read‐only Zarr store | |
| store = zarr.ZipStore(zip_path, mode="r") | |
| # 3) Open the correct group inside the ZIP | |
| ds = xr.open_zarr( | |
| store, | |
| group="weather_all_7_07_chunk.zarr", # ← use whatever folder name you see at the root of the ZIP | |
| consolidated=False | |
| ) | |
| def normalize(arr: np.ndarray, lo: float, hi: float) -> np.ndarray: | |
| return np.clip((arr - lo) / (hi - lo), 0, 1) | |
| def render_frames_zip( | |
| variable: str = Body(..., description="e.g. '2t', 'msl', '10u', 'q', etc."), | |
| level: int | None = Body(None, description="pressure level for 4D vars (e.g. 500)"), | |
| cmap: str = Body("plasma"), | |
| vmin: float| None = Body(None), | |
| vmax: float| None = Body(None), | |
| ): | |
| """ | |
| Returns a ZIP containing one PNG per timestep for `variable`. | |
| If `variable` has shape (time,pressure_level,lat,lon), you MUST supply `level`. | |
| """ | |
| # 1) validate variable | |
| if variable not in ds.data_vars: | |
| raise HTTPException(404, f"Unknown variable '{variable}'") | |
| # 2) load data into a NumPy array | |
| arr = ds[variable].values # either (40,720,1440) or (40,13,720,1440) | |
| # 3) if 4D, slice out the requested pressure level | |
| if arr.ndim == 4: | |
| if level is None: | |
| raise HTTPException(400, "Must supply 'level' for a 4D variable") | |
| levels = ds["pressure_level"].values.tolist() | |
| try: | |
| idx = levels.index(level) | |
| except ValueError: | |
| raise HTTPException(400, f"Level {level} not found; available: {levels}") | |
| arr = arr[:, idx, :, :] # now shape (40,720,1440) | |
| # 4) determine normalization bounds | |
| lo = vmin if vmin is not None else float(np.nanmin(arr)) | |
| hi = vmax if vmax is not None else float(np.nanmax(arr)) | |
| # 5) build an in-memory ZIP of 40 PNG frames | |
| zip_buf = io.BytesIO() | |
| with zipfile.ZipFile(zip_buf, mode="w") as zf: | |
| for t in range(arr.shape[0]): | |
| frame = normalize(arr[t], lo, hi) | |
| fig, ax = plt.subplots(figsize=(6, 3), dpi=100) | |
| ax.imshow(frame, cmap=cmap, origin="upper") | |
| ax.axis("off") | |
| img = io.BytesIO() | |
| plt.savefig(img, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| img.seek(0) | |
| zf.writestr(f"{variable}_frame_{t:02d}.png", img.read()) | |
| zip_buf.seek(0) | |
| return Response( | |
| content=zip_buf.read(), | |
| media_type="application/zip", | |
| headers={"Content-Disposition": f"attachment; filename={variable}_frames.zip"} | |
| ) | |
| def render_frame( | |
| variable: str = Body(..., description="e.g. '2t', 'msl', '10u', 'q'"), | |
| time_idx: int = Body(..., description="0-based timestep index"), | |
| level: int | None = Body(None, description="pressure level for 4D vars"), | |
| cmap: str = Body("plasma"), | |
| vmin: float| None = Body(None), | |
| vmax: float| None = Body(None), | |
| ): | |
| # 1) validate variable | |
| if variable not in ds.data_vars: | |
| raise HTTPException(404, f"Unknown variable '{variable}'") | |
| # 2) pull out the raw array | |
| arr = ds[variable].values # either (T, H, Y, X) or (T, Y, X) | |
| # 3) if 4D, slice by level | |
| if arr.ndim == 4: | |
| if level is None: | |
| raise HTTPException(400, "Must supply 'level' for a 4D variable") | |
| levels = ds["pressure_level"].values.tolist() | |
| try: | |
| idx = levels.index(level) | |
| except ValueError: | |
| raise HTTPException(400, f"Level {level} not found; available: {levels}") | |
| arr = arr[:, idx, :, :] # now shape (T, Y, X) | |
| # 4) check time index | |
| if not (0 <= time_idx < arr.shape[0]): | |
| raise HTTPException(400, f"time_idx {time_idx} out of range (0–{arr.shape[0]-1})") | |
| # 5) compute vmin/vmax | |
| lo = vmin if vmin is not None else float(np.nanmin(arr)) | |
| hi = vmax if vmax is not None else float(np.nanmax(arr)) | |
| # 6) render that one frame | |
| frame = normalize(arr[time_idx], lo, hi) | |
| img_buf = io.BytesIO() | |
| fig, ax = plt.subplots(figsize=(6, 3), dpi=100) | |
| ax.imshow(frame, cmap=cmap, origin="upper") | |
| ax.axis("off") | |
| plt.savefig(img_buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| img_buf.seek(0) | |
| return Response( | |
| content=img_buf.read(), | |
| media_type="image/png", | |
| headers={ | |
| "Content-Disposition": | |
| f"attachment; filename={variable}_t{time_idx:02d}.png" | |
| } | |
| ) | |
| async def temp_trend( | |
| lat: float = Query(..., description="Latitude in decimal degrees"), | |
| lon: float = Query(..., description="Longitude in decimal degrees") | |
| ): | |
| """ | |
| Returns a list of { timestamp: ISO, tempK, tempC, tempF } for each forecast hour. | |
| """ | |
| # 1) pull coords into numpy | |
| lats = ds["lat"].values # shape (720,) | |
| lons = ds["lon"].values # shape (1440,) | |
| lead_hours = ds["lead_time"].values # shape (40,) | |
| init_time = np.datetime64(ds["time"].values[0]) # first forecast time | |
| # 2) find nearest grid-point | |
| lat_idx = int(np.abs(lats - lat).argmin()) | |
| lon_idx = int(np.abs(lons - lon).argmin()) | |
| # 3) grab the 2 m temperature series at [time, lat, lon] | |
| t2m = ds["2t"][:, lat_idx, lon_idx].values # Kelvin | |
| # 4) build the response | |
| out = [] | |
| for i, kh in enumerate(lead_hours): | |
| ts = init_time + np.timedelta64(int(kh), "h") | |
| k = float(t2m[i]) | |
| c = k - 273.15 | |
| f = c * 9/5 + 32 | |
| out.append({ | |
| "timestamp": str(ts), | |
| "tempK": round(k, 2), | |
| "tempC": round(c, 2), | |
| "tempF": round(f, 2), | |
| }) | |
| return out | |
| async def daily_highlow( | |
| lat: float = Query(..., description="Latitude in decimal degrees"), | |
| lon: float = Query(..., description="Longitude in decimal degrees") | |
| ): | |
| """ | |
| Returns a list of { day: 'Mon', highF, lowF } for each forecast day. | |
| """ | |
| # reuse temp_trend logic | |
| trend = await temp_trend(lat, lon) # list of dicts | |
| # group by date | |
| by_date: dict[str, list[float]] = {} | |
| for pt in trend: | |
| date = pt["timestamp"].split("T", 1)[0] | |
| by_date.setdefault(date, []).append(pt["tempF"]) | |
| # build per-day highs/lows | |
| out = [] | |
| for date, temps in by_date.items(): | |
| dow = np.datetime64(date) | |
| weekday = np.datetime_as_string( | |
| np.datetime64(dow), unit="D" | |
| ) | |
| # convert to Python weekday name | |
| weekday = ( | |
| np.datetime64(date).astype("datetime64[D]") | |
| .tolist() | |
| .strftime("%a") | |
| ) | |
| out.append({ | |
| "day": weekday, | |
| "highF": round(max(temps), 1), | |
| "lowF": round(min(temps), 1), | |
| }) | |
| return out | |
| def wind10_vectors(stride: int = Query(20), include_vorticity: bool = Query(False)): | |
| u10 = ds["10u"].values # shape: (40, 720, 1440) | |
| v10 = ds["10v"].values | |
| lats = ds["lat"].values | |
| lons = ds["lon"].values | |
| lon2d, lat2d = np.meshgrid(lons, lats) | |
| slat = slice(None, None, stride) | |
| slon = slice(None, None, stride) | |
| R = 6.371e6 # Earth radius in meters | |
| dy = np.gradient(np.radians(lats)) * R | |
| dx = np.gradient(np.radians(lons)) * R * np.cos(np.radians(lats))[:, None] | |
| times = [str(t) for t in ds["time"].values] | |
| out = [] | |
| for t in range(u10.shape[0]): | |
| U = u10[t] | |
| V = v10[t] | |
| du_dy = np.gradient(U, axis=0) / dy[:, None] | |
| dv_dx = np.gradient(V, axis=1) / dx | |
| ZETA = dv_dx - du_dy # shape: (720, 1440) | |
| u = U[slat, slon] | |
| v = V[slat, slon] | |
| z = ZETA[slat, slon] | |
| LAT = lat2d[slat, slon] | |
| LON = lon2d[slat, slon] | |
| pts = [] | |
| for i in range(u.shape[0]): | |
| for j in range(u.shape[1]): | |
| pt = { | |
| "lat": float(LAT[i, j]), | |
| "lng": float(LON[i, j]), | |
| "u": float(u[i, j]), | |
| "v": float(v[i, j]) | |
| } | |
| if include_vorticity: | |
| pt["zeta"] = float(z[i, j]) | |
| pts.append(pt) | |
| out.append(pts) | |
| return {"times": times, "vectors": out} | |
| def list_variables(): | |
| """ | |
| List all available variable names. | |
| """ | |
| return {"variables": list(ds.data_vars)} | |
| def pressure_levels(): | |
| """ | |
| List available pressure levels (in hPa) for any 4D variable. | |
| """ | |
| return {"levels": ds["pressure_level"].values.tolist()} | |
| def available_times(): | |
| """ | |
| List all forecast timestamps (ISO strings) in order. | |
| """ | |
| return {"times": [str(t) for t in ds["time"].values]} | |
| def precipitation_grid(stride: int = Query(4)): | |
| times = [str(t) for t in ds["time"].values] | |
| lats = ds["lat"].values[::stride] | |
| lons = ds["lon"].values[::stride] | |
| out = [] | |
| for t in range(len(ds.time)): | |
| ds_t = ds.isel(time=t) | |
| ppp = compute_precipitation_probability(ds_t) # only 2D array | |
| ppp_vals = ppp[::stride, ::stride].values | |
| frame = [] | |
| for i, lat in enumerate(lats): | |
| for j, lon in enumerate(lons): | |
| frame.append({ | |
| "lat": float(lat), | |
| "lng": float(lon), | |
| "value": float(ppp_vals[i, j]) | |
| }) | |
| out.append(frame) | |
| return {"times": times, "ppp": out} | |