|
|
"""Utility functions for CF conventions and coordinate system helpers.""" |
|
|
|
|
|
import re |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
import numpy as np |
|
|
import xarray as xr |
|
|
from pyproj import CRS |
|
|
|
|
|
|
|
|
def guess_cf_axis(da: xr.DataArray, coord_name: str) -> Optional[str]: |
|
|
""" |
|
|
Guess the CF axis type (X, Y, Z, T) for a coordinate. |
|
|
|
|
|
Args: |
|
|
da: DataArray containing the coordinate |
|
|
coord_name: Name of the coordinate |
|
|
|
|
|
Returns: |
|
|
CF axis type ('X', 'Y', 'Z', 'T') or None |
|
|
""" |
|
|
if coord_name not in da.coords: |
|
|
return None |
|
|
|
|
|
coord = da.coords[coord_name] |
|
|
attrs = coord.attrs |
|
|
name_lower = coord_name.lower() |
|
|
|
|
|
|
|
|
if 'axis' in attrs: |
|
|
return attrs['axis'].upper() |
|
|
|
|
|
|
|
|
standard_name = attrs.get('standard_name', '').lower() |
|
|
if standard_name in ['longitude', 'projection_x_coordinate']: |
|
|
return 'X' |
|
|
elif standard_name in ['latitude', 'projection_y_coordinate']: |
|
|
return 'Y' |
|
|
elif standard_name in ['time']: |
|
|
return 'T' |
|
|
elif 'altitude' in standard_name or 'height' in standard_name or standard_name == 'air_pressure': |
|
|
return 'Z' |
|
|
|
|
|
|
|
|
if any(pattern in name_lower for pattern in ['lon', 'x']): |
|
|
return 'X' |
|
|
elif any(pattern in name_lower for pattern in ['lat', 'y']): |
|
|
return 'Y' |
|
|
elif any(pattern in name_lower for pattern in ['time', 't']): |
|
|
return 'T' |
|
|
elif any(pattern in name_lower for pattern in ['lev', 'level', 'pressure', 'z', 'height', 'alt']): |
|
|
return 'Z' |
|
|
|
|
|
|
|
|
units = attrs.get('units', '').lower() |
|
|
if any(unit in units for unit in ['degree_east', 'degrees_east', 'degree_e']): |
|
|
return 'X' |
|
|
elif any(unit in units for unit in ['degree_north', 'degrees_north', 'degree_n']): |
|
|
return 'Y' |
|
|
elif any(unit in units for unit in ['days since', 'hours since', 'seconds since']): |
|
|
return 'T' |
|
|
elif any(unit in units for unit in ['pa', 'hpa', 'mbar', 'mb', 'm', 'km']): |
|
|
return 'Z' |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def identify_coordinates(da: xr.DataArray) -> Dict[str, str]: |
|
|
""" |
|
|
Identify coordinate types in a DataArray. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping axis type to coordinate name |
|
|
""" |
|
|
coords = {} |
|
|
|
|
|
for coord_name in da.dims: |
|
|
axis = guess_cf_axis(da, coord_name) |
|
|
if axis: |
|
|
coords[axis] = coord_name |
|
|
|
|
|
return coords |
|
|
|
|
|
|
|
|
def get_crs(da: xr.DataArray) -> Optional[CRS]: |
|
|
""" |
|
|
Extract CRS information from a DataArray. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
|
|
|
Returns: |
|
|
pyproj CRS object or None |
|
|
""" |
|
|
|
|
|
grid_mapping = da.attrs.get('grid_mapping') |
|
|
if grid_mapping and grid_mapping in da.coords: |
|
|
gm_var = da.coords[grid_mapping] |
|
|
|
|
|
|
|
|
try: |
|
|
crs_attrs = dict(gm_var.attrs) |
|
|
return CRS.from_cf(crs_attrs) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if 'crs' in da.coords: |
|
|
try: |
|
|
return CRS.from_cf(dict(da.coords['crs'].attrs)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if 'spatial_ref' in da.coords: |
|
|
try: |
|
|
spatial_ref = da.coords['spatial_ref'] |
|
|
if hasattr(spatial_ref, 'spatial_ref'): |
|
|
return CRS.from_wkt(spatial_ref.spatial_ref) |
|
|
elif 'crs_wkt' in spatial_ref.attrs: |
|
|
return CRS.from_wkt(spatial_ref.attrs['crs_wkt']) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
coords = identify_coordinates(da) |
|
|
if 'X' in coords and 'Y' in coords: |
|
|
x_coord = da.coords[coords['X']] |
|
|
y_coord = da.coords[coords['Y']] |
|
|
|
|
|
|
|
|
x_range = float(x_coord.max()) - float(x_coord.min()) |
|
|
y_range = float(y_coord.max()) - float(y_coord.min()) |
|
|
|
|
|
if -180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90: |
|
|
return CRS.from_epsg(4326) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def is_geographic(da: xr.DataArray) -> bool: |
|
|
"""Check if DataArray uses geographic coordinates.""" |
|
|
crs = get_crs(da) |
|
|
if crs: |
|
|
return crs.is_geographic |
|
|
|
|
|
|
|
|
coords = identify_coordinates(da) |
|
|
if 'X' in coords and 'Y' in coords: |
|
|
x_coord = da.coords[coords['X']] |
|
|
y_coord = da.coords[coords['Y']] |
|
|
|
|
|
x_range = float(x_coord.max()) - float(x_coord.min()) |
|
|
y_range = float(y_coord.max()) - float(y_coord.min()) |
|
|
|
|
|
return (-180 <= x_coord.min() <= x_coord.max() <= 360 and |
|
|
-90 <= y_coord.min() <= y_coord.max() <= 90) |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def ensure_longitude_range(da: xr.DataArray, range_type: str = '180') -> xr.DataArray: |
|
|
""" |
|
|
Ensure longitude coordinates are in the specified range. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
range_type: '180' for [-180, 180] or '360' for [0, 360] |
|
|
|
|
|
Returns: |
|
|
DataArray with adjusted longitude coordinates |
|
|
""" |
|
|
coords = identify_coordinates(da) |
|
|
if 'X' not in coords: |
|
|
return da |
|
|
|
|
|
x_coord = coords['X'] |
|
|
da_copy = da.copy() |
|
|
|
|
|
if range_type == '180': |
|
|
|
|
|
da_copy.coords[x_coord] = ((da_copy.coords[x_coord] + 180) % 360) - 180 |
|
|
elif range_type == '360': |
|
|
|
|
|
da_copy.coords[x_coord] = da_copy.coords[x_coord] % 360 |
|
|
|
|
|
|
|
|
if x_coord in da_copy.dims: |
|
|
da_copy = da_copy.sortby(x_coord) |
|
|
|
|
|
return da_copy |
|
|
|
|
|
|
|
|
def get_time_bounds(da: xr.DataArray) -> Optional[Tuple[Any, Any]]: |
|
|
"""Get time bounds from a DataArray.""" |
|
|
coords = identify_coordinates(da) |
|
|
if 'T' not in coords: |
|
|
return None |
|
|
|
|
|
time_coord = da.coords[coords['T']] |
|
|
return (time_coord.min().values, time_coord.max().values) |
|
|
|
|
|
|
|
|
def format_value(value: Any, coord_name: str = '') -> str: |
|
|
"""Format a coordinate value for display.""" |
|
|
if isinstance(value, np.datetime64): |
|
|
return str(value)[:19] |
|
|
elif isinstance(value, (int, np.integer)): |
|
|
return str(value) |
|
|
elif isinstance(value, (float, np.floating)): |
|
|
if 'time' in coord_name.lower(): |
|
|
return f"{value:.1f}" |
|
|
else: |
|
|
return f"{value:.3f}" |
|
|
else: |
|
|
return str(value) |