vValentine7's picture
Update app.py
c062128 verified
from fastapi import FastAPI, Body, HTTPException, Response, Query
from fastapi.responses import Response
import io, zipfile
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import zarr
from huggingface_hub import hf_hub_download
from precip import compute_precipitation_probability
app = FastAPI()
@app.get("/")
def home():
return {"message": "Zarr Chunk Render API is live."}
# 1) Download your ZIP from Hugging Face
zip_path = hf_hub_download(
repo_id="vValentine7/test_full_1",
filename="weather_all_7_07_chunk.zarr.zip",
repo_type="dataset"
)
# 2) Mount it as a read‐only Zarr store
store = zarr.ZipStore(zip_path, mode="r")
# 3) Open the correct group inside the ZIP
ds = xr.open_zarr(
store,
group="weather_all_7_07_chunk.zarr", # ← use whatever folder name you see at the root of the ZIP
consolidated=False
)
def normalize(arr: np.ndarray, lo: float, hi: float) -> np.ndarray:
return np.clip((arr - lo) / (hi - lo), 0, 1)
@app.post("/render-frames-zip")
def render_frames_zip(
variable: str = Body(..., description="e.g. '2t', 'msl', '10u', 'q', etc."),
level: int | None = Body(None, description="pressure level for 4D vars (e.g. 500)"),
cmap: str = Body("plasma"),
vmin: float| None = Body(None),
vmax: float| None = Body(None),
):
"""
Returns a ZIP containing one PNG per timestep for `variable`.
If `variable` has shape (time,pressure_level,lat,lon), you MUST supply `level`.
"""
# 1) validate variable
if variable not in ds.data_vars:
raise HTTPException(404, f"Unknown variable '{variable}'")
# 2) load data into a NumPy array
arr = ds[variable].values # either (40,720,1440) or (40,13,720,1440)
# 3) if 4D, slice out the requested pressure level
if arr.ndim == 4:
if level is None:
raise HTTPException(400, "Must supply 'level' for a 4D variable")
levels = ds["pressure_level"].values.tolist()
try:
idx = levels.index(level)
except ValueError:
raise HTTPException(400, f"Level {level} not found; available: {levels}")
arr = arr[:, idx, :, :] # now shape (40,720,1440)
# 4) determine normalization bounds
lo = vmin if vmin is not None else float(np.nanmin(arr))
hi = vmax if vmax is not None else float(np.nanmax(arr))
# 5) build an in-memory ZIP of 40 PNG frames
zip_buf = io.BytesIO()
with zipfile.ZipFile(zip_buf, mode="w") as zf:
for t in range(arr.shape[0]):
frame = normalize(arr[t], lo, hi)
fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
ax.imshow(frame, cmap=cmap, origin="upper")
ax.axis("off")
img = io.BytesIO()
plt.savefig(img, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
img.seek(0)
zf.writestr(f"{variable}_frame_{t:02d}.png", img.read())
zip_buf.seek(0)
return Response(
content=zip_buf.read(),
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename={variable}_frames.zip"}
)
@app.post("/render-frame")
def render_frame(
variable: str = Body(..., description="e.g. '2t', 'msl', '10u', 'q'"),
time_idx: int = Body(..., description="0-based timestep index"),
level: int | None = Body(None, description="pressure level for 4D vars"),
cmap: str = Body("plasma"),
vmin: float| None = Body(None),
vmax: float| None = Body(None),
):
# 1) validate variable
if variable not in ds.data_vars:
raise HTTPException(404, f"Unknown variable '{variable}'")
# 2) pull out the raw array
arr = ds[variable].values # either (T, H, Y, X) or (T, Y, X)
# 3) if 4D, slice by level
if arr.ndim == 4:
if level is None:
raise HTTPException(400, "Must supply 'level' for a 4D variable")
levels = ds["pressure_level"].values.tolist()
try:
idx = levels.index(level)
except ValueError:
raise HTTPException(400, f"Level {level} not found; available: {levels}")
arr = arr[:, idx, :, :] # now shape (T, Y, X)
# 4) check time index
if not (0 <= time_idx < arr.shape[0]):
raise HTTPException(400, f"time_idx {time_idx} out of range (0–{arr.shape[0]-1})")
# 5) compute vmin/vmax
lo = vmin if vmin is not None else float(np.nanmin(arr))
hi = vmax if vmax is not None else float(np.nanmax(arr))
# 6) render that one frame
frame = normalize(arr[time_idx], lo, hi)
img_buf = io.BytesIO()
fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
ax.imshow(frame, cmap=cmap, origin="upper")
ax.axis("off")
plt.savefig(img_buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
img_buf.seek(0)
return Response(
content=img_buf.read(),
media_type="image/png",
headers={
"Content-Disposition":
f"attachment; filename={variable}_t{time_idx:02d}.png"
}
)
@app.get("/temp-trend")
async def temp_trend(
lat: float = Query(..., description="Latitude in decimal degrees"),
lon: float = Query(..., description="Longitude in decimal degrees")
):
"""
Returns a list of { timestamp: ISO, tempK, tempC, tempF } for each forecast hour.
"""
# 1) pull coords into numpy
lats = ds["lat"].values # shape (720,)
lons = ds["lon"].values # shape (1440,)
lead_hours = ds["lead_time"].values # shape (40,)
init_time = np.datetime64(ds["time"].values[0]) # first forecast time
# 2) find nearest grid-point
lat_idx = int(np.abs(lats - lat).argmin())
lon_idx = int(np.abs(lons - lon).argmin())
# 3) grab the 2 m temperature series at [time, lat, lon]
t2m = ds["2t"][:, lat_idx, lon_idx].values # Kelvin
# 4) build the response
out = []
for i, kh in enumerate(lead_hours):
ts = init_time + np.timedelta64(int(kh), "h")
k = float(t2m[i])
c = k - 273.15
f = c * 9/5 + 32
out.append({
"timestamp": str(ts),
"tempK": round(k, 2),
"tempC": round(c, 2),
"tempF": round(f, 2),
})
return out
@app.get("/daily-highlow")
async def daily_highlow(
lat: float = Query(..., description="Latitude in decimal degrees"),
lon: float = Query(..., description="Longitude in decimal degrees")
):
"""
Returns a list of { day: 'Mon', highF, lowF } for each forecast day.
"""
# reuse temp_trend logic
trend = await temp_trend(lat, lon) # list of dicts
# group by date
by_date: dict[str, list[float]] = {}
for pt in trend:
date = pt["timestamp"].split("T", 1)[0]
by_date.setdefault(date, []).append(pt["tempF"])
# build per-day highs/lows
out = []
for date, temps in by_date.items():
dow = np.datetime64(date)
weekday = np.datetime_as_string(
np.datetime64(dow), unit="D"
)
# convert to Python weekday name
weekday = (
np.datetime64(date).astype("datetime64[D]")
.tolist()
.strftime("%a")
)
out.append({
"day": weekday,
"highF": round(max(temps), 1),
"lowF": round(min(temps), 1),
})
return out
@app.get("/wind10-vectors")
def wind10_vectors(stride: int = Query(20), include_vorticity: bool = Query(False)):
u10 = ds["10u"].values # shape: (40, 720, 1440)
v10 = ds["10v"].values
lats = ds["lat"].values
lons = ds["lon"].values
lon2d, lat2d = np.meshgrid(lons, lats)
slat = slice(None, None, stride)
slon = slice(None, None, stride)
R = 6.371e6 # Earth radius in meters
dy = np.gradient(np.radians(lats)) * R
dx = np.gradient(np.radians(lons)) * R * np.cos(np.radians(lats))[:, None]
times = [str(t) for t in ds["time"].values]
out = []
for t in range(u10.shape[0]):
U = u10[t]
V = v10[t]
du_dy = np.gradient(U, axis=0) / dy[:, None]
dv_dx = np.gradient(V, axis=1) / dx
ZETA = dv_dx - du_dy # shape: (720, 1440)
u = U[slat, slon]
v = V[slat, slon]
z = ZETA[slat, slon]
LAT = lat2d[slat, slon]
LON = lon2d[slat, slon]
pts = []
for i in range(u.shape[0]):
for j in range(u.shape[1]):
pt = {
"lat": float(LAT[i, j]),
"lng": float(LON[i, j]),
"u": float(u[i, j]),
"v": float(v[i, j])
}
if include_vorticity:
pt["zeta"] = float(z[i, j])
pts.append(pt)
out.append(pts)
return {"times": times, "vectors": out}
@app.get("/variables")
def list_variables():
"""
List all available variable names.
"""
return {"variables": list(ds.data_vars)}
@app.get("/pressure-levels")
def pressure_levels():
"""
List available pressure levels (in hPa) for any 4D variable.
"""
return {"levels": ds["pressure_level"].values.tolist()}
@app.get("/available-times")
def available_times():
"""
List all forecast timestamps (ISO strings) in order.
"""
return {"times": [str(t) for t in ds["time"].values]}
@app.get("/precipitation-grid")
def precipitation_grid(stride: int = Query(4)):
times = [str(t) for t in ds["time"].values]
lats = ds["lat"].values[::stride]
lons = ds["lon"].values[::stride]
out = []
for t in range(len(ds.time)):
ds_t = ds.isel(time=t)
ppp = compute_precipitation_probability(ds_t) # only 2D array
ppp_vals = ppp[::stride, ::stride].values
frame = []
for i, lat in enumerate(lats):
for j, lon in enumerate(lons):
frame.append({
"lat": float(lat),
"lng": float(lon),
"value": float(ppp_vals[i, j])
})
out.append(frame)
return {"times": times, "ppp": out}