VISION / inference_io_ssh_vision.py
YuanGao-YG's picture
Upload 19 files
484df85 verified
import xarray as xr
import torch as pt
from torch import nn
import torch.utils.data as Data
import numpy as np
import torch.nn.functional as F
import os
from model.vision import VISION
import os.path as osp
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import h5py
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
high1 = 0
high2 = 512
width1 = 0
width2 = 512
class testDataset(Dataset):
def __init__(self):
super().__init__()
self.data_path = './data/KD48_demo.h5'
self.data_file = h5py.File(self.data_path, 'r')
self.mean = np.load('./data/mean.npy')
self.std = np.load('./data/std.npy')
self.size = 50
def __getitem__(self, index):
data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2]
data = np.nan_to_num(data, nan=0)
ssh = data[0]
ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :])
data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2]
data_u = np.nan_to_num(data_u, nan=0)
u = data_u[0]
u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :])
data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2]
data_v = np.nan_to_num(data_v, nan=0)
v = data_v[0]
v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :])
data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2]
data_w_20 = np.nan_to_num(data_w_20, nan=0)
w_20 = data_w_20[0]
w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :])
data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2]
data_w_40 = np.nan_to_num(data_w_40, nan=0)
w_40 = data_w_40[0]
w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :])
data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2]
data_w_60 = np.nan_to_num(data_w_60, nan=0)
w_60 = data_w_60[0]
w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :])
data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2]
data_b = np.nan_to_num(data_b, nan=0)
b = data_b[0]
b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :])
return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0)
def __len__(self):
return self.size
def __del__(self):
self.data_file.close()
testdataset = testDataset()
testloader=Data.DataLoader(
dataset=testdataset,
batch_size=1,
shuffle=False,
num_workers=0
)
model = VISION().cuda()
checkpoint_path = './checkpoint_VISION/best_mse.pt'
ckpt = pt.load(checkpoint_path, map_location='cpu')
state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=True)
model.eval()
folder_path = './result'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
output_path = osp.join(folder_path, 'results_io_ssh_vision.h5')
if os.path.exists(output_path):
os.remove(output_path)
N = len(testdataset)
H = 512
W = 512
f_out = h5py.File(output_path, 'w')
dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32')
dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32')
buffer_preds = []
buffer_trues = []
buffer_indices = []
batch_size_to_save = 1
current_count = 10
with pt.no_grad():
num = 0
for data in tqdm(testloader, desc="Loading data"):
xbatch = data[:, 0:1, :, :].cuda().float()
ybatch = data[:, 4:7, :, :].cuda().float()
out = model(xbatch)
print(out.shape)
mse = pt.mean((ybatch - out) ** 2)
print(num, mse)
preds_np = out.detach().cpu().numpy().astype(np.float32)
trues_np = ybatch.detach().cpu().numpy().astype(np.float32)
buffer_preds.append(preds_np)
buffer_trues.append(trues_np)
buffer_indices.append(num)
if len(buffer_preds) == batch_size_to_save:
preds_block = np.concatenate(buffer_preds, axis=0)
trues_block = np.concatenate(buffer_trues, axis=0)
indices_block = buffer_indices
dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
buffer_preds.clear()
buffer_trues.clear()
buffer_indices.clear()
num += 1
if len(buffer_preds) > 0:
preds_block = np.concatenate(buffer_preds, axis=0)
trues_block = np.concatenate(buffer_trues, axis=0)
indices_block = buffer_indices
dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
buffer_preds.clear()
buffer_trues.clear()
buffer_indices.clear()
f_out.close()
print("Results successfully saved to HDF5 file.")