ncview / tensorview /grid.py
Nipun's picture
🌍 TensorView v1.0 - Complete NetCDF/HDF/GRIB viewer
433dab5
"""Grid alignment and combination operations."""
from typing import Literal, Dict, Any, Tuple
import numpy as np
import xarray as xr
from .utils import identify_coordinates, get_crs, is_geographic
def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]:
"""
Align two DataArrays for combination operations.
Args:
a, b: Input DataArrays
method: Alignment method ('reindex', 'interp')
Returns:
Tuple of aligned DataArrays
"""
# Check CRS compatibility
crs_a = get_crs(a)
crs_b = get_crs(b)
if crs_a and crs_b and not crs_a.equals(crs_b):
raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}")
# Get coordinate information
coords_a = identify_coordinates(a)
coords_b = identify_coordinates(b)
# Find common dimensions
common_dims = set(a.dims) & set(b.dims)
if not common_dims:
raise ValueError("No common dimensions found for alignment")
# Align coordinates
if method == "reindex":
# Use nearest neighbor reindexing
a_aligned = a
b_aligned = b
for dim in common_dims:
if dim in a.dims and dim in b.dims:
# Get the union of coordinates for this dimension
coord_a = a.coords[dim]
coord_b = b.coords[dim]
# Use the coordinate with higher resolution
if len(coord_a) >= len(coord_b):
target_coord = coord_a
else:
target_coord = coord_b
# Reindex both arrays to the target coordinate
a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest')
b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest')
elif method == "interp":
# Use interpolation
# Find common coordinate grid
common_coords = {}
for dim in common_dims:
if dim in a.dims and dim in b.dims:
coord_a = a.coords[dim]
coord_b = b.coords[dim]
# Create a common grid (intersection)
min_val = max(float(coord_a.min()), float(coord_b.min()))
max_val = min(float(coord_a.max()), float(coord_b.max()))
# Use the finer resolution
res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0
res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0
res = min(abs(res_a), abs(res_b))
common_coords[dim] = np.arange(min_val, max_val + res, res)
a_aligned = a.interp(common_coords)
b_aligned = b.interp(common_coords)
else:
raise ValueError(f"Unknown alignment method: {method}")
return a_aligned, b_aligned
def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray:
"""
Combine two DataArrays with the specified operation.
Args:
a, b: Input DataArrays
op: Operation ('sum', 'avg', 'diff')
Returns:
Combined DataArray
"""
# Align the arrays first
a_aligned, b_aligned = align_for_combine(a, b)
# Perform the operation
if op == "sum":
result = a_aligned + b_aligned
elif op == "avg":
result = (a_aligned + b_aligned) / 2
elif op == "diff":
result = a_aligned - b_aligned
else:
raise ValueError(f"Unknown operation: {op}")
# Update attributes
result.name = f"{a.name}_{op}_{b.name}"
if op == "sum":
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}"
elif op == "avg":
result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}"
elif op == "diff":
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}"
# Preserve units if they match
if a.attrs.get('units') == b.attrs.get('units'):
result.attrs['units'] = a.attrs.get('units', '')
return result
def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray:
"""
Create a cross-section of the DataArray.
Args:
da: Input DataArray
along: Dimension to keep for the section (e.g., 'time', 'lat')
fixed: Dictionary of {dim: value} for dimensions to fix
Returns:
Cross-section DataArray
"""
if along not in da.dims:
raise ValueError(f"Dimension '{along}' not found in DataArray")
# Start with the full array
result = da
# Apply fixed selections
selection = {}
for dim, value in fixed.items():
if dim not in da.dims:
continue
coord = da.coords[dim]
if isinstance(value, (int, float)):
# Select nearest value
selection[dim] = coord.sel({dim: value}, method='nearest')
elif isinstance(value, str) and 'time' in dim.lower():
# Handle time strings
selection[dim] = value
else:
selection[dim] = value
if selection:
result = result.sel(selection, method='nearest')
# Ensure the 'along' dimension is preserved
if along not in result.dims:
raise ValueError(f"Section operation removed the '{along}' dimension")
# Update metadata
result.attrs = da.attrs.copy()
# Add section info to long_name
section_info = []
for dim, value in fixed.items():
if dim in da.dims:
if isinstance(value, (int, float)):
section_info.append(f"{dim}={value:.3f}")
else:
section_info.append(f"{dim}={value}")
if section_info:
long_name = result.attrs.get('long_name', result.name)
result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})"
return result
def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray:
"""
Aggregate spatially (e.g., zonal mean).
Args:
da: Input DataArray
method: Aggregation method ('mean', 'sum', 'std')
Returns:
Spatially aggregated DataArray
"""
coords = identify_coordinates(da)
spatial_dims = []
if 'X' in coords:
spatial_dims.append(coords['X'])
if 'Y' in coords:
spatial_dims.append(coords['Y'])
if not spatial_dims:
raise ValueError("No spatial dimensions found for aggregation")
# Perform aggregation
if method == "mean":
result = da.mean(dim=spatial_dims)
elif method == "sum":
result = da.sum(dim=spatial_dims)
elif method == "std":
result = da.std(dim=spatial_dims)
else:
raise ValueError(f"Unknown aggregation method: {method}")
# Update attributes
result.attrs = da.attrs.copy()
long_name = result.attrs.get('long_name', result.name)
result.attrs['long_name'] = f"{method.capitalize()} of {long_name}"
return result