File size: 7,317 Bytes
433dab5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
"""Grid alignment and combination operations."""
from typing import Literal, Dict, Any, Tuple
import numpy as np
import xarray as xr
from .utils import identify_coordinates, get_crs, is_geographic
def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]:
"""
Align two DataArrays for combination operations.
Args:
a, b: Input DataArrays
method: Alignment method ('reindex', 'interp')
Returns:
Tuple of aligned DataArrays
"""
# Check CRS compatibility
crs_a = get_crs(a)
crs_b = get_crs(b)
if crs_a and crs_b and not crs_a.equals(crs_b):
raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}")
# Get coordinate information
coords_a = identify_coordinates(a)
coords_b = identify_coordinates(b)
# Find common dimensions
common_dims = set(a.dims) & set(b.dims)
if not common_dims:
raise ValueError("No common dimensions found for alignment")
# Align coordinates
if method == "reindex":
# Use nearest neighbor reindexing
a_aligned = a
b_aligned = b
for dim in common_dims:
if dim in a.dims and dim in b.dims:
# Get the union of coordinates for this dimension
coord_a = a.coords[dim]
coord_b = b.coords[dim]
# Use the coordinate with higher resolution
if len(coord_a) >= len(coord_b):
target_coord = coord_a
else:
target_coord = coord_b
# Reindex both arrays to the target coordinate
a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest')
b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest')
elif method == "interp":
# Use interpolation
# Find common coordinate grid
common_coords = {}
for dim in common_dims:
if dim in a.dims and dim in b.dims:
coord_a = a.coords[dim]
coord_b = b.coords[dim]
# Create a common grid (intersection)
min_val = max(float(coord_a.min()), float(coord_b.min()))
max_val = min(float(coord_a.max()), float(coord_b.max()))
# Use the finer resolution
res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0
res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0
res = min(abs(res_a), abs(res_b))
common_coords[dim] = np.arange(min_val, max_val + res, res)
a_aligned = a.interp(common_coords)
b_aligned = b.interp(common_coords)
else:
raise ValueError(f"Unknown alignment method: {method}")
return a_aligned, b_aligned
def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray:
"""
Combine two DataArrays with the specified operation.
Args:
a, b: Input DataArrays
op: Operation ('sum', 'avg', 'diff')
Returns:
Combined DataArray
"""
# Align the arrays first
a_aligned, b_aligned = align_for_combine(a, b)
# Perform the operation
if op == "sum":
result = a_aligned + b_aligned
elif op == "avg":
result = (a_aligned + b_aligned) / 2
elif op == "diff":
result = a_aligned - b_aligned
else:
raise ValueError(f"Unknown operation: {op}")
# Update attributes
result.name = f"{a.name}_{op}_{b.name}"
if op == "sum":
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}"
elif op == "avg":
result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}"
elif op == "diff":
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}"
# Preserve units if they match
if a.attrs.get('units') == b.attrs.get('units'):
result.attrs['units'] = a.attrs.get('units', '')
return result
def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray:
"""
Create a cross-section of the DataArray.
Args:
da: Input DataArray
along: Dimension to keep for the section (e.g., 'time', 'lat')
fixed: Dictionary of {dim: value} for dimensions to fix
Returns:
Cross-section DataArray
"""
if along not in da.dims:
raise ValueError(f"Dimension '{along}' not found in DataArray")
# Start with the full array
result = da
# Apply fixed selections
selection = {}
for dim, value in fixed.items():
if dim not in da.dims:
continue
coord = da.coords[dim]
if isinstance(value, (int, float)):
# Select nearest value
selection[dim] = coord.sel({dim: value}, method='nearest')
elif isinstance(value, str) and 'time' in dim.lower():
# Handle time strings
selection[dim] = value
else:
selection[dim] = value
if selection:
result = result.sel(selection, method='nearest')
# Ensure the 'along' dimension is preserved
if along not in result.dims:
raise ValueError(f"Section operation removed the '{along}' dimension")
# Update metadata
result.attrs = da.attrs.copy()
# Add section info to long_name
section_info = []
for dim, value in fixed.items():
if dim in da.dims:
if isinstance(value, (int, float)):
section_info.append(f"{dim}={value:.3f}")
else:
section_info.append(f"{dim}={value}")
if section_info:
long_name = result.attrs.get('long_name', result.name)
result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})"
return result
def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray:
"""
Aggregate spatially (e.g., zonal mean).
Args:
da: Input DataArray
method: Aggregation method ('mean', 'sum', 'std')
Returns:
Spatially aggregated DataArray
"""
coords = identify_coordinates(da)
spatial_dims = []
if 'X' in coords:
spatial_dims.append(coords['X'])
if 'Y' in coords:
spatial_dims.append(coords['Y'])
if not spatial_dims:
raise ValueError("No spatial dimensions found for aggregation")
# Perform aggregation
if method == "mean":
result = da.mean(dim=spatial_dims)
elif method == "sum":
result = da.sum(dim=spatial_dims)
elif method == "std":
result = da.std(dim=spatial_dims)
else:
raise ValueError(f"Unknown aggregation method: {method}")
# Update attributes
result.attrs = da.attrs.copy()
long_name = result.attrs.get('long_name', result.name)
result.attrs['long_name'] = f"{method.capitalize()} of {long_name}"
return result |