| import torch |
| import torch.nn as nn |
| import pandas as pd |
| import numpy as np |
| from functools import partial |
| from datetime import datetime, timedelta |
| from pathlib import Path |
| import pickle |
|
|
| import dask |
| import dask.array as da |
| import cartopy |
| import cartopy.crs as ccrs |
| import xarray as xr |
| import xarray.ufuncs as xu |
| import matplotlib.pyplot as plt |
|
|
| from model.afnonet import AFNONet |
|
|
| DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', |
| 'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850', |
| 'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850', |
| 'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour', |
| 'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850', |
| 'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850', |
| 'total_precipitation'] |
| DATAMAP = { |
| 'geopotential': 'z', |
| 'relative_humidity': 'r', |
| 'temperature': 't', |
| 'u_component_of_wind': 'u', |
| 'v_component_of_wind': 'v' |
| } |
|
|
|
|
| def load_model(): |
| |
| h, w = 720, 1440 |
| x_c, y_c, p_c = 20, 20, 1 |
|
|
| backbone_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
| ckpt = torch.load('./backbone.pt', map_location="cpu") |
| backbone_model.load_state_dict(ckpt['model']) |
|
|
| precip_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=p_c, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
| ckpt = torch.load('./precipitation.pt', map_location="cpu") |
| precip_model.load_state_dict(ckpt['model']) |
|
|
|
|
| def imcol(data, img_path, img_name, **kwargs): |
| fig = plt.figure(figsize=(20, 10)) |
| ax = plt.axes(projection=ccrs.PlateCarree()) |
|
|
| I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, rasterized=True, **kwargs) |
| ax.coastlines(resolution='110m') |
|
|
| dirname = f'{img_path.absolute()}/{img_name}.jpg' |
|
|
| plt.axis('off') |
| plt.savefig(dirname, bbox_inches='tight', pad_inches=0.) |
| plt.close(fig) |
|
|
|
|
| def plot(real_data, pred_data, save_path): |
| cmap_t = 'RdYlBu_r' |
|
|
| wind = xu.sqrt(real_data['u10'] ** 2 + real_data['v10'] ** 2) |
| wmin, wmax = wind.values.min(), wind.values.max() |
| wind = xu.sqrt(pred_data['u10'] ** 2 + pred_data['v10'] ** 2) |
| wmin, wmax = min(wind.values.min(), wmin), max(wind.values.max(), wmax) |
|
|
| pmin, pmax = real_data['tp'].values.min(), real_data['tp'].values.max() |
| pmin, pmax = min(pred_data['tp'].values.min(), pmin), max(pred_data['tp'].values.max(), pmax) |
|
|
| tmin, tmax = real_data['t2m'].values.min(), real_data['t2m'].values.max() |
| tmin, tmax = min(pred_data['t2m'].values.min(), tmin), max(pred_data['t2m'].values.max(), tmax) |
|
|
| for i in range(len(real_data.time)): |
| u = real_data['u10'].isel(time=i) |
| v = real_data['v10'].isel(time=i) |
| wind = xu.sqrt(u ** 2 + v ** 2) |
| precip = real_data['tp'].isel(time=i) |
| temp = real_data['t2m'].isel(time=i) |
|
|
| datetime = pd.to_datetime(str(wind['time'].values)) |
| datetime = datetime.strftime('%Y-%m-%d %H:%M:%S') |
| print(f'plot {datetime}') |
|
|
| imcol(wind, save_path, img_name=f'wind_{datetime}_real', cmap=cmap_t, vmin=wmin, vmax=wmax), |
| imcol(precip, save_path, img_name=f'precipitation_{datetime}_real', cmap=cmap_t, vmin=pmin, vmax=pmax), |
| imcol(temp, save_path, img_name=f'temperature_{datetime}_real', cmap=cmap_t, vmin=tmin, vmax=tmax) |
|
|
| for i in range(len(pred_data.time)): |
| u = pred_data['u10'].isel(time=i) |
| v = pred_data['v10'].isel(time=i) |
| wind = xu.sqrt(u ** 2 + v ** 2) |
| precip = pred_data['tp'].isel(time=i) |
| temp = pred_data['t2m'].isel(time=i) |
|
|
| datetime = pd.to_datetime(str(wind['time'].values)) |
| datetime = datetime.strftime('%Y-%m-%d %H:%M:%S') |
| print(f'plot {datetime}') |
|
|
| imcol(wind, save_path, img_name=f'wind_{datetime}_pred', cmap=cmap_t, vmin=wmin, vmax=wmax), |
| imcol(precip, save_path, img_name=f'precipitation_{datetime}_pred', cmap=cmap_t, vmin=pmin, vmax=pmax), |
| imcol(temp, save_path, img_name=f'temperature_{datetime}_pred', cmap=cmap_t, vmin=tmin, vmax=tmax) |
|
|
|
|
| def get_pred(sample, scaler, times=None, latitude=None, longitude=None): |
|
|
| backbone_model, precip_model = load_model() |
|
|
| sample = torch.from_numpy(sample[0]) |
| sample = sample.float() |
| |
| backbone_model.eval() |
| precip_model.eval() |
| pred = [] |
| x = sample.unsqueeze(0).transpose(3, 2).transpose(2, 1) |
| for i in range(len(times)): |
| print(f"predict {times[i]}") |
|
|
| with torch.cuda.amp.autocast(): |
| x = backbone_model(x) |
| tmp = x.transpose(1, 2).transpose(2, 3) |
| p = precip_model(x) |
|
|
| tmp = tmp.detach().numpy()[0, :, :, :3] * scaler['std'][:3] + scaler['mean'][:3] |
| p = p.detach().numpy()[0, 0, :, :, np.newaxis] * scaler['std'][-1] + scaler['mean'][-1] |
| tmp = np.concatenate([tmp, p], axis=-1) |
| pred.append(tmp) |
|
|
| pred = np.asarray(pred) |
| pred_data = xr.Dataset({ |
| 'u10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 0], chunks=(7, 720, 1440))), |
| 'v10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 1], chunks=(7, 720, 1440))), |
| 't2m': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 2], chunks=(7, 720, 1440))), |
| 'tp': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 3], chunks=(7, 720, 1440))), |
| }, |
| coords={'time': (['time'], times), |
| 'latitude': (['latitude'], latitude), |
| 'longitude': (['longitude'], longitude) |
| } |
| ) |
|
|
| return pred_data |
|
|
|
|
| def get_data(start_time, end_time): |
| times = slice(start_time, end_time) |
|
|
| with open(f'./scaler.pkl', "rb") as f: |
| scaler = pickle.load(f) |
|
|
| |
| datas = [] |
| for file in DATANAMES: |
| tmp = xr.open_mfdataset(f'./ERA5_rawdata/{file}/*.nc', combine='by_coords').sel(time=times) |
| if '@' in file: |
| k, v = file.split('@') |
| tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'}) |
| datas.append(tmp) |
| with dask.config.set(**{'array.slicing.split_large_chunks': False}): |
| raw_data = xr.merge(datas, compat="identical", join="inner") |
|
|
| data = [] |
| for name in ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']: |
| raw = raw_data[name].values |
| data.append(raw) |
|
|
| data = np.stack(data, axis=-1) |
| data = (data - scaler['mean']) / scaler['std'] |
| data = data[:, 1:, :, :] |
|
|
| return raw_data[['u10', 'v10', 't2m', 'tp']].sel(expver=1), data, scaler |
|
|
|
|
|
|
| if __name__ == '__main__': |
|
|
| start_time = datetime(2023, 1, 1, 0, 0) |
| end_time = datetime(2023, 1, 5, 18, 0) |
| num = int((end_time - start_time) / timedelta(hours=6)) |
|
|
| print(f"start_time: {start_time}, end_time: {end_time}, pred_num: {num}") |
|
|
| real_data, sample, scaler = get_data(start_time) |
| print(sample.shape) |
|
|
| pred_times = [start_time + timedelta(hours=6) * i for i in range(1, num)] |
| pred = get_pred(sample, scaler=scaler, times=pred_times, latitude=real_data.latitude[1:], longitude=real_data.longitude) |
|
|
| save_path = Path(f"./output/") |
| save_path.mkdir(parents=True, exist_ok=True) |
|
|
| plot(real_data, pred, save_path) |