import numpy as np def denormalize_and_split(pred: np.ndarray, out_stats: dict): # pred: (1,27,4,300,300) or (27,4,300,300) if pred.ndim == 5: pred = pred[0] mean = out_stats["mean"][:, :, None, None] # (27,4,1,1) std = out_stats["std"][:, :, None, None] pred = pred * std + mean # (27,4,300,300) u = pred[:, 0] v = pred[:, 1] w = pred[:, 2] k = pred[:, 3] return u, v, w, k