File size: 3,560 Bytes
be89dda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import time
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import List, Sequence
import sys
from datetime import datetime, timedelta
import numpy as np
import torch
import torch.nn.functional as F
from model_architect.UNet_DDPM import UNet_with_time, DDPM
@dataclass
class Config:
input_frame: int = 12
output_frame: int = 6
cond_nc: int = 5
time_emb_dim: int = 128
base_chs: int = 32
chs_mult: tuple = (1, 2, 4, 8, 8) ## different resolution
use_attn_list: tuple = (0, 0, 1, 1, 1) # 0 means no attention, 1 means use attention
n_res_blocks: int = 2
n_steps: int = 1000
dropout: float = 0.1
def data_loading(BASETIME, device):
data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz')
inputs = {}
for key in data_npz:
inputs[key] = torch.from_numpy(data_npz[key]).to(device)
return inputs
def arg_parse():
parser = argparse.ArgumentParser()
parser.add_argument(
'--pred-hr',
type=str,
default='1hr',
choices=[
'1hr',
'6hr'
]
)
parser.add_argument(
'--pred-mode',
type=str,
default='DDPM',
choices=[
'DDPM',
'DDIM'
]
)
parser.add_argument('--basetime', type=str, default='202504131100')
args = parser.parse_args()
return args
if __name__ == "__main__":
config = Config()
args = arg_parse()
pred_hr = args.pred_hr
pred_mode = args.pred_mode
BASETIME = args.basetime
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = data_loading(BASETIME, device)
model_config = Config()
if pred_hr == '6hr':
model_config.input_frame = 72
model_config.output_frame = 36
print("Prediction mode:", pred_mode)
print("Prediction horizon:", pred_hr)
## preporcess inputs for DDPM model
## concat previous Himawari and topo as conditional input (B, 5, 512, 512)
## WRF dim: (B, 36, 512, 512). 1hr: (B, 6, 512, 512), 6hr: (B, 36, 512, 512)
prev_himawari = inputs['Himawari'].squeeze(2)
topo = inputs['topo']
input_ = torch.cat([prev_himawari, topo], dim=1)
WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear')
clearsky = inputs['clearsky']
if pred_hr == '1hr':
WRF = WRF[:, :6]
clearsky = clearsky[:, :6]
backbone = UNet_with_time(model_config)
model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512))
## load model weights
if pred_hr == '1hr':
ckpt_path = './model_weights/ft06_01hr/weights.ckpt'
elif pred_hr == '6hr':
ckpt_path = './model_weights/ft36_06hr/weights.ckpt'
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['state_dict'])
model.eval()
model = model.to(device)
if pred_mode == 'DDPM':
pred_clr_idx = model.sample_ddpm(
input_,
input_cond=WRF,
verbose="text"
)
elif pred_mode == 'DDIM':
pred_clr_idx = model.sample_ddim(
input_,
input_cond=WRF,
ddim_steps=100,
verbose="text"
)
pred_clr_idx = (pred_clr_idx + 1.0) / 2.0
pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0)
## transform clearsky index to solar radiation
pred_srad = pred_clr_idx * clearsky
## save prediction
np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy())
print('Done')
|