File size: 7,317 Bytes
433dab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Grid alignment and combination operations."""

from typing import Literal, Dict, Any, Tuple
import numpy as np
import xarray as xr
from .utils import identify_coordinates, get_crs, is_geographic


def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]:
    """
    Align two DataArrays for combination operations.
    
    Args:
        a, b: Input DataArrays
        method: Alignment method ('reindex', 'interp')
        
    Returns:
        Tuple of aligned DataArrays
    """
    # Check CRS compatibility
    crs_a = get_crs(a)
    crs_b = get_crs(b)
    
    if crs_a and crs_b and not crs_a.equals(crs_b):
        raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}")
    
    # Get coordinate information
    coords_a = identify_coordinates(a)
    coords_b = identify_coordinates(b)
    
    # Find common dimensions
    common_dims = set(a.dims) & set(b.dims)
    
    if not common_dims:
        raise ValueError("No common dimensions found for alignment")
    
    # Align coordinates
    if method == "reindex":
        # Use nearest neighbor reindexing
        a_aligned = a
        b_aligned = b
        
        for dim in common_dims:
            if dim in a.dims and dim in b.dims:
                # Get the union of coordinates for this dimension
                coord_a = a.coords[dim]
                coord_b = b.coords[dim]
                
                # Use the coordinate with higher resolution
                if len(coord_a) >= len(coord_b):
                    target_coord = coord_a
                else:
                    target_coord = coord_b
                
                # Reindex both arrays to the target coordinate
                a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest')
                b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest')
    
    elif method == "interp":
        # Use interpolation
        # Find common coordinate grid
        common_coords = {}
        for dim in common_dims:
            if dim in a.dims and dim in b.dims:
                coord_a = a.coords[dim]
                coord_b = b.coords[dim]
                
                # Create a common grid (intersection)
                min_val = max(float(coord_a.min()), float(coord_b.min()))
                max_val = min(float(coord_a.max()), float(coord_b.max()))
                
                # Use the finer resolution
                res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0
                res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0
                res = min(abs(res_a), abs(res_b))
                
                common_coords[dim] = np.arange(min_val, max_val + res, res)
        
        a_aligned = a.interp(common_coords)
        b_aligned = b.interp(common_coords)
    
    else:
        raise ValueError(f"Unknown alignment method: {method}")
    
    return a_aligned, b_aligned


def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray:
    """
    Combine two DataArrays with the specified operation.
    
    Args:
        a, b: Input DataArrays
        op: Operation ('sum', 'avg', 'diff')
        
    Returns:
        Combined DataArray
    """
    # Align the arrays first
    a_aligned, b_aligned = align_for_combine(a, b)
    
    # Perform the operation
    if op == "sum":
        result = a_aligned + b_aligned
    elif op == "avg":
        result = (a_aligned + b_aligned) / 2
    elif op == "diff":
        result = a_aligned - b_aligned
    else:
        raise ValueError(f"Unknown operation: {op}")
    
    # Update attributes
    result.name = f"{a.name}_{op}_{b.name}"
    
    if op == "sum":
        result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}"
    elif op == "avg":
        result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}"
    elif op == "diff":
        result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}"
    
    # Preserve units if they match
    if a.attrs.get('units') == b.attrs.get('units'):
        result.attrs['units'] = a.attrs.get('units', '')
    
    return result


def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray:
    """
    Create a cross-section of the DataArray.
    
    Args:
        da: Input DataArray
        along: Dimension to keep for the section (e.g., 'time', 'lat')
        fixed: Dictionary of {dim: value} for dimensions to fix
        
    Returns:
        Cross-section DataArray
    """
    if along not in da.dims:
        raise ValueError(f"Dimension '{along}' not found in DataArray")
    
    # Start with the full array
    result = da
    
    # Apply fixed selections
    selection = {}
    for dim, value in fixed.items():
        if dim not in da.dims:
            continue
            
        coord = da.coords[dim]
        
        if isinstance(value, (int, float)):
            # Select nearest value
            selection[dim] = coord.sel({dim: value}, method='nearest')
        elif isinstance(value, str) and 'time' in dim.lower():
            # Handle time strings
            selection[dim] = value
        else:
            selection[dim] = value
    
    if selection:
        result = result.sel(selection, method='nearest')
    
    # Ensure the 'along' dimension is preserved
    if along not in result.dims:
        raise ValueError(f"Section operation removed the '{along}' dimension")
    
    # Update metadata
    result.attrs = da.attrs.copy()
    
    # Add section info to long_name
    section_info = []
    for dim, value in fixed.items():
        if dim in da.dims:
            if isinstance(value, (int, float)):
                section_info.append(f"{dim}={value:.3f}")
            else:
                section_info.append(f"{dim}={value}")
    
    if section_info:
        long_name = result.attrs.get('long_name', result.name)
        result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})"
    
    return result


def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray:
    """
    Aggregate spatially (e.g., zonal mean).
    
    Args:
        da: Input DataArray
        method: Aggregation method ('mean', 'sum', 'std')
        
    Returns:
        Spatially aggregated DataArray
    """
    coords = identify_coordinates(da)
    
    spatial_dims = []
    if 'X' in coords:
        spatial_dims.append(coords['X'])
    if 'Y' in coords:
        spatial_dims.append(coords['Y'])
    
    if not spatial_dims:
        raise ValueError("No spatial dimensions found for aggregation")
    
    # Perform aggregation
    if method == "mean":
        result = da.mean(dim=spatial_dims)
    elif method == "sum":
        result = da.sum(dim=spatial_dims)
    elif method == "std":
        result = da.std(dim=spatial_dims)
    else:
        raise ValueError(f"Unknown aggregation method: {method}")
    
    # Update attributes
    result.attrs = da.attrs.copy()
    long_name = result.attrs.get('long_name', result.name)
    result.attrs['long_name'] = f"{method.capitalize()} of {long_name}"
    
    return result