DGMR_SolRad / model_architect /inference_model.py
Jason-thingnario's picture
feat: initial implementation of DGMR solar radiation nowcasting models
2fa5aae
import torch.nn as nn
from .DGMR_SO.model import Generator as DGMR_SO
from .Generator_only.model_clr_idx import Generator as Generator_only
class Predictor(nn.Module):
def __init__(
self,
model_type,
):
super().__init__()
if model_type == 'DGMR_SO':
self.generator = DGMR_SO(
in_channels=1,
base_channels=24,
down_step=4,
prev_step=4,
sigma=1
)
elif model_type == 'Generator_only':
self.generator = Generator_only(
in_channels=1,
base_channels=24,
down_step=4,
prev_step=4,
)
def forward(self, x, x2, topo, datetime_feat, pred_step=36):
"""
x: input seq -> dims (N, D, C, H, W)
x2: input seq (WRF) -> dims (N, D, C, H, W)
topo: topography -> dims (N, 1, H=512, W=512)
datetime_feat -> dims (N, D, 4)
"""
pred = self.generator(x, x2, topo, datetime_feat, pred_step=pred_step)
return pred