File size: 6,683 Bytes
433dab5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
"""Utility functions for CF conventions and coordinate system helpers."""
import re
from typing import Dict, List, Optional, Tuple, Any
import numpy as np
import xarray as xr
from pyproj import CRS
def guess_cf_axis(da: xr.DataArray, coord_name: str) -> Optional[str]:
"""
Guess the CF axis type (X, Y, Z, T) for a coordinate.
Args:
da: DataArray containing the coordinate
coord_name: Name of the coordinate
Returns:
CF axis type ('X', 'Y', 'Z', 'T') or None
"""
if coord_name not in da.coords:
return None
coord = da.coords[coord_name]
attrs = coord.attrs
name_lower = coord_name.lower()
# Check explicit axis attribute
if 'axis' in attrs:
return attrs['axis'].upper()
# Check standard_name
standard_name = attrs.get('standard_name', '').lower()
if standard_name in ['longitude', 'projection_x_coordinate']:
return 'X'
elif standard_name in ['latitude', 'projection_y_coordinate']:
return 'Y'
elif standard_name in ['time']:
return 'T'
elif 'altitude' in standard_name or 'height' in standard_name or standard_name == 'air_pressure':
return 'Z'
# Check coordinate name patterns
if any(pattern in name_lower for pattern in ['lon', 'x']):
return 'X'
elif any(pattern in name_lower for pattern in ['lat', 'y']):
return 'Y'
elif any(pattern in name_lower for pattern in ['time', 't']):
return 'T'
elif any(pattern in name_lower for pattern in ['lev', 'level', 'pressure', 'z', 'height', 'alt']):
return 'Z'
# Check units
units = attrs.get('units', '').lower()
if any(unit in units for unit in ['degree_east', 'degrees_east', 'degree_e']):
return 'X'
elif any(unit in units for unit in ['degree_north', 'degrees_north', 'degree_n']):
return 'Y'
elif any(unit in units for unit in ['days since', 'hours since', 'seconds since']):
return 'T'
elif any(unit in units for unit in ['pa', 'hpa', 'mbar', 'mb', 'm', 'km']):
return 'Z'
return None
def identify_coordinates(da: xr.DataArray) -> Dict[str, str]:
"""
Identify coordinate types in a DataArray.
Args:
da: Input DataArray
Returns:
Dictionary mapping axis type to coordinate name
"""
coords = {}
for coord_name in da.dims:
axis = guess_cf_axis(da, coord_name)
if axis:
coords[axis] = coord_name
return coords
def get_crs(da: xr.DataArray) -> Optional[CRS]:
"""
Extract CRS information from a DataArray.
Args:
da: Input DataArray
Returns:
pyproj CRS object or None
"""
# Check for grid_mapping attribute
grid_mapping = da.attrs.get('grid_mapping')
if grid_mapping and grid_mapping in da.coords:
gm_var = da.coords[grid_mapping]
# Try to construct CRS from grid mapping attributes
try:
crs_attrs = dict(gm_var.attrs)
return CRS.from_cf(crs_attrs)
except:
pass
# Check for crs coordinate
if 'crs' in da.coords:
try:
return CRS.from_cf(dict(da.coords['crs'].attrs))
except:
pass
# Check for spatial_ref coordinate (common in rioxarray)
if 'spatial_ref' in da.coords:
try:
spatial_ref = da.coords['spatial_ref']
if hasattr(spatial_ref, 'spatial_ref'):
return CRS.from_wkt(spatial_ref.spatial_ref)
elif 'crs_wkt' in spatial_ref.attrs:
return CRS.from_wkt(spatial_ref.attrs['crs_wkt'])
except:
pass
# Default to geographic CRS if we have lat/lon
coords = identify_coordinates(da)
if 'X' in coords and 'Y' in coords:
x_coord = da.coords[coords['X']]
y_coord = da.coords[coords['Y']]
# Check if coordinates look like geographic
x_range = float(x_coord.max()) - float(x_coord.min())
y_range = float(y_coord.max()) - float(y_coord.min())
if -180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90:
return CRS.from_epsg(4326) # WGS84
return None
def is_geographic(da: xr.DataArray) -> bool:
"""Check if DataArray uses geographic coordinates."""
crs = get_crs(da)
if crs:
return crs.is_geographic
# Fallback: check coordinate ranges
coords = identify_coordinates(da)
if 'X' in coords and 'Y' in coords:
x_coord = da.coords[coords['X']]
y_coord = da.coords[coords['Y']]
x_range = float(x_coord.max()) - float(x_coord.min())
y_range = float(y_coord.max()) - float(y_coord.min())
return (-180 <= x_coord.min() <= x_coord.max() <= 360 and
-90 <= y_coord.min() <= y_coord.max() <= 90)
return False
def ensure_longitude_range(da: xr.DataArray, range_type: str = '180') -> xr.DataArray:
"""
Ensure longitude coordinates are in the specified range.
Args:
da: Input DataArray
range_type: '180' for [-180, 180] or '360' for [0, 360]
Returns:
DataArray with adjusted longitude coordinates
"""
coords = identify_coordinates(da)
if 'X' not in coords:
return da
x_coord = coords['X']
da_copy = da.copy()
if range_type == '180':
# Convert to [-180, 180]
da_copy.coords[x_coord] = ((da_copy.coords[x_coord] + 180) % 360) - 180
elif range_type == '360':
# Convert to [0, 360]
da_copy.coords[x_coord] = da_copy.coords[x_coord] % 360
# Sort by longitude if needed
if x_coord in da_copy.dims:
da_copy = da_copy.sortby(x_coord)
return da_copy
def get_time_bounds(da: xr.DataArray) -> Optional[Tuple[Any, Any]]:
"""Get time bounds from a DataArray."""
coords = identify_coordinates(da)
if 'T' not in coords:
return None
time_coord = da.coords[coords['T']]
return (time_coord.min().values, time_coord.max().values)
def format_value(value: Any, coord_name: str = '') -> str:
"""Format a coordinate value for display."""
if isinstance(value, np.datetime64):
return str(value)[:19] # Remove nanoseconds
elif isinstance(value, (int, np.integer)):
return str(value)
elif isinstance(value, (float, np.floating)):
if 'time' in coord_name.lower():
return f"{value:.1f}"
else:
return f"{value:.3f}"
else:
return str(value) |