import os
import io
import json
import tempfile
import datetime as dt
from typing import List, Optional, Tuple, Union
import requests
import pandas as pd
import plotly.graph_objects as go
import gradio as gr
API_BASE_URL = "https://gribstream.com/api/v2"
TOKEN_ENV_VAR = "GRIB_API"
DEFAULT_LEVEL = "sfc"
WAVE_MODELS = {
"ifswave": "ECMWF IFS Wave (deterministic)",
"ifswaef": "ECMWF IFS Wave Ensemble",
}
WAVE_VARIABLES = {
"swh": ("Significant wave height", "m"),
"mwd": ("Mean wave direction", "° (true)"),
"mwp": ("Mean wave period", "s"),
"mp2": ("Mean zero-crossing wave period", "s"),
"pp1d": ("Peak wave period", "s"),
}
class ConfigurationError(RuntimeError):
"""Raised when the application is missing a required configuration."""
def get_token() -> str:
token = os.getenv(TOKEN_ENV_VAR)
if not token:
raise ConfigurationError(
f"The `{TOKEN_ENV_VAR}` secret is missing. "
"Add it in your Space settings (Settings → Variables & secrets)."
)
return token
def to_iso_utc(value: str, label: str) -> str:
"""Normalise ISO-8601 text to a Z-suffixed UTC string."""
if not value:
raise ValueError(f"{label} is required.")
text = value.strip()
if text.endswith("Z"):
text = text[:-1] + "+00:00"
try:
parsed = dt.datetime.fromisoformat(text)
except ValueError as exc:
raise ValueError(f"{label} must be a valid ISO-8601 timestamp.") from exc
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=dt.timezone.utc)
else:
parsed = parsed.astimezone(dt.timezone.utc)
# Use the canonical Z suffix expected by the API.
return parsed.replace(tzinfo=dt.timezone.utc).isoformat().replace("+00:00", "Z")
def build_members_list(raw: str, model: str) -> Optional[List[int]]:
"""Parse a CSV list of ensemble members (only meaningful for ifswaef)."""
if model != "ifswaef":
return None
if not raw:
return None
members: List[int] = []
for chunk in raw.split(","):
chunk = chunk.strip()
if not chunk:
continue
try:
members.append(int(chunk))
except ValueError as exc:
raise ValueError(
"Ensemble members must be a comma-separated list of integers."
) from exc
return members or None
def make_alias(name: str) -> str:
"""Create a lowercase alias compatible with the API response."""
cleaned = "".join(ch.lower() if ch.isalnum() else "_" for ch in name)
cleaned = "_".join(part for part in cleaned.split("_") if part)
return cleaned or "value"
def parse_variable_list(raw: str) -> List[str]:
"""Split a comma/newline separated list of variable names."""
if not raw:
return []
parts = []
for chunk in raw.replace("\n", ",").split(","):
chunk = chunk.strip()
if chunk:
parts.append(chunk)
return parts
def decode_json_response(response: requests.Response) -> List[dict]:
"""Parse GribStream JSON/NDJSON responses into a list of dictionaries."""
try:
payload: Union[List[dict], dict] = response.json()
except ValueError:
# Fall back to NDJSON-style decoding.
data: List[dict] = []
for line in response.text.strip().splitlines():
line = line.strip()
if not line:
continue
data.append(json.loads(line))
return data
if isinstance(payload, dict):
return [payload]
return list(payload)
def fetch_wave_history(
token: str,
model: str,
variables: List[dict],
*,
from_time: Optional[str] = None,
until_time: Optional[str] = None,
times_list: Optional[List[str]] = None,
min_horizon: int,
max_horizon: int,
coordinates: Optional[List[dict]] = None,
grid: Optional[dict] = None,
members: Optional[List[int]] = None,
accept: str = "application/ndjson",
timeout: int = 120,
) -> pd.DataFrame:
"""Call GribStream's history endpoint and return a dataframe."""
if not variables:
raise ValueError("At least one variable must be specified.")
if not coordinates and not grid:
raise ValueError("Provide either coordinates or a grid definition.")
if from_time and until_time:
def parse_iso(text: str) -> dt.datetime:
text = text.strip()
if text.endswith("Z"):
text = text[:-1] + "+00:00"
return dt.datetime.fromisoformat(text)
start_dt = parse_iso(from_time)
end_dt = parse_iso(until_time)
if end_dt < start_dt:
raise ValueError("until_time must not be before from_time.")
# If equal, nudge the end forward by a small epsilon to keep API happy.
if end_dt == start_dt:
end_dt += dt.timedelta(minutes=1)
until_time = end_dt.isoformat().replace("+00:00", "Z")
elif not times_list:
raise ValueError("Provide either a time range or an explicit times list.")
url = f"{API_BASE_URL}/{model}/history"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"Accept": accept,
}
payload: dict = {
"minHorizon": int(min_horizon),
"maxHorizon": int(max_horizon),
"variables": variables,
}
if from_time and until_time:
payload["fromTime"] = from_time
payload["untilTime"] = until_time
if times_list:
payload["timesList"] = times_list
if coordinates:
payload["coordinates"] = coordinates
if grid:
payload["grid"] = grid
if members:
payload["members"] = members
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
try:
response.raise_for_status()
except requests.HTTPError as exc:
detail = response.text[:500] # Trim to keep message readable.
raise RuntimeError(f"API request failed: {response.status_code} {detail}") from exc
if accept == "text/csv":
buffer = io.BytesIO(response.content)
df = pd.read_csv(buffer)
return df
records = decode_json_response(response)
if not records:
return pd.DataFrame()
return pd.DataFrame(records)
def prepare_results(
df: pd.DataFrame,
alias: str,
variable: str,
) -> Tuple[pd.DataFrame, go.Figure, str]:
"""Clean returned dataframe, build the plot, and craft a status message."""
if df.empty:
raise ValueError("No data returned for the selected configuration.")
expected_cols = {"forecasted_at", "forecasted_time", alias, "lat", "lon"}
missing = expected_cols - set(df.columns)
if missing:
raise ValueError(
"Unexpected payload format. Missing columns: "
+ ", ".join(sorted(missing))
)
df["forecasted_at"] = pd.to_datetime(df["forecasted_at"], utc=True, errors="coerce")
df["forecasted_time"] = pd.to_datetime(
df["forecasted_time"], utc=True, errors="coerce"
)
df = df.dropna(subset=["forecasted_at", "forecasted_time"])
df = df.sort_values("forecasted_time").reset_index(drop=True)
df["lead_time_hours"] = (
(df["forecasted_time"] - df["forecasted_at"]).dt.total_seconds() / 3600.0
)
variable_label, unit = WAVE_VARIABLES.get(variable.lower(), (variable, ""))
label = f"{variable_label} ({unit})" if unit else variable_label
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=df["forecasted_time"],
y=df[alias],
mode="lines+markers",
name=label,
hovertemplate=(
"Valid: %{{x|%Y-%m-%d %HZ}}
"
f"{label}: %{{y:.2f}}
"
"Lead: %{{customdata:.0f}} h"
),
customdata=df[["lead_time_hours"]],
)
)
fig.update_layout(
template="plotly_white",
xaxis_title="Forecast valid time (UTC)",
yaxis_title=label,
margin=dict(l=50, r=20, t=40, b=40),
)
start = df["forecasted_time"].min().strftime("%Y-%m-%d %H:%M UTC")
end = df["forecasted_time"].max().strftime("%Y-%m-%d %H:%M UTC")
status = (
f"Retrieved {len(df)} rows from {start} to {end}. "
f"Lead times range from {df['lead_time_hours'].min():.0f} h to "
f"{df['lead_time_hours'].max():.0f} h."
)
display_df = df[
[
"forecasted_time",
"forecasted_at",
"lead_time_hours",
alias,
"lat",
"lon",
]
].rename(
columns={
"forecasted_time": "valid_time",
"forecasted_at": "forecast_issue_time",
alias: label,
}
)
return display_df, fig, status
def make_global_heatmap(
df: pd.DataFrame,
alias: str,
variable_label: str,
valid_time_iso: str,
) -> Optional[go.Figure]:
"""Generate a Plotly heatmap from a gridded dataframe."""
if alias not in df.columns:
return None
subset = df.dropna(subset=[alias, "lat", "lon"]).copy()
if subset.empty:
return None
subset["lat"] = subset["lat"].astype(float)
subset["lon"] = subset["lon"].astype(float)
if "forecasted_time" in subset.columns:
subset["forecasted_time"] = pd.to_datetime(
subset["forecasted_time"], utc=True, errors="coerce"
)
subset = subset.dropna(subset=["forecasted_time"])
if not subset.empty:
target_time = subset["forecasted_time"].min()
subset = subset[subset["forecasted_time"] == target_time]
if "member" in subset.columns:
subset = subset.sort_values("member").groupby(["lat", "lon"], as_index=False).first()
subset = subset.sort_values(["lat", "lon"])
subset = subset.drop_duplicates(subset=["lat", "lon"], keep="first")
if subset.empty:
return None
pivot = subset.pivot(index="lat", columns="lon", values=alias)
if pivot.empty:
return None
pivot = pivot.sort_index(ascending=False)
lats = pivot.index.to_list()
lons = pivot.columns.to_list()
values = pivot.values
fig = go.Figure(
data=go.Heatmap(
x=lons,
y=lats,
z=values,
colorscale="Viridis",
colorbar=dict(title=variable_label),
)
)
fig.update_layout(
title=f"{variable_label} at {valid_time_iso}",
xaxis_title="Longitude",
yaxis_title="Latitude",
template="plotly_white",
margin=dict(l=60, r=20, t=60, b=40),
)
return fig
def run_query(
model: str,
variable: str,
custom_variable: str,
latitude: float,
longitude: float,
from_time: str,
until_time: str,
min_horizon: int,
max_horizon: int,
raw_members: str,
) -> Tuple[str, Optional[pd.DataFrame], Optional[go.Figure]]:
try:
token = get_token()
dropdown_value = (variable or "").strip()
custom_value = (custom_variable or "").strip()
variable_name = custom_value or dropdown_value.lower()
if not variable_name:
raise ValueError("Select a variable or provide a custom variable name.")
if latitude is None or longitude is None:
raise ValueError("Latitude and longitude are required.")
if not (-90 <= latitude <= 90 and -180 <= longitude <= 180):
raise ValueError("Latitude must be within [-90, 90] and longitude within [-180, 180].")
window_start = to_iso_utc(from_time, "From time")
window_end = to_iso_utc(until_time, "Until time")
if window_end <= window_start:
raise ValueError("Until time must be after the From time.")
lower = int(min(min_horizon, max_horizon))
upper = int(max(min_horizon, max_horizon))
members = build_members_list(raw_members, model)
alias = make_alias(variable_name)
df = fetch_wave_history(
token=token,
model=model,
variables=[{"name": variable_name, "level": DEFAULT_LEVEL, "alias": alias}],
from_time=window_start,
until_time=window_end,
min_horizon=lower,
max_horizon=upper,
members=members,
coordinates=[{"lat": float(latitude), "lon": float(longitude)}],
accept="application/ndjson",
)
display_df, fig, status = prepare_results(df, alias, variable_name)
return status, display_df, fig
except ConfigurationError as exc:
return f"⚠️ {exc}", None, None
except Exception as exc: # noqa: BLE001
return f"❌ {exc}", None, None
def run_global_download(
model: str,
variables: List[str],
custom_variables: str,
valid_time: str,
min_horizon: int,
max_horizon: int,
grid_step: float,
raw_members: str,
preview_variable: str,
) -> Tuple[str, Optional[pd.DataFrame], Optional[go.Figure], Optional[str]]:
try:
token = get_token()
selected_from_dropdown = [name.lower() for name in (variables or [])]
custom_list = [name.lower() for name in parse_variable_list(custom_variables)]
variable_names = list(dict.fromkeys([*selected_from_dropdown, *custom_list]))
if not variable_names:
raise ValueError("Select at least one variable or enter custom variable names.")
valid_iso = to_iso_utc(valid_time, "Valid time")
lower = int(min(min_horizon, max_horizon))
upper = int(max(min_horizon, max_horizon))
if grid_step <= 0:
raise ValueError("Grid step must be a positive number of degrees.")
members = build_members_list(raw_members, model)
alias_map = {}
variables_payload = []
for name in variable_names:
alias = make_alias(name)
alias_map[name] = alias
variables_payload.append({"name": name, "level": DEFAULT_LEVEL, "alias": alias})
def request_payload(as_times_list: bool) -> pd.DataFrame:
kwargs = dict(
token=token,
model=model,
variables=variables_payload,
min_horizon=lower,
max_horizon=upper,
grid={
"minLatitude": -90,
"maxLatitude": 90,
"minLongitude": -180,
"maxLongitude": 180,
"step": float(grid_step),
},
members=members,
accept="text/csv",
timeout=240,
)
if as_times_list:
kwargs["times_list"] = [valid_iso]
else:
kwargs["from_time"] = valid_iso
kwargs["until_time"] = valid_iso
return fetch_wave_history(**kwargs)
df = request_payload(as_times_list=True)
if df.empty:
df = request_payload(as_times_list=False)
if df.empty:
raise ValueError("No data returned for the requested global configuration.")
df = df.copy()
for col in ("forecasted_time", "forecasted_at"):
if col in df.columns:
df[col] = pd.to_datetime(df[col], utc=True, errors="coerce")
if {"forecasted_time", "forecasted_at"}.issubset(df.columns):
df["lead_time_hours"] = (
(df["forecasted_time"] - df["forecasted_at"]).dt.total_seconds() / 3600.0
)
preview = df.head(500).copy()
preview_rows = len(preview)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="")
try:
df.to_csv(tmp.name, index=False)
finally:
tmp.close()
preview_choice = (preview_variable or variable_names[0]).strip().lower()
if preview_choice not in alias_map:
preview_choice = variable_names[0]
alias_for_map = alias_map[preview_choice]
variable_label, unit = WAVE_VARIABLES.get(
preview_choice,
(preview_choice.upper(), ""),
)
display_label = f"{variable_label} ({unit})" if unit else variable_label
map_fig = make_global_heatmap(df, alias_for_map, display_label, valid_iso)
lead_info = ""
if "lead_time_hours" in df.columns:
lead_series = df["lead_time_hours"].dropna()
if not lead_series.empty:
lead_info = (
f" Lead times {lead_series.min():.0f}–{lead_series.max():.0f} h."
)
status = (
f"Fetched {len(df)} rows across {len(variable_names)} variable(s) "
f"for {valid_iso}.{lead_info} Showing the first {preview_rows} rows."
)
return status, preview, map_fig, tmp.name
except ConfigurationError as exc:
return f"⚠️ {exc}", None, None, None
except Exception as exc: # noqa: BLE001
return f"❌ {exc}", None, None, None
def default_time_window(hours_back: int = 6, hours_forward: int = 24) -> Tuple[str, str]:
now = dt.datetime.utcnow().replace(minute=0, second=0, microsecond=0)
start = (now - dt.timedelta(hours=hours_back)).isoformat() + "Z"
end = (now + dt.timedelta(hours=hours_forward)).isoformat() + "Z"
return start, end
def default_valid_time(offset_hours: int = 0) -> str:
now = dt.datetime.utcnow().replace(minute=0, second=0, microsecond=0)
# Align to the previous 3-hour boundary, which matches ECMWF wave output cadence.
hours_since_cycle = now.hour % 3
if hours_since_cycle != 0:
now -= dt.timedelta(hours=hours_since_cycle)
return (now + dt.timedelta(hours=offset_hours)).isoformat() + "Z"
def build_interface() -> gr.Blocks:
start_default, end_default = default_time_window()
valid_default = default_valid_time()
with gr.Blocks(title="GribStream IFS Wave Explorer") as demo:
gr.Markdown(
"""
# ECMWF Wave Data Explorer
Use your GribStream API token (stored as the `GRIB_API` secret) to pull ECMWF IFS wave forecasts via GribStream.
Choose between a point time-series view or a global snapshot download of the latest wave fields.
"""
)
with gr.Tabs():
with gr.Tab("Point time series"):
with gr.Row():
with gr.Column(scale=1, min_width=320):
series_model_input = gr.Dropdown(
choices=list(WAVE_MODELS.keys()),
value="ifswave",
label="Wave model",
info="Choose `ifswave` for the deterministic run or `ifswaef` for the ensemble.",
)
series_variable_input = gr.Dropdown(
choices=[code.upper() for code in WAVE_VARIABLES.keys()],
value="SWH",
label="Variable",
info="Wave parameters use ECMWF short names (e.g., SWH height, MWD direction, MWP period).",
)
series_custom_variable_input = gr.Textbox(
label="Custom variable (optional)",
placeholder="Override with another parameter, e.g. swh",
info="Leave blank to use the dropdown selection.",
)
series_latitude_input = gr.Number(
label="Latitude",
value=32.0,
precision=4,
)
series_longitude_input = gr.Number(
label="Longitude",
value=-64.0,
precision=4,
)
series_from_time_input = gr.Textbox(
label="From time (UTC)",
value=start_default,
info="ISO 8601 format, e.g. 2025-10-23T00:00:00Z",
)
series_until_time_input = gr.Textbox(
label="Until time (UTC)",
value=end_default,
info="ISO 8601 format, must be after the start time.",
)
series_min_horizon_input = gr.Slider(
label="Minimum forecast horizon (hours)",
value=0,
minimum=0,
maximum=360,
step=1,
)
series_max_horizon_input = gr.Slider(
label="Maximum forecast horizon (hours)",
value=72,
minimum=0,
maximum=360,
step=1,
)
series_members_input = gr.Textbox(
label="Ensemble members (IFS Waef only)",
placeholder="e.g. 0,1,2",
info="Leave blank for control (0). Ignored for deterministic model.",
)
series_submit = gr.Button("Fetch time series", variant="primary")
with gr.Column(scale=2):
series_status_output = gr.Markdown("Results will appear here once you hit **Fetch**.")
series_table_output = gr.Dataframe(
interactive=False,
wrap=False,
)
series_chart_output = gr.Plot(show_label=False)
series_submit.click(
fn=run_query,
inputs=[
series_model_input,
series_variable_input,
series_custom_variable_input,
series_latitude_input,
series_longitude_input,
series_from_time_input,
series_until_time_input,
series_min_horizon_input,
series_max_horizon_input,
series_members_input,
],
outputs=[series_status_output, series_table_output, series_chart_output],
)
with gr.Tab("Global snapshot download"):
gr.Markdown(
"Fetch the full global grid for a selected valid time, then download it as CSV. "
"Reduce the grid spacing if you need a lighter file."
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
global_model_input = gr.Dropdown(
choices=list(WAVE_MODELS.keys()),
value="ifswave",
label="Wave model",
info="`ifswave` deterministic or `ifswaef` ensemble.",
)
global_variables_input = gr.CheckboxGroup(
label="Variables",
choices=[code.upper() for code in WAVE_VARIABLES.keys()],
value=[code.upper() for code in WAVE_VARIABLES.keys()],
info="Select one or more parameters to include in the download.",
)
global_custom_variables_input = gr.Textbox(
label="Additional variables (optional)",
placeholder="Comma-separated list, e.g. mp2,pp1d",
info="Use ECMWF short names. Combined with the selection above.",
)
global_valid_time_input = gr.Textbox(
label="Forecast valid time (UTC)",
value=valid_default,
info="ISO 8601 format corresponding to the wave field time you need.",
)
global_min_horizon_input = gr.Slider(
label="Minimum forecast horizon (hours)",
value=0,
minimum=0,
maximum=360,
step=1,
)
global_max_horizon_input = gr.Slider(
label="Maximum forecast horizon (hours)",
value=24,
minimum=0,
maximum=360,
step=1,
)
global_grid_step_input = gr.Slider(
label="Grid spacing (degrees)",
value=0.5,
minimum=0.25,
maximum=2.0,
step=0.25,
)
global_members_input = gr.Textbox(
label="Ensemble members (IFS Waef only)",
placeholder="e.g. 0,1,2,3",
info="Leave blank for default control member. Ignored for deterministic model.",
)
global_preview_variable_input = gr.Dropdown(
label="Preview variable for map",
choices=[code.upper() for code in WAVE_VARIABLES.keys()],
value="SWH",
info="Used for the heatmap preview below.",
)
global_submit = gr.Button("Download global snapshot", variant="primary")
with gr.Column(scale=2):
global_status_output = gr.Markdown(
"The download link and preview will appear here after processing."
)
global_preview_output = gr.Dataframe(
interactive=False,
wrap=False,
)
global_map_output = gr.Plot(label="Global map preview", show_label=True)
global_file_output = gr.File(label="Download CSV")
global_submit.click(
fn=run_global_download,
inputs=[
global_model_input,
global_variables_input,
global_custom_variables_input,
global_valid_time_input,
global_min_horizon_input,
global_max_horizon_input,
global_grid_step_input,
global_members_input,
global_preview_variable_input,
],
outputs=[global_status_output, global_preview_output, global_map_output, global_file_output],
)
demo.queue()
return demo
if __name__ == "__main__":
app = build_interface()
app.launch()