fuxi-2.1 / data_util.py
tpys's picture
Upload folder using huggingface_hub
7e46066 verified
Raw
History Blame Contribute Delete
1.73 kB
"""Data utilities for FuXi 2.1 inference.
Handles: input loading, normalization stats, output denormalization.
"""
import os
import numpy as np
import xarray as xr
def load_input(path: str) -> xr.DataArray:
"""Load pre-normalized input from NetCDF.
Expected shape: (time=2, channel=85, lat=721, lon=1440).
Input must already be z-score normalized.
"""
ds = xr.open_dataset(path)
return ds["input"]
def load_norm_stats(model_dir: str) -> tuple[np.ndarray, np.ndarray]:
"""Load mean.nc and std.nc for output denormalization.
Returns (mean, std) each shape (85,).
"""
mean = xr.open_dataarray(os.path.join(model_dir, "mean.nc")).values
std = xr.open_dataarray(os.path.join(model_dir, "std.nc")).values
return mean, std
def postprocess(
output: np.ndarray, mean: np.ndarray, std: np.ndarray, exp_channel: int = -1
) -> np.ndarray:
"""Denormalize output in-place: output = output * std + mean.
The tp channel (last, index 84) was log1p-normalized during training,
so expm1 is applied to reverse it. Pass exp_channel=None to skip.
Args:
output: array with shape (..., C, H, W)
mean: shape (C,)
std: shape (C,)
exp_channel: channel index for expm1 reversal (-1 = last)
"""
ndim = output.ndim
shape = [1] * (ndim - 3) + [mean.shape[0], 1, 1]
m = mean.reshape(shape).astype(output.dtype)
s = std.reshape(shape).astype(output.dtype)
np.multiply(output, s, out=output)
np.add(output, m, out=output)
if exp_channel is not None:
ch = output[..., exp_channel, :, :]
np.clip(ch, None, 20, out=ch)
np.expm1(ch, out=ch)
np.clip(ch, 0, None, out=ch)
return output