YuanGao-YG's picture
Upload 97 files
912fe5a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor
import os
def reshape_fields(img, inp_or_tar, params, train, normalize=True, orog=None, add_noise=False):
# Takes in np array of size (n_history+1, c, h, w)
# returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
if len(np.shape(img)) == 3:
img = np.expand_dims(img, 0)
n_history = np.shape(img)[0] - 1
img_shape_x = np.shape(img)[-2]
img_shape_y = np.shape(img)[-1]
n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels
if inp_or_tar == 'inp':
channels = params.in_channels
else:
channels = params.out_channels
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
mean_path = os.path.join(parent_dir, params.global_means_path)
std_path = os.path.join(parent_dir, params.global_stds_path)
if normalize and params.normalization == 'zscore':
means = np.load(mean_path)[:, channels]
stds = np.load(std_path)[:, channels]
img -=means
img /=stds
img = np.squeeze(img)
return torch.as_tensor(img)