| import time |
| import argparse |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import List, Sequence |
| import sys |
| from datetime import datetime, timedelta |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from model_architect.UNet_DDPM import UNet_with_time, DDPM |
|
|
| @dataclass |
| class Config: |
| input_frame: int = 12 |
| output_frame: int = 6 |
| cond_nc: int = 5 |
| time_emb_dim: int = 128 |
| base_chs: int = 32 |
| chs_mult: tuple = (1, 2, 4, 8, 8) |
| use_attn_list: tuple = (0, 0, 1, 1, 1) |
| n_res_blocks: int = 2 |
| n_steps: int = 1000 |
| dropout: float = 0.1 |
|
|
| def data_loading(BASETIME, device): |
| data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz') |
|
|
| inputs = {} |
| for key in data_npz: |
| inputs[key] = torch.from_numpy(data_npz[key]).to(device) |
|
|
| return inputs |
|
|
| def arg_parse(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '--pred-hr', |
| type=str, |
| default='1hr', |
| choices=[ |
| '1hr', |
| '6hr' |
| ] |
| ) |
| parser.add_argument( |
| '--pred-mode', |
| type=str, |
| default='DDPM', |
| choices=[ |
| 'DDPM', |
| 'DDIM' |
| ] |
| ) |
| parser.add_argument('--basetime', type=str, default='202504131100') |
| args = parser.parse_args() |
| return args |
|
|
| if __name__ == "__main__": |
| config = Config() |
| args = arg_parse() |
| pred_hr = args.pred_hr |
| pred_mode = args.pred_mode |
|
|
| BASETIME = args.basetime |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| inputs = data_loading(BASETIME, device) |
| model_config = Config() |
| if pred_hr == '6hr': |
| model_config.input_frame = 72 |
| model_config.output_frame = 36 |
| print("Prediction mode:", pred_mode) |
| print("Prediction horizon:", pred_hr) |
|
|
| |
| |
| |
| prev_himawari = inputs['Himawari'].squeeze(2) |
| topo = inputs['topo'] |
| input_ = torch.cat([prev_himawari, topo], dim=1) |
| WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear') |
|
|
| clearsky = inputs['clearsky'] |
| if pred_hr == '1hr': |
| WRF = WRF[:, :6] |
| clearsky = clearsky[:, :6] |
|
|
| backbone = UNet_with_time(model_config) |
| model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512)) |
| |
| |
| if pred_hr == '1hr': |
| ckpt_path = './model_weights/ft06_01hr/weights.ckpt' |
| elif pred_hr == '6hr': |
| ckpt_path = './model_weights/ft36_06hr/weights.ckpt' |
|
|
| ckpt = torch.load(ckpt_path, weights_only=True) |
| model.load_state_dict(ckpt['state_dict']) |
| model.eval() |
| model = model.to(device) |
|
|
| if pred_mode == 'DDPM': |
| pred_clr_idx = model.sample_ddpm( |
| input_, |
| input_cond=WRF, |
| verbose="text" |
| ) |
| elif pred_mode == 'DDIM': |
| pred_clr_idx = model.sample_ddim( |
| input_, |
| input_cond=WRF, |
| ddim_steps=100, |
| verbose="text" |
| ) |
| |
| pred_clr_idx = (pred_clr_idx + 1.0) / 2.0 |
| pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0) |
|
|
| |
| pred_srad = pred_clr_idx * clearsky |
| |
| |
| np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy()) |
| print('Done') |
|
|