"""Data utilities for FuXi 2.1 inference. Handles: input loading, normalization stats, output denormalization. """ import os import numpy as np import xarray as xr def load_input(path: str) -> xr.DataArray: """Load pre-normalized input from NetCDF. Expected shape: (time=2, channel=85, lat=721, lon=1440). Input must already be z-score normalized. """ ds = xr.open_dataset(path) return ds["input"] def load_norm_stats(model_dir: str) -> tuple[np.ndarray, np.ndarray]: """Load mean.nc and std.nc for output denormalization. Returns (mean, std) each shape (85,). """ mean = xr.open_dataarray(os.path.join(model_dir, "mean.nc")).values std = xr.open_dataarray(os.path.join(model_dir, "std.nc")).values return mean, std def postprocess( output: np.ndarray, mean: np.ndarray, std: np.ndarray, exp_channel: int = -1 ) -> np.ndarray: """Denormalize output in-place: output = output * std + mean. The tp channel (last, index 84) was log1p-normalized during training, so expm1 is applied to reverse it. Pass exp_channel=None to skip. Args: output: array with shape (..., C, H, W) mean: shape (C,) std: shape (C,) exp_channel: channel index for expm1 reversal (-1 = last) """ ndim = output.ndim shape = [1] * (ndim - 3) + [mean.shape[0], 1, 1] m = mean.reshape(shape).astype(output.dtype) s = std.reshape(shape).astype(output.dtype) np.multiply(output, s, out=output) np.add(output, m, out=output) if exp_channel is not None: ch = output[..., exp_channel, :, :] np.clip(ch, None, 20, out=ch) np.expm1(ch, out=ch) np.clip(ch, 0, None, out=ch) return output