| import xarray as xr |
| import torch as pt |
| from torch import nn |
| import torch.utils.data as Data |
| import numpy as np |
| import torch.nn.functional as F |
| import os |
| from model.vision import VISION |
| import os.path as osp |
| from tqdm import tqdm |
| from torch.utils.data import Dataset, DataLoader |
| import h5py |
|
|
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
| high1 = 0 |
| high2 = 512 |
| width1 = 0 |
| width2 = 512 |
|
|
| class testDataset(Dataset): |
| def __init__(self): |
| super().__init__() |
| self.data_path = './data/KD48_demo.h5' |
| self.data_file = h5py.File(self.data_path, 'r') |
| self.mean = np.load('./data/mean.npy') |
| self.std = np.load('./data/std.npy') |
| self.size = 50 |
|
|
| def __getitem__(self, index): |
| |
| data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2] |
| data = np.nan_to_num(data, nan=0) |
| ssh = data[0] |
| ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :]) |
|
|
| data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2] |
| data_u = np.nan_to_num(data_u, nan=0) |
| u = data_u[0] |
| u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :]) |
|
|
| data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2] |
| data_v = np.nan_to_num(data_v, nan=0) |
| v = data_v[0] |
| v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :]) |
|
|
| data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2] |
| data_w_20 = np.nan_to_num(data_w_20, nan=0) |
| w_20 = data_w_20[0] |
| w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :]) |
|
|
| data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2] |
| data_w_40 = np.nan_to_num(data_w_40, nan=0) |
| w_40 = data_w_40[0] |
| w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :]) |
|
|
| data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2] |
| data_w_60 = np.nan_to_num(data_w_60, nan=0) |
| w_60 = data_w_60[0] |
| w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :]) |
|
|
| data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2] |
| data_b = np.nan_to_num(data_b, nan=0) |
| b = data_b[0] |
| b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :]) |
|
|
| return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0) |
|
|
| def __len__(self): |
| return self.size |
|
|
| def __del__(self): |
| self.data_file.close() |
|
|
|
|
| testdataset = testDataset() |
| testloader=Data.DataLoader( |
| dataset=testdataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=0 |
| ) |
|
|
| model = VISION().cuda() |
| checkpoint_path = './checkpoint_VISION/best_mse.pt' |
| ckpt = pt.load(checkpoint_path, map_location='cpu') |
| state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt |
| state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
| model.load_state_dict(state_dict, strict=True) |
|
|
| model.eval() |
|
|
| folder_path = './result' |
| if not os.path.exists(folder_path): |
| os.makedirs(folder_path) |
|
|
| output_path = osp.join(folder_path, 'results_io_ssh_vision.h5') |
| if os.path.exists(output_path): |
| os.remove(output_path) |
|
|
| N = len(testdataset) |
| H = 512 |
| W = 512 |
| f_out = h5py.File(output_path, 'w') |
| dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32') |
| dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32') |
|
|
|
|
| buffer_preds = [] |
| buffer_trues = [] |
| buffer_indices = [] |
|
|
| batch_size_to_save = 1 |
| current_count = 10 |
|
|
| with pt.no_grad(): |
| num = 0 |
| for data in tqdm(testloader, desc="Loading data"): |
| xbatch = data[:, 0:1, :, :].cuda().float() |
| ybatch = data[:, 4:7, :, :].cuda().float() |
| out = model(xbatch) |
| print(out.shape) |
|
|
| mse = pt.mean((ybatch - out) ** 2) |
| print(num, mse) |
|
|
| preds_np = out.detach().cpu().numpy().astype(np.float32) |
| trues_np = ybatch.detach().cpu().numpy().astype(np.float32) |
|
|
| buffer_preds.append(preds_np) |
| buffer_trues.append(trues_np) |
| buffer_indices.append(num) |
|
|
| if len(buffer_preds) == batch_size_to_save: |
| preds_block = np.concatenate(buffer_preds, axis=0) |
| trues_block = np.concatenate(buffer_trues, axis=0) |
| indices_block = buffer_indices |
|
|
| dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :] |
| dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :] |
| |
| dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :] |
| dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :] |
| |
| dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :] |
| dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :] |
|
|
| buffer_preds.clear() |
| buffer_trues.clear() |
| buffer_indices.clear() |
|
|
| num += 1 |
|
|
| if len(buffer_preds) > 0: |
| preds_block = np.concatenate(buffer_preds, axis=0) |
| trues_block = np.concatenate(buffer_trues, axis=0) |
| indices_block = buffer_indices |
|
|
| dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :] |
| dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :] |
| |
| dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :] |
| dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :] |
| |
| dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :] |
| dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :] |
|
|
| buffer_preds.clear() |
| buffer_trues.clear() |
| buffer_indices.clear() |
|
|
| f_out.close() |
| print("Results successfully saved to HDF5 file.") |
|
|