"""Utility functions for CF conventions and coordinate system helpers.""" import re from typing import Dict, List, Optional, Tuple, Any import numpy as np import xarray as xr from pyproj import CRS def guess_cf_axis(da: xr.DataArray, coord_name: str) -> Optional[str]: """ Guess the CF axis type (X, Y, Z, T) for a coordinate. Args: da: DataArray containing the coordinate coord_name: Name of the coordinate Returns: CF axis type ('X', 'Y', 'Z', 'T') or None """ if coord_name not in da.coords: return None coord = da.coords[coord_name] attrs = coord.attrs name_lower = coord_name.lower() # Check explicit axis attribute if 'axis' in attrs: return attrs['axis'].upper() # Check standard_name standard_name = attrs.get('standard_name', '').lower() if standard_name in ['longitude', 'projection_x_coordinate']: return 'X' elif standard_name in ['latitude', 'projection_y_coordinate']: return 'Y' elif standard_name in ['time']: return 'T' elif 'altitude' in standard_name or 'height' in standard_name or standard_name == 'air_pressure': return 'Z' # Check coordinate name patterns if any(pattern in name_lower for pattern in ['lon', 'x']): return 'X' elif any(pattern in name_lower for pattern in ['lat', 'y']): return 'Y' elif any(pattern in name_lower for pattern in ['time', 't']): return 'T' elif any(pattern in name_lower for pattern in ['lev', 'level', 'pressure', 'z', 'height', 'alt']): return 'Z' # Check units units = attrs.get('units', '').lower() if any(unit in units for unit in ['degree_east', 'degrees_east', 'degree_e']): return 'X' elif any(unit in units for unit in ['degree_north', 'degrees_north', 'degree_n']): return 'Y' elif any(unit in units for unit in ['days since', 'hours since', 'seconds since']): return 'T' elif any(unit in units for unit in ['pa', 'hpa', 'mbar', 'mb', 'm', 'km']): return 'Z' return None def identify_coordinates(da: xr.DataArray) -> Dict[str, str]: """ Identify coordinate types in a DataArray. Args: da: Input DataArray Returns: Dictionary mapping axis type to coordinate name """ coords = {} for coord_name in da.dims: axis = guess_cf_axis(da, coord_name) if axis: coords[axis] = coord_name return coords def get_crs(da: xr.DataArray) -> Optional[CRS]: """ Extract CRS information from a DataArray. Args: da: Input DataArray Returns: pyproj CRS object or None """ # Check for grid_mapping attribute grid_mapping = da.attrs.get('grid_mapping') if grid_mapping and grid_mapping in da.coords: gm_var = da.coords[grid_mapping] # Try to construct CRS from grid mapping attributes try: crs_attrs = dict(gm_var.attrs) return CRS.from_cf(crs_attrs) except: pass # Check for crs coordinate if 'crs' in da.coords: try: return CRS.from_cf(dict(da.coords['crs'].attrs)) except: pass # Check for spatial_ref coordinate (common in rioxarray) if 'spatial_ref' in da.coords: try: spatial_ref = da.coords['spatial_ref'] if hasattr(spatial_ref, 'spatial_ref'): return CRS.from_wkt(spatial_ref.spatial_ref) elif 'crs_wkt' in spatial_ref.attrs: return CRS.from_wkt(spatial_ref.attrs['crs_wkt']) except: pass # Default to geographic CRS if we have lat/lon coords = identify_coordinates(da) if 'X' in coords and 'Y' in coords: x_coord = da.coords[coords['X']] y_coord = da.coords[coords['Y']] # Check if coordinates look like geographic x_range = float(x_coord.max()) - float(x_coord.min()) y_range = float(y_coord.max()) - float(y_coord.min()) if -180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90: return CRS.from_epsg(4326) # WGS84 return None def is_geographic(da: xr.DataArray) -> bool: """Check if DataArray uses geographic coordinates.""" crs = get_crs(da) if crs: return crs.is_geographic # Fallback: check coordinate ranges coords = identify_coordinates(da) if 'X' in coords and 'Y' in coords: x_coord = da.coords[coords['X']] y_coord = da.coords[coords['Y']] x_range = float(x_coord.max()) - float(x_coord.min()) y_range = float(y_coord.max()) - float(y_coord.min()) return (-180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90) return False def ensure_longitude_range(da: xr.DataArray, range_type: str = '180') -> xr.DataArray: """ Ensure longitude coordinates are in the specified range. Args: da: Input DataArray range_type: '180' for [-180, 180] or '360' for [0, 360] Returns: DataArray with adjusted longitude coordinates """ coords = identify_coordinates(da) if 'X' not in coords: return da x_coord = coords['X'] da_copy = da.copy() if range_type == '180': # Convert to [-180, 180] da_copy.coords[x_coord] = ((da_copy.coords[x_coord] + 180) % 360) - 180 elif range_type == '360': # Convert to [0, 360] da_copy.coords[x_coord] = da_copy.coords[x_coord] % 360 # Sort by longitude if needed if x_coord in da_copy.dims: da_copy = da_copy.sortby(x_coord) return da_copy def get_time_bounds(da: xr.DataArray) -> Optional[Tuple[Any, Any]]: """Get time bounds from a DataArray.""" coords = identify_coordinates(da) if 'T' not in coords: return None time_coord = da.coords[coords['T']] return (time_coord.min().values, time_coord.max().values) def format_value(value: Any, coord_name: str = '') -> str: """Format a coordinate value for display.""" if isinstance(value, np.datetime64): return str(value)[:19] # Remove nanoseconds elif isinstance(value, (int, np.integer)): return str(value) elif isinstance(value, (float, np.floating)): if 'time' in coord_name.lower(): return f"{value:.1f}" else: return f"{value:.3f}" else: return str(value)