import xarray as xr import torch as pt from torch import nn import torch.utils.data as Data import numpy as np import torch.nn.functional as F import os from model.vision import VISION import os.path as osp from tqdm import tqdm from torch.utils.data import Dataset, DataLoader import h5py os.environ["CUDA_VISIBLE_DEVICES"] = "0" high1 = 0 high2 = 512 width1 = 0 width2 = 512 class testDataset(Dataset): def __init__(self): super().__init__() self.data_path = './data/KD48_demo.h5' self.data_file = h5py.File(self.data_path, 'r') self.mean = np.load('./data/mean.npy') self.std = np.load('./data/std.npy') self.size = 50 def __getitem__(self, index): data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2] data = np.nan_to_num(data, nan=0) ssh = data[0] ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :]) data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2] data_u = np.nan_to_num(data_u, nan=0) u = data_u[0] u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :]) data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2] data_v = np.nan_to_num(data_v, nan=0) v = data_v[0] v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :]) data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2] data_w_20 = np.nan_to_num(data_w_20, nan=0) w_20 = data_w_20[0] w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :]) data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2] data_w_40 = np.nan_to_num(data_w_40, nan=0) w_40 = data_w_40[0] w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :]) data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2] data_w_60 = np.nan_to_num(data_w_60, nan=0) w_60 = data_w_60[0] w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :]) data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2] data_b = np.nan_to_num(data_b, nan=0) b = data_b[0] b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :]) return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0) def __len__(self): return self.size def __del__(self): self.data_file.close() testdataset = testDataset() testloader=Data.DataLoader( dataset=testdataset, batch_size=1, shuffle=False, num_workers=0 ) model = VISION().cuda() checkpoint_path = './checkpoint_VISION/best_mse.pt' ckpt = pt.load(checkpoint_path, map_location='cpu') state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=True) model.eval() folder_path = './result' if not os.path.exists(folder_path): os.makedirs(folder_path) output_path = osp.join(folder_path, 'results_io_ssh_vision.h5') if os.path.exists(output_path): os.remove(output_path) N = len(testdataset) H = 512 W = 512 f_out = h5py.File(output_path, 'w') dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32') dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32') buffer_preds = [] buffer_trues = [] buffer_indices = [] batch_size_to_save = 1 current_count = 10 with pt.no_grad(): num = 0 for data in tqdm(testloader, desc="Loading data"): xbatch = data[:, 0:1, :, :].cuda().float() ybatch = data[:, 4:7, :, :].cuda().float() out = model(xbatch) print(out.shape) mse = pt.mean((ybatch - out) ** 2) print(num, mse) preds_np = out.detach().cpu().numpy().astype(np.float32) trues_np = ybatch.detach().cpu().numpy().astype(np.float32) buffer_preds.append(preds_np) buffer_trues.append(trues_np) buffer_indices.append(num) if len(buffer_preds) == batch_size_to_save: preds_block = np.concatenate(buffer_preds, axis=0) trues_block = np.concatenate(buffer_trues, axis=0) indices_block = buffer_indices dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :] dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :] dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :] dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :] dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :] dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :] buffer_preds.clear() buffer_trues.clear() buffer_indices.clear() num += 1 if len(buffer_preds) > 0: preds_block = np.concatenate(buffer_preds, axis=0) trues_block = np.concatenate(buffer_trues, axis=0) indices_block = buffer_indices dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :] dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :] dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :] dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :] dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :] dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :] buffer_preds.clear() buffer_trues.clear() buffer_indices.clear() f_out.close() print("Results successfully saved to HDF5 file.")