|
|
"""Grid alignment and combination operations.""" |
|
|
|
|
|
from typing import Literal, Dict, Any, Tuple |
|
|
import numpy as np |
|
|
import xarray as xr |
|
|
from .utils import identify_coordinates, get_crs, is_geographic |
|
|
|
|
|
|
|
|
def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]: |
|
|
""" |
|
|
Align two DataArrays for combination operations. |
|
|
|
|
|
Args: |
|
|
a, b: Input DataArrays |
|
|
method: Alignment method ('reindex', 'interp') |
|
|
|
|
|
Returns: |
|
|
Tuple of aligned DataArrays |
|
|
""" |
|
|
|
|
|
crs_a = get_crs(a) |
|
|
crs_b = get_crs(b) |
|
|
|
|
|
if crs_a and crs_b and not crs_a.equals(crs_b): |
|
|
raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}") |
|
|
|
|
|
|
|
|
coords_a = identify_coordinates(a) |
|
|
coords_b = identify_coordinates(b) |
|
|
|
|
|
|
|
|
common_dims = set(a.dims) & set(b.dims) |
|
|
|
|
|
if not common_dims: |
|
|
raise ValueError("No common dimensions found for alignment") |
|
|
|
|
|
|
|
|
if method == "reindex": |
|
|
|
|
|
a_aligned = a |
|
|
b_aligned = b |
|
|
|
|
|
for dim in common_dims: |
|
|
if dim in a.dims and dim in b.dims: |
|
|
|
|
|
coord_a = a.coords[dim] |
|
|
coord_b = b.coords[dim] |
|
|
|
|
|
|
|
|
if len(coord_a) >= len(coord_b): |
|
|
target_coord = coord_a |
|
|
else: |
|
|
target_coord = coord_b |
|
|
|
|
|
|
|
|
a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest') |
|
|
b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest') |
|
|
|
|
|
elif method == "interp": |
|
|
|
|
|
|
|
|
common_coords = {} |
|
|
for dim in common_dims: |
|
|
if dim in a.dims and dim in b.dims: |
|
|
coord_a = a.coords[dim] |
|
|
coord_b = b.coords[dim] |
|
|
|
|
|
|
|
|
min_val = max(float(coord_a.min()), float(coord_b.min())) |
|
|
max_val = min(float(coord_a.max()), float(coord_b.max())) |
|
|
|
|
|
|
|
|
res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0 |
|
|
res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0 |
|
|
res = min(abs(res_a), abs(res_b)) |
|
|
|
|
|
common_coords[dim] = np.arange(min_val, max_val + res, res) |
|
|
|
|
|
a_aligned = a.interp(common_coords) |
|
|
b_aligned = b.interp(common_coords) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown alignment method: {method}") |
|
|
|
|
|
return a_aligned, b_aligned |
|
|
|
|
|
|
|
|
def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray: |
|
|
""" |
|
|
Combine two DataArrays with the specified operation. |
|
|
|
|
|
Args: |
|
|
a, b: Input DataArrays |
|
|
op: Operation ('sum', 'avg', 'diff') |
|
|
|
|
|
Returns: |
|
|
Combined DataArray |
|
|
""" |
|
|
|
|
|
a_aligned, b_aligned = align_for_combine(a, b) |
|
|
|
|
|
|
|
|
if op == "sum": |
|
|
result = a_aligned + b_aligned |
|
|
elif op == "avg": |
|
|
result = (a_aligned + b_aligned) / 2 |
|
|
elif op == "diff": |
|
|
result = a_aligned - b_aligned |
|
|
else: |
|
|
raise ValueError(f"Unknown operation: {op}") |
|
|
|
|
|
|
|
|
result.name = f"{a.name}_{op}_{b.name}" |
|
|
|
|
|
if op == "sum": |
|
|
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}" |
|
|
elif op == "avg": |
|
|
result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}" |
|
|
elif op == "diff": |
|
|
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}" |
|
|
|
|
|
|
|
|
if a.attrs.get('units') == b.attrs.get('units'): |
|
|
result.attrs['units'] = a.attrs.get('units', '') |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray: |
|
|
""" |
|
|
Create a cross-section of the DataArray. |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
along: Dimension to keep for the section (e.g., 'time', 'lat') |
|
|
fixed: Dictionary of {dim: value} for dimensions to fix |
|
|
|
|
|
Returns: |
|
|
Cross-section DataArray |
|
|
""" |
|
|
if along not in da.dims: |
|
|
raise ValueError(f"Dimension '{along}' not found in DataArray") |
|
|
|
|
|
|
|
|
result = da |
|
|
|
|
|
|
|
|
selection = {} |
|
|
for dim, value in fixed.items(): |
|
|
if dim not in da.dims: |
|
|
continue |
|
|
|
|
|
coord = da.coords[dim] |
|
|
|
|
|
if isinstance(value, (int, float)): |
|
|
|
|
|
selection[dim] = coord.sel({dim: value}, method='nearest') |
|
|
elif isinstance(value, str) and 'time' in dim.lower(): |
|
|
|
|
|
selection[dim] = value |
|
|
else: |
|
|
selection[dim] = value |
|
|
|
|
|
if selection: |
|
|
result = result.sel(selection, method='nearest') |
|
|
|
|
|
|
|
|
if along not in result.dims: |
|
|
raise ValueError(f"Section operation removed the '{along}' dimension") |
|
|
|
|
|
|
|
|
result.attrs = da.attrs.copy() |
|
|
|
|
|
|
|
|
section_info = [] |
|
|
for dim, value in fixed.items(): |
|
|
if dim in da.dims: |
|
|
if isinstance(value, (int, float)): |
|
|
section_info.append(f"{dim}={value:.3f}") |
|
|
else: |
|
|
section_info.append(f"{dim}={value}") |
|
|
|
|
|
if section_info: |
|
|
long_name = result.attrs.get('long_name', result.name) |
|
|
result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray: |
|
|
""" |
|
|
Aggregate spatially (e.g., zonal mean). |
|
|
|
|
|
Args: |
|
|
da: Input DataArray |
|
|
method: Aggregation method ('mean', 'sum', 'std') |
|
|
|
|
|
Returns: |
|
|
Spatially aggregated DataArray |
|
|
""" |
|
|
coords = identify_coordinates(da) |
|
|
|
|
|
spatial_dims = [] |
|
|
if 'X' in coords: |
|
|
spatial_dims.append(coords['X']) |
|
|
if 'Y' in coords: |
|
|
spatial_dims.append(coords['Y']) |
|
|
|
|
|
if not spatial_dims: |
|
|
raise ValueError("No spatial dimensions found for aggregation") |
|
|
|
|
|
|
|
|
if method == "mean": |
|
|
result = da.mean(dim=spatial_dims) |
|
|
elif method == "sum": |
|
|
result = da.sum(dim=spatial_dims) |
|
|
elif method == "std": |
|
|
result = da.std(dim=spatial_dims) |
|
|
else: |
|
|
raise ValueError(f"Unknown aggregation method: {method}") |
|
|
|
|
|
|
|
|
result.attrs = da.attrs.copy() |
|
|
long_name = result.attrs.get('long_name', result.name) |
|
|
result.attrs['long_name'] = f"{method.capitalize()} of {long_name}" |
|
|
|
|
|
return result |