"""Plotting functions for 1D, 2D, and map visualizations.""" import io import os from typing import Optional, Dict, Any, Tuple, Literal import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors from matplotlib.figure import Figure from matplotlib.axes import Axes import xarray as xr try: import cartopy.crs as ccrs import cartopy.feature as cfeature HAS_CARTOPY = True except ImportError: HAS_CARTOPY = False from .utils import identify_coordinates, get_crs, is_geographic, format_value def setup_matplotlib(): """Setup matplotlib with non-interactive backend.""" plt.switch_backend('Agg') plt.style.use('default') def plot_1d(da: xr.DataArray, x_dim: Optional[str] = None, **style) -> Figure: """ Create a 1D line plot. Args: da: Input DataArray (should be 1D or have only one varying dimension) x_dim: Dimension to use as x-axis (auto-detected if None) **style: Style parameters (color, linewidth, etc.) Returns: matplotlib Figure """ setup_matplotlib() # Find the appropriate dimension for x-axis if x_dim is None: # Find the first dimension with more than 1 element for dim in da.dims: if da.sizes[dim] > 1: x_dim = dim break if x_dim is None: raise ValueError("No suitable dimension found for 1D plot") if x_dim not in da.dims: raise ValueError(f"Dimension '{x_dim}' not found in DataArray") # Create the figure fig, ax = plt.subplots(figsize=(10, 6)) # Get data for plotting x_data = da.coords[x_dim] y_data = da # Plot the data line_style = { 'color': style.get('color', 'blue'), 'linewidth': style.get('linewidth', 1.5), 'linestyle': style.get('linestyle', '-'), 'marker': style.get('marker', ''), 'markersize': style.get('markersize', 4), 'alpha': style.get('alpha', 1.0) } ax.plot(x_data, y_data, **line_style) # Set labels ax.set_xlabel(f"{x_dim} ({x_data.attrs.get('units', '')})") ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") # Set title title = da.attrs.get('long_name', da.name or 'Data') ax.set_title(title) # Add grid if requested if style.get('grid', True): ax.grid(True, alpha=0.3) # Handle time axis formatting if 'time' in x_dim.lower() or x_data.dtype.kind == 'M': fig.autofmt_xdate() plt.tight_layout() return fig def plot_2d(da: xr.DataArray, kind: Literal["image", "contour"] = "image", x_dim: Optional[str] = None, y_dim: Optional[str] = None, **style) -> Figure: """ Create a 2D plot (image or contour). Args: da: Input DataArray (should be 2D) kind: Plot type ('image' or 'contour') x_dim, y_dim: Dimensions to use for axes **style: Style parameters Returns: matplotlib Figure """ setup_matplotlib() # Auto-detect dimensions if not provided if x_dim is None or y_dim is None: coords = identify_coordinates(da) if x_dim is None: x_dim = coords.get('X', da.dims[-1]) # Default to last dimension if y_dim is None: y_dim = coords.get('Y', da.dims[-2]) # Default to second-to-last dimension if x_dim not in da.dims or y_dim not in da.dims: raise ValueError(f"Dimensions {x_dim}, {y_dim} not found in DataArray") # Transpose to get (y, x) order for plotting da_plot = da.transpose(y_dim, x_dim) # Create figure fig, ax = plt.subplots(figsize=(10, 8)) # Get coordinates x_coord = da.coords[x_dim] y_coord = da.coords[y_dim] # Set up colormap cmap = style.get('cmap', 'viridis') if isinstance(cmap, str): cmap = plt.get_cmap(cmap) # Set up normalization vmin = style.get('vmin', float(da.min().values)) vmax = style.get('vmax', float(da.max().values)) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) if kind == "image": # Use imshow for regular grids im = ax.imshow(da_plot.values, extent=[float(x_coord.min()), float(x_coord.max()), float(y_coord.min()), float(y_coord.max())], aspect='auto', origin='lower', cmap=cmap, norm=norm) elif kind == "contour": # Use contourf for contour plots levels = style.get('levels', 20) if isinstance(levels, int): levels = np.linspace(vmin, vmax, levels) X, Y = np.meshgrid(x_coord, y_coord) im = ax.contourf(X, Y, da_plot.values, levels=levels, cmap=cmap, norm=norm) # Add contour lines if requested if style.get('contour_lines', False): cs = ax.contour(X, Y, da_plot.values, levels=levels, colors='k', linewidths=0.5) ax.clabel(cs, inline=True, fontsize=8) # Add colorbar if style.get('colorbar', True): cbar = plt.colorbar(im, ax=ax) cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") # Set labels ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})") ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})") # Set title title = da.attrs.get('long_name', da.name or 'Data') ax.set_title(title) plt.tight_layout() return fig def plot_map(da: xr.DataArray, proj: str = "PlateCarree", **style) -> Figure: """ Create a map plot with cartopy. Args: da: Input DataArray with geographic coordinates proj: Map projection name **style: Style parameters Returns: matplotlib Figure """ if not HAS_CARTOPY: raise ImportError("Cartopy is required for map plotting") setup_matplotlib() # Check if data is geographic if not is_geographic(da): raise ValueError("DataArray does not appear to have geographic coordinates") # Get coordinate information coords = identify_coordinates(da) if 'X' not in coords or 'Y' not in coords: raise ValueError("Could not identify longitude/latitude coordinates") lon_dim = coords['X'] lat_dim = coords['Y'] # Set up projection proj_map = { 'PlateCarree': ccrs.PlateCarree(), 'Robinson': ccrs.Robinson(), 'Mollweide': ccrs.Mollweide(), 'Orthographic': ccrs.Orthographic(), 'NorthPolarStereo': ccrs.NorthPolarStereo(), 'SouthPolarStereo': ccrs.SouthPolarStereo(), 'Miller': ccrs.Miller(), 'InterruptedGoodeHomolosine': ccrs.InterruptedGoodeHomolosine() } if proj not in proj_map: proj = 'PlateCarree' # Default fallback projection = proj_map[proj] # Create figure with cartopy fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': projection}) # Transpose to get (lat, lon) order da_plot = da.transpose(lat_dim, lon_dim) # Get coordinates lons = da.coords[lon_dim].values lats = da.coords[lat_dim].values # Set up colormap and normalization cmap = style.get('cmap', 'viridis') if isinstance(cmap, str): cmap = plt.get_cmap(cmap) vmin = style.get('vmin', float(da.min().values)) vmax = style.get('vmax', float(da.max().values)) # Create plot plot_type = style.get('plot_type', 'pcolormesh') if plot_type == 'contourf': levels = style.get('levels', 20) if isinstance(levels, int): levels = np.linspace(vmin, vmax, levels) im = ax.contourf(lons, lats, da_plot.values, levels=levels, cmap=cmap, transform=ccrs.PlateCarree()) else: im = ax.pcolormesh(lons, lats, da_plot.values, cmap=cmap, transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax, shading='auto') # Add map features if style.get('coastlines', True): ax.coastlines(resolution='50m', color='black', linewidth=0.5) if style.get('borders', False): ax.add_feature(cfeature.BORDERS, linewidth=0.5) if style.get('ocean', False): ax.add_feature(cfeature.OCEAN, color='lightblue', alpha=0.5) if style.get('land', False): ax.add_feature(cfeature.LAND, color='lightgray', alpha=0.5) # Add gridlines if style.get('gridlines', True): gl = ax.gridlines(draw_labels=True, alpha=0.5) gl.top_labels = False gl.right_labels = False # Set extent if specified if 'extent' in style: ax.set_extent(style['extent'], crs=ccrs.PlateCarree()) else: ax.set_global() # Add colorbar if style.get('colorbar', True): cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8) cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})") # Set title title = da.attrs.get('long_name', da.name or 'Data') ax.set_title(title, pad=20) plt.tight_layout() return fig def export_fig(fig: Figure, fmt: Literal["png", "svg", "pdf"] = "png", dpi: int = 150, out_path: Optional[str] = None) -> str: """ Export a figure to file or return as bytes. Args: fig: matplotlib Figure fmt: Output format dpi: Resolution for raster formats out_path: Output file path (if None, returns bytes) Returns: File path or bytes """ if out_path is None: # Return as bytes buf = io.BytesIO() fig.savefig(buf, format=fmt, dpi=dpi, bbox_inches='tight') buf.seek(0) return buf.getvalue() else: # Save to file fig.savefig(out_path, format=fmt, dpi=dpi, bbox_inches='tight') return out_path def create_subplot_figure(n_plots: int, ncols: int = 2) -> Tuple[Figure, np.ndarray]: """Create a figure with multiple subplots.""" nrows = (n_plots + ncols - 1) // ncols fig, axes = plt.subplots(nrows, ncols, figsize=(6*ncols, 4*nrows)) if n_plots == 1: axes = np.array([axes]) elif nrows == 1: axes = axes.reshape(1, -1) # Hide unused subplots for i in range(n_plots, nrows * ncols): axes.flat[i].set_visible(False) return fig, axes def add_statistics_text(ax: Axes, da: xr.DataArray, x: float = 0.02, y: float = 0.98): """Add statistics text to a plot.""" stats = [ f"Min: {float(da.min().values):.3g}", f"Max: {float(da.max().values):.3g}", f"Mean: {float(da.mean().values):.3g}", f"Std: {float(da.std().values):.3g}" ] text = '\n'.join(stats) ax.text(x, y, text, transform=ax.transAxes, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), verticalalignment='top', fontsize=8)