File size: 1,102 Bytes
2fa5aae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch.nn as nn

from .DGMR_SO.model import Generator as DGMR_SO
from .Generator_only.model_clr_idx import Generator as Generator_only


class Predictor(nn.Module):
    def __init__(
        self,
        model_type,
    ):
        super().__init__()

        if model_type == 'DGMR_SO':
            self.generator = DGMR_SO(
                in_channels=1,
                base_channels=24,
                down_step=4,
                prev_step=4,
                sigma=1
            )

        elif model_type == 'Generator_only':
            self.generator = Generator_only(
                in_channels=1,
                base_channels=24,
                down_step=4,
                prev_step=4,
            )

    def forward(self, x, x2, topo, datetime_feat, pred_step=36):
        """
        x: input seq -> dims (N, D, C, H, W)
        x2: input seq (WRF) -> dims (N, D, C, H, W)
        topo: topography -> dims (N, 1, H=512, W=512)
        datetime_feat -> dims (N, D, 4)
        """
        pred = self.generator(x, x2, topo, datetime_feat, pred_step=pred_step)

        return pred