| import numpy as np | |
| def denormalize_and_split(pred: np.ndarray, out_stats: dict): | |
| # pred: (1,27,4,300,300) or (27,4,300,300) | |
| if pred.ndim == 5: | |
| pred = pred[0] | |
| mean = out_stats["mean"][:, :, None, None] # (27,4,1,1) | |
| std = out_stats["std"][:, :, None, None] | |
| pred = pred * std + mean # (27,4,300,300) | |
| u = pred[:, 0] | |
| v = pred[:, 1] | |
| w = pred[:, 2] | |
| k = pred[:, 3] | |
| return u, v, w, k | |