|
|
import torch |
|
|
import numpy as np |
|
|
import argparse |
|
|
|
|
|
from model_architect.inference_model import Predictor |
|
|
|
|
|
|
|
|
def data_loading(BASETIME, device): |
|
|
data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz') |
|
|
|
|
|
inputs = {} |
|
|
for key in data_npz: |
|
|
inputs[key] = torch.from_numpy(data_npz[key]).to(device) |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def model_loading(model_type, device): |
|
|
if model_type == 'DGMR_SO': |
|
|
ckpt_path = './model_weights/DGMR_SO/ft36/weights.ckpt' |
|
|
elif model_type == 'Generator_only': |
|
|
ckpt_path = './model_weights/Generator_only/ft36/weights.ckpt' |
|
|
|
|
|
model = Predictor( |
|
|
model_type=model_type, |
|
|
) |
|
|
|
|
|
ckpt = torch.load(ckpt_path, weights_only=True) |
|
|
model.load_state_dict(ckpt['generator_state_dict']) |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def arg_parse(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
'--model-type', |
|
|
type=str, |
|
|
default='DGMR_SO', |
|
|
choices=[ |
|
|
'Generator_only', |
|
|
'DGMR_SO']) |
|
|
parser.add_argument('--basetime', type=str, default='202504131100') |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = arg_parse() |
|
|
model_type = args.model_type |
|
|
BASETIME = args.basetime |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
inputs = data_loading(BASETIME, device) |
|
|
model = model_loading(model_type, device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_clr_idx = model( |
|
|
inputs['Himawari'], |
|
|
inputs['WRF'], |
|
|
inputs['topo'], |
|
|
inputs['time_feat'], |
|
|
pred_step=36, |
|
|
) |
|
|
pred_clr_idx = pred_clr_idx.squeeze(2).clamp(0, 1) |
|
|
|
|
|
|
|
|
pred_srad = pred_clr_idx * inputs['clearsky'] |
|
|
|
|
|
|
|
|
np.save(f'./pred_{BASETIME}_{model_type}.npy', pred_srad.cpu().numpy()) |
|
|
print('Done') |
|
|
|