import numpy as np from scipy.ndimage import zoom def build_model_input(input_dict: dict, in_stats: dict): # Raw inputs dem = input_dict["dem"] rough = input_dict["roughness"] u100 = input_dict["u_100m"] v100 = input_dict["v_100m"] dem_rough = np.stack([dem, rough], axis=0) # (2,301,301) uv_100m = np.stack([u100, v100], axis=0) # (2, 9, 9) # Normalize high_mean = in_stats["high_mean"][:, None, None] high_std = in_stats["high_std"][:, None, None] low_mean = in_stats["low_mean"][:, None, None] low_std = in_stats["low_std"][:, None, None] dem_rough = (dem_rough - high_mean) / high_std uv_100m = (uv_100m - low_mean) / low_std # Resample to (300,300) like the original script dem_rough = zoom(dem_rough, (1, 300/301, 300/301), order=1) # -> (2,300,300) uv_100m = zoom(uv_100m, (1, 300/9, 300/9), order=1) # -> (2,300,300) x = np.concatenate([uv_100m, dem_rough], axis=0).astype(np.float32) # (4,300,300) return x