File size: 446 Bytes
f637d77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import numpy as np
def denormalize_and_split(pred: np.ndarray, out_stats: dict):
# pred: (1,27,4,300,300) or (27,4,300,300)
if pred.ndim == 5:
pred = pred[0]
mean = out_stats["mean"][:, :, None, None] # (27,4,1,1)
std = out_stats["std"][:, :, None, None]
pred = pred * std + mean # (27,4,300,300)
u = pred[:, 0]
v = pred[:, 1]
w = pred[:, 2]
k = pred[:, 3]
return u, v, w, k
|