ncview / tensorview /utils.py
Nipun's picture
🌍 TensorView v1.0 - Complete NetCDF/HDF/GRIB viewer
433dab5
"""Utility functions for CF conventions and coordinate system helpers."""
import re
from typing import Dict, List, Optional, Tuple, Any
import numpy as np
import xarray as xr
from pyproj import CRS
def guess_cf_axis(da: xr.DataArray, coord_name: str) -> Optional[str]:
"""
Guess the CF axis type (X, Y, Z, T) for a coordinate.
Args:
da: DataArray containing the coordinate
coord_name: Name of the coordinate
Returns:
CF axis type ('X', 'Y', 'Z', 'T') or None
"""
if coord_name not in da.coords:
return None
coord = da.coords[coord_name]
attrs = coord.attrs
name_lower = coord_name.lower()
# Check explicit axis attribute
if 'axis' in attrs:
return attrs['axis'].upper()
# Check standard_name
standard_name = attrs.get('standard_name', '').lower()
if standard_name in ['longitude', 'projection_x_coordinate']:
return 'X'
elif standard_name in ['latitude', 'projection_y_coordinate']:
return 'Y'
elif standard_name in ['time']:
return 'T'
elif 'altitude' in standard_name or 'height' in standard_name or standard_name == 'air_pressure':
return 'Z'
# Check coordinate name patterns
if any(pattern in name_lower for pattern in ['lon', 'x']):
return 'X'
elif any(pattern in name_lower for pattern in ['lat', 'y']):
return 'Y'
elif any(pattern in name_lower for pattern in ['time', 't']):
return 'T'
elif any(pattern in name_lower for pattern in ['lev', 'level', 'pressure', 'z', 'height', 'alt']):
return 'Z'
# Check units
units = attrs.get('units', '').lower()
if any(unit in units for unit in ['degree_east', 'degrees_east', 'degree_e']):
return 'X'
elif any(unit in units for unit in ['degree_north', 'degrees_north', 'degree_n']):
return 'Y'
elif any(unit in units for unit in ['days since', 'hours since', 'seconds since']):
return 'T'
elif any(unit in units for unit in ['pa', 'hpa', 'mbar', 'mb', 'm', 'km']):
return 'Z'
return None
def identify_coordinates(da: xr.DataArray) -> Dict[str, str]:
"""
Identify coordinate types in a DataArray.
Args:
da: Input DataArray
Returns:
Dictionary mapping axis type to coordinate name
"""
coords = {}
for coord_name in da.dims:
axis = guess_cf_axis(da, coord_name)
if axis:
coords[axis] = coord_name
return coords
def get_crs(da: xr.DataArray) -> Optional[CRS]:
"""
Extract CRS information from a DataArray.
Args:
da: Input DataArray
Returns:
pyproj CRS object or None
"""
# Check for grid_mapping attribute
grid_mapping = da.attrs.get('grid_mapping')
if grid_mapping and grid_mapping in da.coords:
gm_var = da.coords[grid_mapping]
# Try to construct CRS from grid mapping attributes
try:
crs_attrs = dict(gm_var.attrs)
return CRS.from_cf(crs_attrs)
except:
pass
# Check for crs coordinate
if 'crs' in da.coords:
try:
return CRS.from_cf(dict(da.coords['crs'].attrs))
except:
pass
# Check for spatial_ref coordinate (common in rioxarray)
if 'spatial_ref' in da.coords:
try:
spatial_ref = da.coords['spatial_ref']
if hasattr(spatial_ref, 'spatial_ref'):
return CRS.from_wkt(spatial_ref.spatial_ref)
elif 'crs_wkt' in spatial_ref.attrs:
return CRS.from_wkt(spatial_ref.attrs['crs_wkt'])
except:
pass
# Default to geographic CRS if we have lat/lon
coords = identify_coordinates(da)
if 'X' in coords and 'Y' in coords:
x_coord = da.coords[coords['X']]
y_coord = da.coords[coords['Y']]
# Check if coordinates look like geographic
x_range = float(x_coord.max()) - float(x_coord.min())
y_range = float(y_coord.max()) - float(y_coord.min())
if -180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90:
return CRS.from_epsg(4326) # WGS84
return None
def is_geographic(da: xr.DataArray) -> bool:
"""Check if DataArray uses geographic coordinates."""
crs = get_crs(da)
if crs:
return crs.is_geographic
# Fallback: check coordinate ranges
coords = identify_coordinates(da)
if 'X' in coords and 'Y' in coords:
x_coord = da.coords[coords['X']]
y_coord = da.coords[coords['Y']]
x_range = float(x_coord.max()) - float(x_coord.min())
y_range = float(y_coord.max()) - float(y_coord.min())
return (-180 <= x_coord.min() <= x_coord.max() <= 360 and
-90 <= y_coord.min() <= y_coord.max() <= 90)
return False
def ensure_longitude_range(da: xr.DataArray, range_type: str = '180') -> xr.DataArray:
"""
Ensure longitude coordinates are in the specified range.
Args:
da: Input DataArray
range_type: '180' for [-180, 180] or '360' for [0, 360]
Returns:
DataArray with adjusted longitude coordinates
"""
coords = identify_coordinates(da)
if 'X' not in coords:
return da
x_coord = coords['X']
da_copy = da.copy()
if range_type == '180':
# Convert to [-180, 180]
da_copy.coords[x_coord] = ((da_copy.coords[x_coord] + 180) % 360) - 180
elif range_type == '360':
# Convert to [0, 360]
da_copy.coords[x_coord] = da_copy.coords[x_coord] % 360
# Sort by longitude if needed
if x_coord in da_copy.dims:
da_copy = da_copy.sortby(x_coord)
return da_copy
def get_time_bounds(da: xr.DataArray) -> Optional[Tuple[Any, Any]]:
"""Get time bounds from a DataArray."""
coords = identify_coordinates(da)
if 'T' not in coords:
return None
time_coord = da.coords[coords['T']]
return (time_coord.min().values, time_coord.max().values)
def format_value(value: Any, coord_name: str = '') -> str:
"""Format a coordinate value for display."""
if isinstance(value, np.datetime64):
return str(value)[:19] # Remove nanoseconds
elif isinstance(value, (int, np.integer)):
return str(value)
elif isinstance(value, (float, np.floating)):
if 'time' in coord_name.lower():
return f"{value:.1f}"
else:
return f"{value:.3f}"
else:
return str(value)