"""Grid alignment and combination operations.""" from typing import Literal, Dict, Any, Tuple import numpy as np import xarray as xr from .utils import identify_coordinates, get_crs, is_geographic def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]: """ Align two DataArrays for combination operations. Args: a, b: Input DataArrays method: Alignment method ('reindex', 'interp') Returns: Tuple of aligned DataArrays """ # Check CRS compatibility crs_a = get_crs(a) crs_b = get_crs(b) if crs_a and crs_b and not crs_a.equals(crs_b): raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}") # Get coordinate information coords_a = identify_coordinates(a) coords_b = identify_coordinates(b) # Find common dimensions common_dims = set(a.dims) & set(b.dims) if not common_dims: raise ValueError("No common dimensions found for alignment") # Align coordinates if method == "reindex": # Use nearest neighbor reindexing a_aligned = a b_aligned = b for dim in common_dims: if dim in a.dims and dim in b.dims: # Get the union of coordinates for this dimension coord_a = a.coords[dim] coord_b = b.coords[dim] # Use the coordinate with higher resolution if len(coord_a) >= len(coord_b): target_coord = coord_a else: target_coord = coord_b # Reindex both arrays to the target coordinate a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest') b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest') elif method == "interp": # Use interpolation # Find common coordinate grid common_coords = {} for dim in common_dims: if dim in a.dims and dim in b.dims: coord_a = a.coords[dim] coord_b = b.coords[dim] # Create a common grid (intersection) min_val = max(float(coord_a.min()), float(coord_b.min())) max_val = min(float(coord_a.max()), float(coord_b.max())) # Use the finer resolution res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0 res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0 res = min(abs(res_a), abs(res_b)) common_coords[dim] = np.arange(min_val, max_val + res, res) a_aligned = a.interp(common_coords) b_aligned = b.interp(common_coords) else: raise ValueError(f"Unknown alignment method: {method}") return a_aligned, b_aligned def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray: """ Combine two DataArrays with the specified operation. Args: a, b: Input DataArrays op: Operation ('sum', 'avg', 'diff') Returns: Combined DataArray """ # Align the arrays first a_aligned, b_aligned = align_for_combine(a, b) # Perform the operation if op == "sum": result = a_aligned + b_aligned elif op == "avg": result = (a_aligned + b_aligned) / 2 elif op == "diff": result = a_aligned - b_aligned else: raise ValueError(f"Unknown operation: {op}") # Update attributes result.name = f"{a.name}_{op}_{b.name}" if op == "sum": result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}" elif op == "avg": result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}" elif op == "diff": result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}" # Preserve units if they match if a.attrs.get('units') == b.attrs.get('units'): result.attrs['units'] = a.attrs.get('units', '') return result def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray: """ Create a cross-section of the DataArray. Args: da: Input DataArray along: Dimension to keep for the section (e.g., 'time', 'lat') fixed: Dictionary of {dim: value} for dimensions to fix Returns: Cross-section DataArray """ if along not in da.dims: raise ValueError(f"Dimension '{along}' not found in DataArray") # Start with the full array result = da # Apply fixed selections selection = {} for dim, value in fixed.items(): if dim not in da.dims: continue coord = da.coords[dim] if isinstance(value, (int, float)): # Select nearest value selection[dim] = coord.sel({dim: value}, method='nearest') elif isinstance(value, str) and 'time' in dim.lower(): # Handle time strings selection[dim] = value else: selection[dim] = value if selection: result = result.sel(selection, method='nearest') # Ensure the 'along' dimension is preserved if along not in result.dims: raise ValueError(f"Section operation removed the '{along}' dimension") # Update metadata result.attrs = da.attrs.copy() # Add section info to long_name section_info = [] for dim, value in fixed.items(): if dim in da.dims: if isinstance(value, (int, float)): section_info.append(f"{dim}={value:.3f}") else: section_info.append(f"{dim}={value}") if section_info: long_name = result.attrs.get('long_name', result.name) result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})" return result def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray: """ Aggregate spatially (e.g., zonal mean). Args: da: Input DataArray method: Aggregation method ('mean', 'sum', 'std') Returns: Spatially aggregated DataArray """ coords = identify_coordinates(da) spatial_dims = [] if 'X' in coords: spatial_dims.append(coords['X']) if 'Y' in coords: spatial_dims.append(coords['Y']) if not spatial_dims: raise ValueError("No spatial dimensions found for aggregation") # Perform aggregation if method == "mean": result = da.mean(dim=spatial_dims) elif method == "sum": result = da.sum(dim=spatial_dims) elif method == "std": result = da.std(dim=spatial_dims) else: raise ValueError(f"Unknown aggregation method: {method}") # Update attributes result.attrs = da.attrs.copy() long_name = result.attrs.get('long_name', result.name) result.attrs['long_name'] = f"{method.capitalize()} of {long_name}" return result