import os import io import json import tempfile import datetime as dt from typing import List, Optional, Tuple, Union import requests import pandas as pd import plotly.graph_objects as go import gradio as gr API_BASE_URL = "https://gribstream.com/api/v2" TOKEN_ENV_VAR = "GRIB_API" DEFAULT_LEVEL = "sfc" WAVE_MODELS = { "ifswave": "ECMWF IFS Wave (deterministic)", "ifswaef": "ECMWF IFS Wave Ensemble", } WAVE_VARIABLES = { "swh": ("Significant wave height", "m"), "mwd": ("Mean wave direction", "° (true)"), "mwp": ("Mean wave period", "s"), "mp2": ("Mean zero-crossing wave period", "s"), "pp1d": ("Peak wave period", "s"), } class ConfigurationError(RuntimeError): """Raised when the application is missing a required configuration.""" def get_token() -> str: token = os.getenv(TOKEN_ENV_VAR) if not token: raise ConfigurationError( f"The `{TOKEN_ENV_VAR}` secret is missing. " "Add it in your Space settings (Settings → Variables & secrets)." ) return token def to_iso_utc(value: str, label: str) -> str: """Normalise ISO-8601 text to a Z-suffixed UTC string.""" if not value: raise ValueError(f"{label} is required.") text = value.strip() if text.endswith("Z"): text = text[:-1] + "+00:00" try: parsed = dt.datetime.fromisoformat(text) except ValueError as exc: raise ValueError(f"{label} must be a valid ISO-8601 timestamp.") from exc if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=dt.timezone.utc) else: parsed = parsed.astimezone(dt.timezone.utc) # Use the canonical Z suffix expected by the API. return parsed.replace(tzinfo=dt.timezone.utc).isoformat().replace("+00:00", "Z") def build_members_list(raw: str, model: str) -> Optional[List[int]]: """Parse a CSV list of ensemble members (only meaningful for ifswaef).""" if model != "ifswaef": return None if not raw: return None members: List[int] = [] for chunk in raw.split(","): chunk = chunk.strip() if not chunk: continue try: members.append(int(chunk)) except ValueError as exc: raise ValueError( "Ensemble members must be a comma-separated list of integers." ) from exc return members or None def make_alias(name: str) -> str: """Create a lowercase alias compatible with the API response.""" cleaned = "".join(ch.lower() if ch.isalnum() else "_" for ch in name) cleaned = "_".join(part for part in cleaned.split("_") if part) return cleaned or "value" def parse_variable_list(raw: str) -> List[str]: """Split a comma/newline separated list of variable names.""" if not raw: return [] parts = [] for chunk in raw.replace("\n", ",").split(","): chunk = chunk.strip() if chunk: parts.append(chunk) return parts def decode_json_response(response: requests.Response) -> List[dict]: """Parse GribStream JSON/NDJSON responses into a list of dictionaries.""" try: payload: Union[List[dict], dict] = response.json() except ValueError: # Fall back to NDJSON-style decoding. data: List[dict] = [] for line in response.text.strip().splitlines(): line = line.strip() if not line: continue data.append(json.loads(line)) return data if isinstance(payload, dict): return [payload] return list(payload) def fetch_wave_history( token: str, model: str, variables: List[dict], *, from_time: Optional[str] = None, until_time: Optional[str] = None, times_list: Optional[List[str]] = None, min_horizon: int, max_horizon: int, coordinates: Optional[List[dict]] = None, grid: Optional[dict] = None, members: Optional[List[int]] = None, accept: str = "application/ndjson", timeout: int = 120, ) -> pd.DataFrame: """Call GribStream's history endpoint and return a dataframe.""" if not variables: raise ValueError("At least one variable must be specified.") if not coordinates and not grid: raise ValueError("Provide either coordinates or a grid definition.") if from_time and until_time: def parse_iso(text: str) -> dt.datetime: text = text.strip() if text.endswith("Z"): text = text[:-1] + "+00:00" return dt.datetime.fromisoformat(text) start_dt = parse_iso(from_time) end_dt = parse_iso(until_time) if end_dt < start_dt: raise ValueError("until_time must not be before from_time.") # If equal, nudge the end forward by a small epsilon to keep API happy. if end_dt == start_dt: end_dt += dt.timedelta(minutes=1) until_time = end_dt.isoformat().replace("+00:00", "Z") elif not times_list: raise ValueError("Provide either a time range or an explicit times list.") url = f"{API_BASE_URL}/{model}/history" headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", "Accept": accept, } payload: dict = { "minHorizon": int(min_horizon), "maxHorizon": int(max_horizon), "variables": variables, } if from_time and until_time: payload["fromTime"] = from_time payload["untilTime"] = until_time if times_list: payload["timesList"] = times_list if coordinates: payload["coordinates"] = coordinates if grid: payload["grid"] = grid if members: payload["members"] = members response = requests.post(url, headers=headers, json=payload, timeout=timeout) try: response.raise_for_status() except requests.HTTPError as exc: detail = response.text[:500] # Trim to keep message readable. raise RuntimeError(f"API request failed: {response.status_code} {detail}") from exc if accept == "text/csv": buffer = io.BytesIO(response.content) df = pd.read_csv(buffer) return df records = decode_json_response(response) if not records: return pd.DataFrame() return pd.DataFrame(records) def prepare_results( df: pd.DataFrame, alias: str, variable: str, ) -> Tuple[pd.DataFrame, go.Figure, str]: """Clean returned dataframe, build the plot, and craft a status message.""" if df.empty: raise ValueError("No data returned for the selected configuration.") expected_cols = {"forecasted_at", "forecasted_time", alias, "lat", "lon"} missing = expected_cols - set(df.columns) if missing: raise ValueError( "Unexpected payload format. Missing columns: " + ", ".join(sorted(missing)) ) df["forecasted_at"] = pd.to_datetime(df["forecasted_at"], utc=True, errors="coerce") df["forecasted_time"] = pd.to_datetime( df["forecasted_time"], utc=True, errors="coerce" ) df = df.dropna(subset=["forecasted_at", "forecasted_time"]) df = df.sort_values("forecasted_time").reset_index(drop=True) df["lead_time_hours"] = ( (df["forecasted_time"] - df["forecasted_at"]).dt.total_seconds() / 3600.0 ) variable_label, unit = WAVE_VARIABLES.get(variable.lower(), (variable, "")) label = f"{variable_label} ({unit})" if unit else variable_label fig = go.Figure() fig.add_trace( go.Scatter( x=df["forecasted_time"], y=df[alias], mode="lines+markers", name=label, hovertemplate=( "Valid: %{{x|%Y-%m-%d %HZ}}
" f"{label}: %{{y:.2f}}
" "Lead: %{{customdata:.0f}} h" ), customdata=df[["lead_time_hours"]], ) ) fig.update_layout( template="plotly_white", xaxis_title="Forecast valid time (UTC)", yaxis_title=label, margin=dict(l=50, r=20, t=40, b=40), ) start = df["forecasted_time"].min().strftime("%Y-%m-%d %H:%M UTC") end = df["forecasted_time"].max().strftime("%Y-%m-%d %H:%M UTC") status = ( f"Retrieved {len(df)} rows from {start} to {end}. " f"Lead times range from {df['lead_time_hours'].min():.0f} h to " f"{df['lead_time_hours'].max():.0f} h." ) display_df = df[ [ "forecasted_time", "forecasted_at", "lead_time_hours", alias, "lat", "lon", ] ].rename( columns={ "forecasted_time": "valid_time", "forecasted_at": "forecast_issue_time", alias: label, } ) return display_df, fig, status def make_global_heatmap( df: pd.DataFrame, alias: str, variable_label: str, valid_time_iso: str, ) -> Optional[go.Figure]: """Generate a Plotly heatmap from a gridded dataframe.""" if alias not in df.columns: return None subset = df.dropna(subset=[alias, "lat", "lon"]).copy() if subset.empty: return None subset["lat"] = subset["lat"].astype(float) subset["lon"] = subset["lon"].astype(float) if "forecasted_time" in subset.columns: subset["forecasted_time"] = pd.to_datetime( subset["forecasted_time"], utc=True, errors="coerce" ) subset = subset.dropna(subset=["forecasted_time"]) if not subset.empty: target_time = subset["forecasted_time"].min() subset = subset[subset["forecasted_time"] == target_time] if "member" in subset.columns: subset = subset.sort_values("member").groupby(["lat", "lon"], as_index=False).first() subset = subset.sort_values(["lat", "lon"]) subset = subset.drop_duplicates(subset=["lat", "lon"], keep="first") if subset.empty: return None pivot = subset.pivot(index="lat", columns="lon", values=alias) if pivot.empty: return None pivot = pivot.sort_index(ascending=False) lats = pivot.index.to_list() lons = pivot.columns.to_list() values = pivot.values fig = go.Figure( data=go.Heatmap( x=lons, y=lats, z=values, colorscale="Viridis", colorbar=dict(title=variable_label), ) ) fig.update_layout( title=f"{variable_label} at {valid_time_iso}", xaxis_title="Longitude", yaxis_title="Latitude", template="plotly_white", margin=dict(l=60, r=20, t=60, b=40), ) return fig def run_query( model: str, variable: str, custom_variable: str, latitude: float, longitude: float, from_time: str, until_time: str, min_horizon: int, max_horizon: int, raw_members: str, ) -> Tuple[str, Optional[pd.DataFrame], Optional[go.Figure]]: try: token = get_token() dropdown_value = (variable or "").strip() custom_value = (custom_variable or "").strip() variable_name = custom_value or dropdown_value.lower() if not variable_name: raise ValueError("Select a variable or provide a custom variable name.") if latitude is None or longitude is None: raise ValueError("Latitude and longitude are required.") if not (-90 <= latitude <= 90 and -180 <= longitude <= 180): raise ValueError("Latitude must be within [-90, 90] and longitude within [-180, 180].") window_start = to_iso_utc(from_time, "From time") window_end = to_iso_utc(until_time, "Until time") if window_end <= window_start: raise ValueError("Until time must be after the From time.") lower = int(min(min_horizon, max_horizon)) upper = int(max(min_horizon, max_horizon)) members = build_members_list(raw_members, model) alias = make_alias(variable_name) df = fetch_wave_history( token=token, model=model, variables=[{"name": variable_name, "level": DEFAULT_LEVEL, "alias": alias}], from_time=window_start, until_time=window_end, min_horizon=lower, max_horizon=upper, members=members, coordinates=[{"lat": float(latitude), "lon": float(longitude)}], accept="application/ndjson", ) display_df, fig, status = prepare_results(df, alias, variable_name) return status, display_df, fig except ConfigurationError as exc: return f"⚠️ {exc}", None, None except Exception as exc: # noqa: BLE001 return f"❌ {exc}", None, None def run_global_download( model: str, variables: List[str], custom_variables: str, valid_time: str, min_horizon: int, max_horizon: int, grid_step: float, raw_members: str, preview_variable: str, ) -> Tuple[str, Optional[pd.DataFrame], Optional[go.Figure], Optional[str]]: try: token = get_token() selected_from_dropdown = [name.lower() for name in (variables or [])] custom_list = [name.lower() for name in parse_variable_list(custom_variables)] variable_names = list(dict.fromkeys([*selected_from_dropdown, *custom_list])) if not variable_names: raise ValueError("Select at least one variable or enter custom variable names.") valid_iso = to_iso_utc(valid_time, "Valid time") lower = int(min(min_horizon, max_horizon)) upper = int(max(min_horizon, max_horizon)) if grid_step <= 0: raise ValueError("Grid step must be a positive number of degrees.") members = build_members_list(raw_members, model) alias_map = {} variables_payload = [] for name in variable_names: alias = make_alias(name) alias_map[name] = alias variables_payload.append({"name": name, "level": DEFAULT_LEVEL, "alias": alias}) def request_payload(as_times_list: bool) -> pd.DataFrame: kwargs = dict( token=token, model=model, variables=variables_payload, min_horizon=lower, max_horizon=upper, grid={ "minLatitude": -90, "maxLatitude": 90, "minLongitude": -180, "maxLongitude": 180, "step": float(grid_step), }, members=members, accept="text/csv", timeout=240, ) if as_times_list: kwargs["times_list"] = [valid_iso] else: kwargs["from_time"] = valid_iso kwargs["until_time"] = valid_iso return fetch_wave_history(**kwargs) df = request_payload(as_times_list=True) if df.empty: df = request_payload(as_times_list=False) if df.empty: raise ValueError("No data returned for the requested global configuration.") df = df.copy() for col in ("forecasted_time", "forecasted_at"): if col in df.columns: df[col] = pd.to_datetime(df[col], utc=True, errors="coerce") if {"forecasted_time", "forecasted_at"}.issubset(df.columns): df["lead_time_hours"] = ( (df["forecasted_time"] - df["forecasted_at"]).dt.total_seconds() / 3600.0 ) preview = df.head(500).copy() preview_rows = len(preview) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="") try: df.to_csv(tmp.name, index=False) finally: tmp.close() preview_choice = (preview_variable or variable_names[0]).strip().lower() if preview_choice not in alias_map: preview_choice = variable_names[0] alias_for_map = alias_map[preview_choice] variable_label, unit = WAVE_VARIABLES.get( preview_choice, (preview_choice.upper(), ""), ) display_label = f"{variable_label} ({unit})" if unit else variable_label map_fig = make_global_heatmap(df, alias_for_map, display_label, valid_iso) lead_info = "" if "lead_time_hours" in df.columns: lead_series = df["lead_time_hours"].dropna() if not lead_series.empty: lead_info = ( f" Lead times {lead_series.min():.0f}–{lead_series.max():.0f} h." ) status = ( f"Fetched {len(df)} rows across {len(variable_names)} variable(s) " f"for {valid_iso}.{lead_info} Showing the first {preview_rows} rows." ) return status, preview, map_fig, tmp.name except ConfigurationError as exc: return f"⚠️ {exc}", None, None, None except Exception as exc: # noqa: BLE001 return f"❌ {exc}", None, None, None def default_time_window(hours_back: int = 6, hours_forward: int = 24) -> Tuple[str, str]: now = dt.datetime.utcnow().replace(minute=0, second=0, microsecond=0) start = (now - dt.timedelta(hours=hours_back)).isoformat() + "Z" end = (now + dt.timedelta(hours=hours_forward)).isoformat() + "Z" return start, end def default_valid_time(offset_hours: int = 0) -> str: now = dt.datetime.utcnow().replace(minute=0, second=0, microsecond=0) # Align to the previous 3-hour boundary, which matches ECMWF wave output cadence. hours_since_cycle = now.hour % 3 if hours_since_cycle != 0: now -= dt.timedelta(hours=hours_since_cycle) return (now + dt.timedelta(hours=offset_hours)).isoformat() + "Z" def build_interface() -> gr.Blocks: start_default, end_default = default_time_window() valid_default = default_valid_time() with gr.Blocks(title="GribStream IFS Wave Explorer") as demo: gr.Markdown( """ # ECMWF Wave Data Explorer Use your GribStream API token (stored as the `GRIB_API` secret) to pull ECMWF IFS wave forecasts via GribStream. Choose between a point time-series view or a global snapshot download of the latest wave fields. """ ) with gr.Tabs(): with gr.Tab("Point time series"): with gr.Row(): with gr.Column(scale=1, min_width=320): series_model_input = gr.Dropdown( choices=list(WAVE_MODELS.keys()), value="ifswave", label="Wave model", info="Choose `ifswave` for the deterministic run or `ifswaef` for the ensemble.", ) series_variable_input = gr.Dropdown( choices=[code.upper() for code in WAVE_VARIABLES.keys()], value="SWH", label="Variable", info="Wave parameters use ECMWF short names (e.g., SWH height, MWD direction, MWP period).", ) series_custom_variable_input = gr.Textbox( label="Custom variable (optional)", placeholder="Override with another parameter, e.g. swh", info="Leave blank to use the dropdown selection.", ) series_latitude_input = gr.Number( label="Latitude", value=32.0, precision=4, ) series_longitude_input = gr.Number( label="Longitude", value=-64.0, precision=4, ) series_from_time_input = gr.Textbox( label="From time (UTC)", value=start_default, info="ISO 8601 format, e.g. 2025-10-23T00:00:00Z", ) series_until_time_input = gr.Textbox( label="Until time (UTC)", value=end_default, info="ISO 8601 format, must be after the start time.", ) series_min_horizon_input = gr.Slider( label="Minimum forecast horizon (hours)", value=0, minimum=0, maximum=360, step=1, ) series_max_horizon_input = gr.Slider( label="Maximum forecast horizon (hours)", value=72, minimum=0, maximum=360, step=1, ) series_members_input = gr.Textbox( label="Ensemble members (IFS Waef only)", placeholder="e.g. 0,1,2", info="Leave blank for control (0). Ignored for deterministic model.", ) series_submit = gr.Button("Fetch time series", variant="primary") with gr.Column(scale=2): series_status_output = gr.Markdown("Results will appear here once you hit **Fetch**.") series_table_output = gr.Dataframe( interactive=False, wrap=False, ) series_chart_output = gr.Plot(show_label=False) series_submit.click( fn=run_query, inputs=[ series_model_input, series_variable_input, series_custom_variable_input, series_latitude_input, series_longitude_input, series_from_time_input, series_until_time_input, series_min_horizon_input, series_max_horizon_input, series_members_input, ], outputs=[series_status_output, series_table_output, series_chart_output], ) with gr.Tab("Global snapshot download"): gr.Markdown( "Fetch the full global grid for a selected valid time, then download it as CSV. " "Reduce the grid spacing if you need a lighter file." ) with gr.Row(): with gr.Column(scale=1, min_width=320): global_model_input = gr.Dropdown( choices=list(WAVE_MODELS.keys()), value="ifswave", label="Wave model", info="`ifswave` deterministic or `ifswaef` ensemble.", ) global_variables_input = gr.CheckboxGroup( label="Variables", choices=[code.upper() for code in WAVE_VARIABLES.keys()], value=[code.upper() for code in WAVE_VARIABLES.keys()], info="Select one or more parameters to include in the download.", ) global_custom_variables_input = gr.Textbox( label="Additional variables (optional)", placeholder="Comma-separated list, e.g. mp2,pp1d", info="Use ECMWF short names. Combined with the selection above.", ) global_valid_time_input = gr.Textbox( label="Forecast valid time (UTC)", value=valid_default, info="ISO 8601 format corresponding to the wave field time you need.", ) global_min_horizon_input = gr.Slider( label="Minimum forecast horizon (hours)", value=0, minimum=0, maximum=360, step=1, ) global_max_horizon_input = gr.Slider( label="Maximum forecast horizon (hours)", value=24, minimum=0, maximum=360, step=1, ) global_grid_step_input = gr.Slider( label="Grid spacing (degrees)", value=0.5, minimum=0.25, maximum=2.0, step=0.25, ) global_members_input = gr.Textbox( label="Ensemble members (IFS Waef only)", placeholder="e.g. 0,1,2,3", info="Leave blank for default control member. Ignored for deterministic model.", ) global_preview_variable_input = gr.Dropdown( label="Preview variable for map", choices=[code.upper() for code in WAVE_VARIABLES.keys()], value="SWH", info="Used for the heatmap preview below.", ) global_submit = gr.Button("Download global snapshot", variant="primary") with gr.Column(scale=2): global_status_output = gr.Markdown( "The download link and preview will appear here after processing." ) global_preview_output = gr.Dataframe( interactive=False, wrap=False, ) global_map_output = gr.Plot(label="Global map preview", show_label=True) global_file_output = gr.File(label="Download CSV") global_submit.click( fn=run_global_download, inputs=[ global_model_input, global_variables_input, global_custom_variables_input, global_valid_time_input, global_min_horizon_input, global_max_horizon_input, global_grid_step_input, global_members_input, global_preview_variable_input, ], outputs=[global_status_output, global_preview_output, global_map_output, global_file_output], ) demo.queue() return demo if __name__ == "__main__": app = build_interface() app.launch()