| """Plot FuXi forecast output as weather maps. |
| |
| Self-contained visualization using matplotlib + cartopy with discrete |
| color scales for key weather variables (precipitation, wind speed). |
| |
| Usage: |
| python plot.py --output_dir ./output --channels t2m z500 tp |
| python plot.py --output_dir ./output --steps 1 3 5 --discrete |
| """ |
|
|
| import argparse |
| import logging |
| import os |
| import re |
| from pathlib import Path |
|
|
| import matplotlib as mpl |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import xarray as xr |
|
|
| mpl.use("Agg") |
|
|
| from matplotlib.colors import BoundaryNorm, ListedColormap |
|
|
| try: |
| import cartopy.crs as ccrs |
| import cartopy.feature as cfeature |
|
|
| HAS_CARTOPY = True |
| except ImportError: |
| HAS_CARTOPY = False |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| VAR_COLORS = { |
| "tp": { |
| "levels": [0, 0.1, 1, 2, 5, 10, 15, 20, 30, 40, 50, 100, 300, 1000], |
| "colors": [ |
| "#FFFFFF", "#F0E6C3", "#B6F391", "#52ED52", "#95CFFF", |
| "#368EFF", "#1061FF", "#0033FF", "#FFFF00", "#FFA500", |
| "#FF0000", "#8B2500", "#FF00FF", |
| ], |
| }, |
| "gs": { |
| "levels": [0.0, 5.5, 8.0, 10.8, 13.9, 17.2, 20.8, 24.5, 28.5, 32.7], |
| "colors": [ |
| "#FFFFFF", "#8FCEF0", "#489B9F", "#49B154", "#9FCE51", |
| "#FAE159", "#F8B547", "#F26429", "#DC3328", "#B01A20", |
| ], |
| }, |
| } |
| VAR_COLORS["ws"] = VAR_COLORS["ws10m"] = VAR_COLORS["ws100m"] = VAR_COLORS["gs"] |
|
|
| DEFAULT_CMAP = "viridis" |
|
|
|
|
| |
|
|
| def _strip_altitude(name): |
| """Strip trailing digits: 'z500' β 'z', 'tp' β 'tp'.""" |
| return re.sub(r"\d+$", "", name) |
|
|
|
|
| def get_color_kwargs(var_name, data_values, discrete=False): |
| """Get cmap/norm kwargs for pcolormesh.""" |
| base = _strip_altitude(str(var_name).lower().split("_")[0]) |
| use_discrete = discrete or (base in VAR_COLORS) |
|
|
| if use_discrete and base in VAR_COLORS: |
| cfg = VAR_COLORS[base] |
| n = min(len(cfg["levels"]), len(cfg["colors"])) |
| cmap = ListedColormap(cfg["colors"][:n]) |
| norm = BoundaryNorm(cfg["levels"][:n], ncolors=cmap.N, clip=True) |
| return {"cmap": cmap, "norm": norm} |
|
|
| vals = data_values.astype(np.float32) |
| return {"cmap": DEFAULT_CMAP, "vmin": np.nanmin(vals), "vmax": np.nanmax(vals)} |
|
|
|
|
| def plot_field( |
| data_2d: np.ndarray, |
| lats: np.ndarray, |
| lons: np.ndarray, |
| var_name: str, |
| title: str = "", |
| save_path: str | None = None, |
| discrete: bool = False, |
| coastline: bool = True, |
| gridline: bool = True, |
| dpi: int = 150, |
| ): |
| """Plot a single 2D weather field on a map.""" |
| lon_range = lons[-1] - lons[0] |
| lat_range = abs(lats[0] - lats[-1]) |
| aspect = lon_range / max(lat_range, 1) |
| figsize = (8 * aspect, 6) |
|
|
| projection = ccrs.PlateCarree(180) if HAS_CARTOPY else None |
| fig, ax = plt.subplots(figsize=figsize, subplot_kw={"projection": projection}, dpi=dpi) |
|
|
| color_kwargs = get_color_kwargs(var_name, data_2d, discrete=discrete) |
|
|
| if HAS_CARTOPY: |
| img = ax.pcolormesh( |
| lons, lats, data_2d, |
| transform=ccrs.PlateCarree(), |
| shading="auto", |
| **color_kwargs, |
| ) |
| extent = [float(lons[0]), float(lons[-1]), |
| max(float(min(lats[0], lats[-1])), -89.5), |
| min(float(max(lats[0], lats[-1])), 89.5)] |
| ax.set_extent(extent, crs=ccrs.PlateCarree()) |
| if coastline: |
| ax.add_feature(cfeature.COASTLINE, edgecolor="k", linewidth=0.5) |
| if gridline: |
| gl = ax.gridlines(draw_labels=True, color="gray", alpha=0.5, |
| linewidth=1, linestyle="--", crs=ccrs.PlateCarree()) |
| gl.top_labels = gl.right_labels = False |
| else: |
| img = ax.pcolormesh(lons, lats, data_2d, shading="auto", **color_kwargs) |
|
|
| if title: |
| ax.set_title(title, fontsize=14, fontweight="bold") |
| else: |
| ax.set_title(var_name.upper(), fontsize=14, fontweight="bold") |
|
|
| |
| pos = ax.get_position() |
| cax = fig.add_axes([pos.x1 + 0.008, pos.y0, 0.015, pos.height]) |
| cbar = plt.colorbar(img, cax=cax, orientation="vertical", extend="both", extendfrac=0.03) |
| cbar.ax.tick_params(labelsize=10) |
| cbar.set_label(var_name, size=10) |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=dpi) |
| plt.close(fig) |
| else: |
| plt.show() |
|
|
| return save_path |
|
|
|
|
| |
|
|
| def plot_forecast( |
| output_dir: str | Path, |
| channels: list[str] | None = None, |
| steps: list[int] | None = None, |
| plot_dir: str | Path | None = None, |
| discrete: bool = False, |
| coastline: bool = True, |
| gridline: bool = True, |
| dpi: int = 150, |
| ): |
| """Plot forecast fields from saved NetCDF output files.""" |
| output_dir = Path(output_dir) |
| if plot_dir is None: |
| plot_dir = output_dir / "plots" |
| plot_dir = Path(plot_dir) |
| plot_dir.mkdir(parents=True, exist_ok=True) |
|
|
| nc_files = sorted([f for f in output_dir.glob("*.nc") if f.stem.isdigit()]) |
| if not nc_files: |
| logger.error(f"No step .nc files (001.nc, 002.nc, ...) found in {output_dir}") |
| return |
|
|
| if steps: |
| nc_files = [f for f in nc_files if int(f.stem) in steps] |
|
|
| for nc_path in nc_files: |
| da = xr.open_dataarray(nc_path) |
| step_idx = int(nc_path.stem) |
| valid_time = da.attrs.get("valid_time", "") |
|
|
| available = list(da.coords["channel"].values) |
| plot_channels = channels if channels else available |
| valid_ch = [c for c in plot_channels if c in available] |
| if not valid_ch: |
| logger.warning(f"None of {plot_channels} found in {nc_path.name}, skipping") |
| continue |
|
|
| lats = da.coords["lat"].values |
| lons = da.coords["lon"].values |
|
|
| for ch in valid_ch: |
| field = da.sel(channel=ch).values |
| title = f"{ch.upper()} | Step {step_idx}" |
| if valid_time: |
| title += f" | {valid_time}" |
| save_path = str(plot_dir / f"step{step_idx:03d}_{ch}.png") |
| plot_field( |
| field, lats, lons, ch, |
| title=title, save_path=save_path, |
| discrete=discrete, coastline=coastline, |
| gridline=gridline, dpi=dpi, |
| ) |
| print(f"Saved: {save_path}") |
|
|
| logger.info(f"All plots saved to {plot_dir}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Plot FuXi forecast fields") |
| parser.add_argument("--output_dir", required=True, help="Directory with .nc output files") |
| parser.add_argument("--channels", nargs="+", default=None, |
| help="Channels to plot (default: all)") |
| parser.add_argument("--steps", nargs="+", type=int, default=None, |
| help="Step indices to plot (default: all)") |
| parser.add_argument("--plot_dir", default=None, help="Output directory for plots") |
| parser.add_argument("--discrete", action="store_true", help="Use discrete color scales") |
| parser.add_argument("--no-coastline", action="store_true", help="Hide coastlines") |
| parser.add_argument("--no-gridline", action="store_true", help="Hide gridlines") |
| parser.add_argument("--dpi", type=int, default=150, help="Output resolution") |
| args = parser.parse_args() |
| plot_forecast( |
| args.output_dir, |
| channels=args.channels, |
| steps=args.steps, |
| plot_dir=args.plot_dir, |
| discrete=args.discrete, |
| coastline=not args.no_coastline, |
| gridline=not args.no_gridline, |
| dpi=args.dpi, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|