VISION / inference_io_ssh_vision.py
YuanGao-YG's picture
Upload 19 files
484df85 verified
raw
history blame
5.63 kB
import xarray as xr
import torch as pt
from torch import nn
import torch.utils.data as Data
import numpy as np
import torch.nn.functional as F
import os
from model.vision import VISION
import os.path as osp
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import h5py
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
high1 = 0
high2 = 512
width1 = 0
width2 = 512
class testDataset(Dataset):
def __init__(self):
super().__init__()
self.data_path = './data/KD48_demo.h5'
self.data_file = h5py.File(self.data_path, 'r')
self.mean = np.load('./data/mean.npy')
self.std = np.load('./data/std.npy')
self.size = 50
def __getitem__(self, index):
data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2]
data = np.nan_to_num(data, nan=0)
ssh = data[0]
ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :])
data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2]
data_u = np.nan_to_num(data_u, nan=0)
u = data_u[0]
u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :])
data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2]
data_v = np.nan_to_num(data_v, nan=0)
v = data_v[0]
v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :])
data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2]
data_w_20 = np.nan_to_num(data_w_20, nan=0)
w_20 = data_w_20[0]
w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :])
data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2]
data_w_40 = np.nan_to_num(data_w_40, nan=0)
w_40 = data_w_40[0]
w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :])
data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2]
data_w_60 = np.nan_to_num(data_w_60, nan=0)
w_60 = data_w_60[0]
w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :])
data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2]
data_b = np.nan_to_num(data_b, nan=0)
b = data_b[0]
b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :])
return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0)
def __len__(self):
return self.size
def __del__(self):
self.data_file.close()
testdataset = testDataset()
testloader=Data.DataLoader(
dataset=testdataset,
batch_size=1,
shuffle=False,
num_workers=0
)
model = VISION().cuda()
checkpoint_path = './checkpoint_VISION/best_mse.pt'
ckpt = pt.load(checkpoint_path, map_location='cpu')
state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=True)
model.eval()
folder_path = './result'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
output_path = osp.join(folder_path, 'results_io_ssh_vision.h5')
if os.path.exists(output_path):
os.remove(output_path)
N = len(testdataset)
H = 512
W = 512
f_out = h5py.File(output_path, 'w')
dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32')
dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32')
buffer_preds = []
buffer_trues = []
buffer_indices = []
batch_size_to_save = 1
current_count = 10
with pt.no_grad():
num = 0
for data in tqdm(testloader, desc="Loading data"):
xbatch = data[:, 0:1, :, :].cuda().float()
ybatch = data[:, 4:7, :, :].cuda().float()
out = model(xbatch)
print(out.shape)
mse = pt.mean((ybatch - out) ** 2)
print(num, mse)
preds_np = out.detach().cpu().numpy().astype(np.float32)
trues_np = ybatch.detach().cpu().numpy().astype(np.float32)
buffer_preds.append(preds_np)
buffer_trues.append(trues_np)
buffer_indices.append(num)
if len(buffer_preds) == batch_size_to_save:
preds_block = np.concatenate(buffer_preds, axis=0)
trues_block = np.concatenate(buffer_trues, axis=0)
indices_block = buffer_indices
dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
buffer_preds.clear()
buffer_trues.clear()
buffer_indices.clear()
num += 1
if len(buffer_preds) > 0:
preds_block = np.concatenate(buffer_preds, axis=0)
trues_block = np.concatenate(buffer_trues, axis=0)
indices_block = buffer_indices
dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
buffer_preds.clear()
buffer_trues.clear()
buffer_indices.clear()
f_out.close()
print("Results successfully saved to HDF5 file.")