Jason-thingnario commited on
Commit
2fa5aae
·
0 Parent(s):

feat: initial implementation of DGMR solar radiation nowcasting models

Browse files

- Add the model architecture
- Include pre-trained model weights for DGMR_SO & Generator_only models
- Add inference pipeline and sample data for testing
- Configure project structure with requirements and documentation

.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ *.npz filter=lfs diff=lfs merge=lfs -text
3
+ *.gif filter=lfs diff=lfs merge=lfs -text
4
+ *.pt filter=lfs diff=lfs merge=lfs -text
5
+ *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DGMR Solar Radiation Nowcasting
2
+
3
+ A deep learning model for solar radiation nowcasting using modified [Deep Generative Model of Rainfall (DGMR)](https://www.nature.com/articles/s41586-021-03854-z) architecture with Solar radiation Output (DGMR-SO). The model predicts clearsky index and converts it to solar radiation for up to 36 time steps ahead.
4
+
5
+ ## Overview
6
+
7
+ This repository implements two model variants for solar radiation forecasting:
8
+ - **DGMR_SO**: Full Deep Generative Models with one generator and two discriminators during the training stage
9
+ - **Generator_only**: Only one generator during the training stage
10
+
11
+ The model uses multiple input sources:
12
+ - **Himawari satellite data**: Clearsky index calculated from Himawari satellite data
13
+ - **WRF Prediction**: Clearsky index from WRF's solar irradiation prediction
14
+ - **Topography**: Static topographical features
15
+ - **Time features**: Temporal sin/cos encoding for day and hour
16
+
17
+ ## Installation
18
+
19
+ 1. Clone the repository:
20
+ ```bash
21
+ git clone <repository-url>
22
+ cd DGMR_SolRad
23
+ ```
24
+
25
+ 2. Install dependencies:
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ ## Requirements
31
+
32
+ - Python 3.x
33
+ - PyTorch 2.4.0
34
+ - NumPy 1.26.4
35
+ - einops 0.8.0
36
+
37
+ ## Usage
38
+
39
+ ### Basic Inference
40
+
41
+ Run solar radiation prediction using the pre-trained models:
42
+
43
+ ```bash
44
+ python inference.py --model-type DGMR_SO --basetime 202504131100
45
+ ```
46
+
47
+ ### Command Line Arguments
48
+
49
+ - `--model-type`: Choose between `DGMR_SO` or `Generator_only` (default: `DGMR_SO`)
50
+ - `--basetime`: Timestamp for input data in format YYYYMMDDHHMM (default: `202504131100`)
51
+
52
+ ### Example
53
+
54
+ ```bash
55
+ # Using DGMR_SO model
56
+ python inference.py --model-type DGMR_SO --basetime 202504131100
57
+
58
+ # Using Generator-only model
59
+ python inference.py --model-type Generator_only --basetime 202507151200
60
+ ```
61
+
62
+ ## Sample Data
63
+
64
+ The repository includes sample data files:
65
+ - `sample_202504131100.npz`
66
+ - `sample_202504161200.npz`
67
+ - `sample_202507151200.npz`
68
+
69
+ ## Model Weights
70
+
71
+ Pre-trained weights are available for both models:
72
+ - `model_weights/DGMR_SO/ft36/weights.ckpt`
73
+ - `model_weights/Generator_only/ft36/weights.ckpt`
inference.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import argparse
4
+
5
+ from model_architect.inference_model import Predictor
6
+
7
+
8
+ def data_loading(BASETIME, device):
9
+ data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz')
10
+
11
+ inputs = {}
12
+ for key in data_npz:
13
+ inputs[key] = torch.from_numpy(data_npz[key]).to(device)
14
+
15
+ return inputs
16
+
17
+
18
+ def model_loading(model_type, device):
19
+ if model_type == 'DGMR_SO':
20
+ ckpt_path = './model_weights/DGMR_SO/ft36/weights.ckpt'
21
+ elif model_type == 'Generator_only':
22
+ ckpt_path = './model_weights/Generator_only/ft36/weights.ckpt'
23
+
24
+ model = Predictor(
25
+ model_type=model_type,
26
+ )
27
+
28
+ ckpt = torch.load(ckpt_path, weights_only=True)
29
+ model.load_state_dict(ckpt['generator_state_dict'])
30
+ model.eval()
31
+ model.to(device)
32
+
33
+ return model
34
+
35
+
36
+ def arg_parse():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument(
39
+ '--model-type',
40
+ type=str,
41
+ default='DGMR_SO',
42
+ choices=[
43
+ 'Generator_only',
44
+ 'DGMR_SO'])
45
+ parser.add_argument('--basetime', type=str, default='202504131100')
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ if __name__ == '__main__':
51
+ args = arg_parse()
52
+ model_type = args.model_type
53
+ BASETIME = args.basetime
54
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
55
+
56
+ inputs = data_loading(BASETIME, device)
57
+ model = model_loading(model_type, device)
58
+
59
+ # prediction
60
+ with torch.no_grad():
61
+ pred_clr_idx = model(
62
+ inputs['Himawari'],
63
+ inputs['WRF'],
64
+ inputs['topo'],
65
+ inputs['time_feat'],
66
+ pred_step=36,
67
+ )
68
+ pred_clr_idx = pred_clr_idx.squeeze(2).clamp(0, 1)
69
+
70
+ # transform clearsky index to solar radiation
71
+ pred_srad = pred_clr_idx * inputs['clearsky'] # dim: (1, 36, 512, 512)
72
+
73
+ # save prediction
74
+ np.save(f'./pred_{BASETIME}_{model_type}.npy', pred_srad.cpu().numpy())
75
+ print('Done')
model_architect/DGMR_SO/discriminator.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..components.common import DBlock
4
+ import einops
5
+
6
+
7
+ class TemporalDiscriminator(nn.Module):
8
+ def __init__(self, in_channel: int, base_c: int = 24):
9
+ super().__init__()
10
+ self.in_channel = in_channel
11
+
12
+ # (N, C, D, H, W)
13
+ self.down_sample = nn.AvgPool3d(
14
+ kernel_size=(1, 2, 2),
15
+ stride=(1, 2, 2)
16
+ )
17
+
18
+ # (N, D, C, H, W)
19
+ self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
20
+
21
+ # in_channel, out_channel
22
+ # Conv3D -> (N, C, D, H, W)
23
+ chn = base_c * 2 * in_channel
24
+ self.d3_1 = DBlock(in_channel=in_channel * 4,
25
+ out_channel=chn,
26
+ conv_type='3d', apply_relu=False, apply_down=True)
27
+
28
+ self.d3_2 = DBlock(in_channel=chn,
29
+ out_channel=2 * chn,
30
+ conv_type='3d', apply_relu=True, apply_down=True)
31
+
32
+ self.Dlist = nn.ModuleList()
33
+ for i in range(3):
34
+ chn = chn * 2
35
+ self.Dlist.append(
36
+ DBlock(in_channel=chn,
37
+ out_channel=2 * chn,
38
+ conv_type='2d', apply_relu=True, apply_down=True)
39
+ )
40
+
41
+ self.last_D = DBlock(in_channel=2 * chn,
42
+ out_channel=2 * chn,
43
+ conv_type='2d', apply_relu=True, apply_down=False)
44
+
45
+ self.fc = nn.Linear(2 * chn, 1)
46
+ self.relu = nn.ReLU()
47
+ # TODO: close bn
48
+ # self.bn = nn.BatchNorm1d(2*chn)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ x = self.down_sample(x)
52
+ x = self.space_to_depth(x)
53
+
54
+ # go through the 3D Block
55
+ # from (N, D, C, H, W) -> (N, C, D, H, W)
56
+ x = torch.permute(x, dims=(0, 2, 1, 3, 4))
57
+ x = self.d3_1(x)
58
+ x = self.d3_2(x)
59
+ # go through 2D Block, permute -> (N, D, C, H, W)
60
+ x = torch.permute(x, dims=(0, 2, 1, 3, 4))
61
+ n, d, c, h, w = list(x.size())
62
+ ####
63
+ fea = einops.rearrange(x, "n d c h w -> (n d) c h w")
64
+ for dd in self.Dlist:
65
+ fea = dd(fea)
66
+
67
+ fea = self.last_D(fea)
68
+
69
+ fea = torch.sum(self.relu(fea), dim=[2, 3])
70
+ # fea = self.bn(fea)
71
+ fea = self.fc(fea)
72
+
73
+ y = torch.reshape(fea, (n, d, 1)) # dims -> (N, D, 1)
74
+ y = torch.sum(y, keepdim=True, dim=1) # dims -> (N, 1, 1)
75
+
76
+ return y
77
+
78
+
79
+ class SpatialDiscriminator(nn.Module):
80
+ def __init__(self, in_channel: int, base_c: int = 24):
81
+ super().__init__()
82
+ self.in_channel = in_channel
83
+
84
+ # (N, C, H, W)
85
+ self.down_sample = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
86
+
87
+ # (N, C, H, W)
88
+ self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
89
+
90
+ # first Dblock doesn't apply relu
91
+ chn = base_c * in_channel
92
+ self.d1 = DBlock(in_channel=in_channel * 4,
93
+ out_channel=chn * 2,
94
+ conv_type='2d', apply_relu=False, apply_down=True)
95
+
96
+ self.Dlist = nn.ModuleList()
97
+ for i in range(4):
98
+ chn = chn * 2
99
+ self.Dlist.append(
100
+ DBlock(in_channel=chn,
101
+ out_channel=2 * chn,
102
+ conv_type='2d', apply_relu=True, apply_down=True)
103
+ )
104
+
105
+ self.last_D = DBlock(in_channel=2 * chn,
106
+ out_channel=2 * chn,
107
+ conv_type='2d', apply_relu=True, apply_down=False)
108
+
109
+ self.fc = nn.Linear(2 * chn, 1)
110
+ self.relu = nn.ReLU()
111
+ # TODO: close BN
112
+ # self.bn = nn.BatchNorm1d(2*chn)
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ # note: input dims -> (N, D, C, H, W)
116
+ # randomly pick up 8 out of 18
117
+ perm = torch.randperm(x.shape[1])
118
+ random_idx = perm[:8]
119
+
120
+ fea = x[:, random_idx, :, :, :]
121
+
122
+ n, d, c, h, w = list(fea.size())
123
+
124
+ fea = einops.rearrange(fea, "n d c h w -> (n d) c h w")
125
+ fea = self.down_sample(fea)
126
+ fea = self.space_to_depth(fea)
127
+
128
+ # apply DBlock
129
+ fea = self.d1(fea)
130
+ for dd in self.Dlist:
131
+ fea = dd(fea)
132
+
133
+ fea = self.last_D(fea)
134
+
135
+ # sum
136
+ fea = torch.sum(self.relu(fea), dim=[2, 3])
137
+ # fea = self.bn(fea)
138
+ fea = self.fc(fea)
139
+
140
+ y = torch.reshape(fea, (n, d, 1)) # dims -> (N, D, 1)
141
+ y = torch.sum(y, keepdim=True, dim=1) # dims -> (N, 1, 1)
142
+
143
+ return y
model_architect/DGMR_SO/generator.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils.parametrizations import spectral_norm
6
+ from torch.distributions import normal
7
+ import einops
8
+
9
+ from ..components.ConvGRU import ConvGRU
10
+ from ..components.common import GBlock, Up_GBlock, LBlock, AttentionLayer, DBlock
11
+
12
+
13
+ class Sampler(nn.Module):
14
+ def __init__(self, in_channels, base_channels=24, up_step=4):
15
+ """
16
+ up_step should be the same as down_step in context-condition-stack
17
+
18
+ """
19
+ super().__init__()
20
+ base_c = base_channels
21
+
22
+ self.up_steps = up_step
23
+ self.convgru_list = nn.ModuleList()
24
+ self.conv1x1_list = nn.ModuleList()
25
+ self.gblock_list = nn.ModuleList()
26
+ self.upg_list = nn.ModuleList()
27
+
28
+ for i in range(self.up_steps):
29
+ # different scale
30
+ chs1 = base_c * 2**(self.up_steps - i + 1) * in_channels
31
+ chs2 = base_c * 2**(self.up_steps - i) * in_channels
32
+ # convgru
33
+ self.convgru_list.append(
34
+ ConvGRU(chs1, chs2, 3)
35
+ )
36
+ # conv1x1
37
+ self.conv1x1_list.append(
38
+ spectral_norm(
39
+ nn.Conv2d(
40
+ in_channels=chs2,
41
+ out_channels=chs1,
42
+ kernel_size=(
43
+ 1,
44
+ 1))))
45
+ # GBlock
46
+ self.gblock_list.append(
47
+ GBlock(in_channel=chs1, out_channel=chs1)
48
+ )
49
+ # upgblock
50
+ self.upg_list.append(
51
+ Up_GBlock(in_channel=chs1)
52
+ )
53
+
54
+ # output
55
+ # TODO: close Batch
56
+ # self.bn = nn.BatchNorm2d(chs2)
57
+ self.relu = nn.ReLU()
58
+ self.last_conv1x1 = spectral_norm(
59
+ nn.Conv2d(in_channels=chs2,
60
+ out_channels=4,
61
+ kernel_size=(1, 1))
62
+ )
63
+ self.depth_to_space = nn.PixelShuffle(upscale_factor=2)
64
+
65
+ def forward(self, latents, img_feat, init_states, pred_step):
66
+ """
67
+ latent dim -> (N, C, W, H)
68
+ img_feat dim -> (N, D, C, W, H)
69
+ init_states dim -> (N, C, W, H)
70
+ """
71
+ # expand time dims at axis=1
72
+ latents = torch.unsqueeze(latents, dim=1)
73
+
74
+ # repeat batch_size
75
+ if latents.shape[0] == 1:
76
+ # expand batch
77
+ latents = einops.repeat(
78
+ latents,
79
+ "b d c h w -> (repeat b) d c h w",
80
+ repeat=init_states[0].shape[0])
81
+ # repeat time step
82
+ latents = einops.repeat(
83
+ latents, "b d c h w -> b (repeat d) c h w", repeat=pred_step
84
+ )
85
+ # TODO: add img feas
86
+ # seq_out = latents + img_feat
87
+ seq_out = torch.cat((img_feat, latents), dim=2)
88
+ # init_states should be reversed
89
+ # scale up step
90
+ for i in range(self.up_steps):
91
+ seq_out = self.convgru_list[i](
92
+ seq_out, init_states[(self.up_steps - 1) - i])
93
+ # seq_out output shape -> (D, N, C, W, H)
94
+
95
+ # forloop time step
96
+ seq_out = [self.conv1x1_list[i](h) for h in seq_out]
97
+ seq_out = [self.gblock_list[i](h) for h in seq_out]
98
+ seq_out = [self.upg_list[i](h) for h in seq_out]
99
+ # output: seq_out list dim -> D * [N, C, H, W]
100
+ # should stack at dim == 1 to become (N, D, C, H, W)
101
+ seq_out = torch.stack(seq_out, dim=1)
102
+
103
+ # final output
104
+ # forloop time step
105
+ output = []
106
+ for t in range(seq_out.shape[1]):
107
+ y = seq_out[:, t, :, :, :]
108
+ # y = self.bn(y)
109
+ y = self.relu(y)
110
+ y = self.last_conv1x1(y)
111
+ y = self.depth_to_space(y)
112
+ output.append(y)
113
+
114
+ output = torch.stack(output, dim=1)
115
+
116
+ return output
117
+
118
+ class LatentConditionStack(nn.Module):
119
+ def __init__(self, out_channels, down_step, sigma, attn=True):
120
+ """
121
+ in_shape dims -> e.g. (8, 8) -> (H, W)
122
+ x) base_c is 1/96 of out_channels
123
+ base_c is set to 4
124
+ """
125
+ super().__init__()
126
+
127
+ self.down_step = down_step
128
+ self.base_c = out_channels // 96
129
+ if self.base_c < 4:
130
+ self.base_c = 4
131
+
132
+ self.out_channels = out_channels
133
+ self.attn = attn
134
+
135
+ # define the distribution
136
+ self.dist = normal.Normal(loc=0.0, scale=sigma)
137
+
138
+ self.conv3x3 = spectral_norm(
139
+ nn.Conv2d(
140
+ in_channels=self.base_c,
141
+ out_channels=self.base_c,
142
+ kernel_size=(3, 3),
143
+ padding=1
144
+ )
145
+ )
146
+
147
+ cc = self.base_c
148
+ self.l1 = LBlock(cc, cc * 3)
149
+ self.l2 = LBlock(cc * 3, cc * 6)
150
+ self.l3 = LBlock(cc * 6, cc * 24)
151
+ if self.attn:
152
+ self.attn = AttentionLayer(cc * 24, cc * 24)
153
+ self.l4 = LBlock(cc * 24, self.out_channels)
154
+
155
+ def forward(self, x, batch_size=1, z=None):
156
+ """
157
+ x shape -> (batch_size, time, c, width, height)
158
+ """
159
+ width = x.shape[3]
160
+ height = x.shape[4]
161
+ # shape after downstep
162
+ s_w = width // (2 * 2**self.down_step)
163
+ s_h = height // (2 * 2**self.down_step)
164
+
165
+ in_shape = [self.base_c] + [s_w, s_h]
166
+
167
+ target_shape = [batch_size] + in_shape
168
+ if z is None:
169
+ z = self.dist.sample(target_shape)
170
+ z = z.type_as(x)
171
+
172
+ # first conv
173
+ z = self.conv3x3(z)
174
+
175
+ # Lblock
176
+ z = self.l1(z)
177
+ z = self.l2(z)
178
+ z = self.l3(z)
179
+ if self.attn:
180
+ z = self.attn(z)
181
+
182
+ z = self.l4(z)
183
+
184
+ return z
185
+
186
+ # TODO: modification(Change the amount of parameters)
187
+
188
+
189
+ class ContextConditionStack(nn.Module):
190
+ def __init__(self,
191
+ in_channels: int = 1,
192
+ base_channels: int = 24,
193
+ down_step: int = 4,
194
+ prev_step: int = 4):
195
+ """
196
+ base_channels: e.g. 24 -> output_channel: 384
197
+ output_channel: base_c*in_c*2**(down_step-2) * prev_step
198
+ down_step: int
199
+ prev_step: int
200
+ """
201
+ super().__init__()
202
+ self.in_channels = in_channels
203
+ self.down_step = down_step
204
+ self.prev_step = prev_step
205
+ ###
206
+ base_c = base_channels
207
+ in_c = in_channels
208
+
209
+ # different scales channels
210
+ chs = [4 * in_c] + [base_c * in_c * 2 **
211
+ (i + 1) for i in range(down_step)]
212
+
213
+ self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
214
+ self.Dlist = nn.ModuleList()
215
+ self.convList = nn.ModuleList()
216
+ for i in range(down_step):
217
+ self.Dlist.append(
218
+ DBlock(in_channel=chs[i],
219
+ out_channel=chs[i + 1],
220
+ apply_relu=True, apply_down=True)
221
+ )
222
+
223
+ self.convList.append(
224
+ spectral_norm(
225
+ nn.Conv2d(in_channels=prev_step * chs[i + 1],
226
+ out_channels=prev_step * chs[i + 1] // 4,
227
+ kernel_size=(3, 3),
228
+ padding=1)
229
+ )
230
+ )
231
+
232
+ # ReLU
233
+ self.relu = nn.ReLU()
234
+
235
+ def forward(self,
236
+ x: torch.Tensor) -> Tuple[torch.Tensor,
237
+ torch.Tensor,
238
+ torch.Tensor,
239
+ torch.Tensor]:
240
+ """
241
+ ## input dims -> (N, D, C, H, W)
242
+ """
243
+ x = self.space_to_depth(x)
244
+ tsteps = x.shape[1]
245
+ assert tsteps == self.prev_step
246
+
247
+ # different feature index represent different scale
248
+ # features
249
+ # [scale1 -> [t1, t2, t3, t4], scale2 -> [t1, t2, t3, t4], scale3 -> [....]]
250
+ features = [[] for i in range(tsteps)]
251
+
252
+ for st in range(tsteps):
253
+ in_x = x[:, st, :, :, :]
254
+ # in_x -> (Batch(N), C, H, W)
255
+ for scale in range(self.down_step):
256
+ in_x = self.Dlist[scale](in_x)
257
+ features[scale].append(in_x)
258
+
259
+ out_scale = []
260
+ for i, cc in enumerate(self.convList):
261
+ # after stacking, dims -> (Batch, Time, C, H, W)
262
+ # and mixing layer is to concat Time, C
263
+ stacked = self._mixing_layer(torch.stack(features[i], dim=1))
264
+ out = self.relu(cc(stacked))
265
+ out_scale.append(out)
266
+
267
+ return out_scale
268
+
269
+ def _mixing_layer(self, x):
270
+ # conver from (N, Time, C, H, W) -> (N, Time*C, H, W)
271
+ # Then apply Conv2d
272
+ stacked = einops.rearrange(x, "b t c h w -> b (t c) h w")
273
+
274
+ return stacked
model_architect/DGMR_SO/img_extractor.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from ..components.common import DBlock
4
+
5
+
6
+ class ImageExtractor(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ out_channels,
11
+ apply_down_flag,
12
+ down_step=4):
13
+ """
14
+ in_c -> 1
15
+ x) base_c is 1/96 of out_channels
16
+ base_c is set to 4
17
+ """
18
+ super().__init__()
19
+ self.down_step = down_step
20
+
21
+ self.base_c = out_channels // 96
22
+ if self.base_c < 4:
23
+ self.base_c = 4
24
+ cc = self.base_c
25
+
26
+ self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
27
+
28
+ chs = [in_channels * 4, cc * 3, cc * 6, cc * 24, out_channels]
29
+ self.DList = nn.ModuleList()
30
+ for i in range(down_step):
31
+ self.DList.append(
32
+ DBlock(
33
+ in_channel=chs[i],
34
+ out_channel=chs[i + 1],
35
+ conv_type='2d',
36
+ apply_down=apply_down_flag[i]
37
+ ),
38
+ )
39
+
40
+ def forward(self, x):
41
+ """
42
+ x
43
+ """
44
+ y = self.space_to_depth(x)
45
+ # forloop ImageExtractor
46
+ for i in range(self.down_step):
47
+ y = self.DList[i](y)
48
+
49
+ return y
model_architect/DGMR_SO/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .discriminator import TemporalDiscriminator, SpatialDiscriminator
4
+ from .generator import Sampler, ContextConditionStack, LatentConditionStack
5
+ from .img_extractor import ImageExtractor
6
+
7
+
8
+ class Generator(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels,
12
+ base_channels,
13
+ down_step,
14
+ prev_step,
15
+ sigma
16
+ ):
17
+ super().__init__()
18
+ out_channels = base_channels * \
19
+ 2**(down_step - 2) * prev_step * in_channels
20
+
21
+ self.latentStack = LatentConditionStack(
22
+ out_channels=out_channels,
23
+ down_step=down_step,
24
+ sigma=sigma
25
+ )
26
+
27
+ self.contextStack = ContextConditionStack(
28
+ in_channels=in_channels,
29
+ base_channels=base_channels,
30
+ down_step=down_step,
31
+ prev_step=prev_step
32
+ )
33
+
34
+ self.sampler = Sampler(
35
+ in_channels=in_channels,
36
+ base_channels=base_channels,
37
+ up_step=down_step
38
+ )
39
+
40
+ self.encode_time = nn.Linear(
41
+ 4, base_channels * 2**(down_step) * in_channels)
42
+
43
+ self.topo_extractor = ImageExtractor(
44
+ in_channels=1,
45
+ out_channels=base_channels * 2**(down_step - 1) * in_channels,
46
+ apply_down_flag=[True, True, True, True],
47
+ down_step=down_step
48
+ )
49
+
50
+ self.nwp_extractor = ImageExtractor(
51
+ in_channels=1, # TODO: fixed now
52
+ out_channels=base_channels * 2**(down_step - 1) * in_channels,
53
+ apply_down_flag=[False, True, False, True],
54
+ down_step=down_step
55
+ )
56
+
57
+ def forward(self, x, x2, topo, datetime_feat, pred_step=36):
58
+ """
59
+ x: input seq -> dims (N, D, C, H, W)
60
+ x2: input seq (WRF) -> dims (N, D, C, H, W)
61
+ topo: topography -> dims (N, 1, H=512, W=512)
62
+ datetime_feat -> dims (N, D, 4)
63
+ """
64
+ context_inits = self.contextStack(x)
65
+ batch_size = context_inits[0].shape[0]
66
+ zlatent = self.latentStack(x, batch_size=batch_size)
67
+
68
+ # topo feature
69
+ topo_feat = self.topo_extractor(topo)
70
+ # encode time feature
71
+ time_feat = self.encode_time(datetime_feat)
72
+ # extract nwp feature
73
+ nwp_feat = []
74
+ # forloop x2
75
+ for i in range(x2.shape[1]):
76
+ nwp_ = self.nwp_extractor(x2[:, i, ...])
77
+ # concat topo and nwp feature
78
+ concat_feat = torch.cat((nwp_, topo_feat), dim=1)
79
+ nwp_feat.append(concat_feat)
80
+ nwp_feat = torch.stack(nwp_feat, dim=1)
81
+ fuse_feat = nwp_feat + time_feat.unsqueeze(-1).unsqueeze(-1)
82
+
83
+ pred = self.sampler(zlatent, fuse_feat, context_inits, pred_step)
84
+
85
+ return pred
86
+
87
+
88
+ class Discriminator(nn.Module):
89
+ def __init__(self, in_channels, base_channels):
90
+ super().__init__()
91
+ self.spatial = SpatialDiscriminator(
92
+ in_channel=in_channels, base_c=base_channels)
93
+ self.temporal = TemporalDiscriminator(
94
+ in_channel=in_channels, base_c=base_channels)
95
+
96
+ def forward(self, x, y):
97
+ """
98
+ x -> dims (N, D, C, H, W) e.g. input_frames
99
+ y -> dims (N, D, C, H, W) e.g. output_grames
100
+ """
101
+ spatial_out = self.spatial(y)
102
+ temporal_out = self.temporal(torch.cat([x, y], dim=1))
103
+
104
+ dis_out = torch.cat([spatial_out, temporal_out], dim=1)
105
+
106
+ return dis_out
model_architect/Generator_only/generator_clr_idx_wrf_topot.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils.parametrizations import spectral_norm
6
+ import einops
7
+
8
+ from ..components.ConvGRU import ConvGRUCell
9
+ from ..components.common import GBlock, Up_GBlock, DBlock
10
+
11
+
12
+ class Sampler(nn.Module):
13
+ def __init__(self, in_channels, base_channels=24, up_step=4):
14
+ """
15
+ up_step should be the same as down_step in context-condition-stack
16
+
17
+ """
18
+ super().__init__()
19
+ base_c = base_channels
20
+
21
+ self.up_steps = up_step
22
+ self.convgru_list = nn.ModuleList()
23
+ self.conv1x1_list = nn.ModuleList()
24
+ self.gblock_list = nn.ModuleList()
25
+ self.upg_list = nn.ModuleList()
26
+
27
+ # image extractor
28
+ self.img_extractor = ImageExtractor(
29
+ in_channels=2,
30
+ out_channels=base_c * 2**(self.up_steps) * in_channels,
31
+ apply_down_flag=[True, True, True, True],
32
+ down_step=self.up_steps
33
+ )
34
+
35
+ self.nwp_extractor = ImageExtractor(
36
+ in_channels=1,
37
+ out_channels=base_c * 2**(self.up_steps) * in_channels,
38
+ apply_down_flag=[False, True, False, True],
39
+ down_step=self.up_steps
40
+ )
41
+
42
+ self.encode_time = nn.Linear(
43
+ 4, base_c * 2**(self.up_steps + 1) * in_channels)
44
+
45
+ for i in range(self.up_steps):
46
+ # different scale
47
+ chs1 = base_c * 2**(self.up_steps - i + 1) * in_channels
48
+ chs2 = base_c * 2**(self.up_steps - i) * in_channels
49
+ # convgru
50
+ self.convgru_list.append(
51
+ ConvGRUCell(chs1, chs2, 3)
52
+ )
53
+ # conv1x1
54
+ self.conv1x1_list.append(
55
+ spectral_norm(
56
+ nn.Conv2d(
57
+ in_channels=chs2,
58
+ out_channels=chs1,
59
+ kernel_size=(
60
+ 1,
61
+ 1))))
62
+ # GBlock
63
+ self.gblock_list.append(
64
+ GBlock(in_channel=chs1, out_channel=chs1)
65
+ )
66
+ # upgblock
67
+ self.upg_list.append(
68
+ Up_GBlock(in_channel=chs1)
69
+ )
70
+
71
+ # output
72
+ # self.bn = nn.BatchNorm2d(chs2)
73
+ self.relu = nn.ReLU()
74
+ self.last_conv1x1 = spectral_norm(
75
+ nn.Conv2d(in_channels=chs2,
76
+ out_channels=4,
77
+ kernel_size=(1, 1))
78
+ )
79
+ self.depth_to_space = nn.PixelShuffle(upscale_factor=2)
80
+
81
+ def forward(
82
+ self,
83
+ input_img,
84
+ nwp_inputs,
85
+ topo,
86
+ time_feat,
87
+ init_states,
88
+ pred_step,
89
+ thres=None):
90
+ """
91
+ input_img dim -> (N, Tstep, C, W, H) -> Tstep can be 1 or pred_steps
92
+ nwp_inputs dim -> (N, Tstep, C, W, H)
93
+ init_states dim -> [(N, C, W, H)-1, (N, C, W, H)-2, ...]
94
+ probs -> (tsteps)
95
+ """
96
+ hh = init_states
97
+ output = []
98
+ img_t_len = input_img.shape[1]
99
+
100
+ xx = None
101
+ for t in range(pred_step):
102
+ time_emb = self.encode_time(time_feat[:, t]) # time emb
103
+
104
+ # The 1st tstep image should be truth
105
+ # and 2nd tstep image apply schedule sampling
106
+ if t == 0:
107
+ input_ = input_img[:, 0, :, :, :]
108
+ else:
109
+ if thres is not None and img_t_len > 1:
110
+ # use groud truth
111
+ if torch.rand(1) < thres:
112
+ input_ = input_img[:, t, :, :, :]
113
+ else:
114
+ input_ = xx
115
+ else:
116
+ input_ = xx
117
+
118
+ nwp_in = nwp_inputs[:, t]
119
+ # image extractor -> extract T step image
120
+ # xx is the output of image extractor
121
+ input_ = torch.cat((input_, topo), dim=1)
122
+ xx = self.img_extractor(input_)
123
+ nwp_feat = self.nwp_extractor(nwp_in)
124
+ xx = torch.cat((xx, nwp_feat), dim=1)
125
+
126
+ xx = xx + time_emb[:, :, None, None] # add time embedding
127
+
128
+ for up in range(self.up_steps):
129
+ # convGRU
130
+ # init_states should be reversed
131
+ h_index = (self.up_steps - 1) - up
132
+ xx, out_hh = self.convgru_list[up](xx, hh[h_index])
133
+ hh[h_index] = out_hh
134
+ # conv1x1
135
+ xx = self.conv1x1_list[up](xx)
136
+ # gblock
137
+ xx = self.gblock_list[up](xx)
138
+ # upg_list
139
+ xx = self.upg_list[up](xx)
140
+
141
+ # xx = self.bn(xx)
142
+ xx = self.relu(xx)
143
+ xx = self.last_conv1x1(xx)
144
+ xx = self.depth_to_space(xx)
145
+
146
+ # prediction
147
+ output.append(xx)
148
+
149
+ output = torch.stack(output, dim=1)
150
+
151
+ return output
152
+
153
+
154
+ class ImageExtractor(nn.Module):
155
+ def __init__(
156
+ self,
157
+ in_channels,
158
+ out_channels,
159
+ apply_down_flag,
160
+ down_step=4):
161
+ """
162
+ in_c -> 1
163
+ x) base_c is 1/96 of out_channels
164
+ base_c is set to 4
165
+ """
166
+ super().__init__()
167
+ self.down_step = down_step
168
+
169
+ self.base_c = out_channels // 96
170
+ if self.base_c < 4:
171
+ self.base_c = 4
172
+ cc = self.base_c
173
+
174
+ self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
175
+
176
+ chs = [in_channels * 4, cc * 3, cc * 6, cc * 24, out_channels]
177
+ self.DList = nn.ModuleList()
178
+ for i in range(down_step):
179
+ self.DList.append(
180
+ DBlock(
181
+ in_channel=chs[i],
182
+ out_channel=chs[i + 1],
183
+ conv_type='2d',
184
+ apply_down=apply_down_flag[i]
185
+ ),
186
+ )
187
+
188
+ def forward(self, x):
189
+ y = self.space_to_depth(x)
190
+ # forloop ImageExtractor
191
+ for i in range(self.down_step):
192
+ y = self.DList[i](y)
193
+
194
+ return y
195
+
196
+
197
+ class ContextConditionStack(nn.Module):
198
+ def __init__(self,
199
+ in_channels: int = 1,
200
+ base_channels: int = 24,
201
+ down_step: int = 4,
202
+ prev_step: int = 4):
203
+ """
204
+ base_channels: e.g. 24 -> output_channel: 384
205
+ output_channel: base_c*in_c*2**(down_step-2) * prev_step
206
+ down_step: int
207
+ prev_step: int
208
+ """
209
+ super().__init__()
210
+ self.in_channels = in_channels
211
+ self.down_step = down_step
212
+ self.prev_step = prev_step
213
+ ###
214
+ base_c = base_channels
215
+ in_c = in_channels
216
+
217
+ # different scales channels
218
+ chs = [4 * in_c] + [base_c * in_c * 2 **
219
+ (i + 1) for i in range(down_step)]
220
+
221
+ self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
222
+ self.Dlist = nn.ModuleList()
223
+ self.convList = nn.ModuleList()
224
+ for i in range(down_step):
225
+ self.Dlist.append(
226
+ DBlock(in_channel=chs[i],
227
+ out_channel=chs[i + 1],
228
+ apply_relu=True, apply_down=True)
229
+ )
230
+
231
+ self.convList.append(
232
+ spectral_norm(
233
+ nn.Conv2d(in_channels=prev_step * chs[i + 1],
234
+ out_channels=prev_step * chs[i + 1] // 4,
235
+ kernel_size=(3, 3),
236
+ padding=1)
237
+ )
238
+ )
239
+
240
+ # ReLU
241
+ self.relu = nn.ReLU()
242
+
243
+ def forward(self,
244
+ x: torch.Tensor) -> Tuple[torch.Tensor,
245
+ torch.Tensor,
246
+ torch.Tensor,
247
+ torch.Tensor]:
248
+ """
249
+ ## input dims -> (N, D, C, H, W)
250
+ """
251
+ x = self.space_to_depth(x)
252
+ tsteps = x.shape[1]
253
+ assert tsteps == self.prev_step
254
+
255
+ # different feature index represent different scale
256
+ # features
257
+ # [scale1 -> [t1, t2, t3, t4], scale2 -> [t1, t2, t3, t4], scale3 -> [....]]
258
+ features = [[] for i in range(tsteps)]
259
+
260
+ for st in range(tsteps):
261
+ in_x = x[:, st, :, :, :]
262
+ # in_x -> (Batch(N), C, H, W)
263
+ for scale in range(self.down_step):
264
+ in_x = self.Dlist[scale](in_x)
265
+ features[scale].append(in_x)
266
+
267
+ out_scale = []
268
+ for i, cc in enumerate(self.convList):
269
+ # after stacking, dims -> (Batch, Time, C, H, W)
270
+ # and mixing layer is to concat Time, C
271
+ stacked = self._mixing_layer(torch.stack(features[i], dim=1))
272
+ out = self.relu(cc(stacked))
273
+ out_scale.append(out)
274
+
275
+ return out_scale
276
+
277
+ def _mixing_layer(self, x):
278
+ # conver from (N, Time, C, H, W) -> (N, Time*C, H, W)
279
+ # Then apply Conv2d
280
+ stacked = einops.rearrange(x, "b t c h w -> b (t c) h w")
281
+
282
+ return stacked
model_architect/Generator_only/model_clr_idx.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .generator_clr_idx_wrf_topot import Sampler, ContextConditionStack
4
+
5
+
6
+ class Generator(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ base_channels,
11
+ down_step,
12
+ prev_step
13
+ ):
14
+
15
+ super().__init__()
16
+ self.contextStack = ContextConditionStack(
17
+ in_channels=in_channels,
18
+ base_channels=base_channels,
19
+ down_step=down_step,
20
+ prev_step=prev_step
21
+ )
22
+
23
+ self.sampler = Sampler(
24
+ in_channels=in_channels,
25
+ base_channels=base_channels,
26
+ up_step=down_step,
27
+ )
28
+
29
+ def forward(
30
+ self,
31
+ x,
32
+ x2,
33
+ topo,
34
+ time_feat,
35
+ pred_step=12,
36
+ y=None,
37
+ thres=None
38
+ ):
39
+ """
40
+ x: input seq -> dims (N, T, C, H, W)
41
+ x2: input seq for WRF -> dims (N, T, C, H, W)
42
+ """
43
+ context_inits = self.contextStack(x)
44
+ if y is None:
45
+ y = x[:, -1:, :, :, :]
46
+
47
+ pred = self.sampler(
48
+ y,
49
+ x2,
50
+ topo,
51
+ time_feat,
52
+ context_inits,
53
+ pred_step,
54
+ thres
55
+ )
56
+
57
+ return pred
model_architect/__init__.py ADDED
File without changes
model_architect/components/ConvGRU.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn.utils.parametrizations import spectral_norm
4
+ from torch.autograd import Variable
5
+
6
+
7
+ class ConvGRUCell(torch.nn.Module):
8
+ def __init__(self, in_channel, out_channel, kernel_size=3):
9
+ super().__init__()
10
+ padding = kernel_size // 2
11
+ self.out_channel = out_channel
12
+
13
+ self.conv1 = spectral_norm(
14
+ torch.nn.Conv2d(
15
+ in_channels=in_channel + out_channel,
16
+ out_channels=2 * out_channel,
17
+ kernel_size=kernel_size,
18
+ padding=padding
19
+ )
20
+ )
21
+
22
+ self.conv2 = spectral_norm(
23
+ torch.nn.Conv2d(
24
+ in_channels=in_channel + out_channel,
25
+ out_channels=out_channel,
26
+ kernel_size=kernel_size,
27
+ padding=padding
28
+ )
29
+ )
30
+
31
+ def forward(self, x, h_st):
32
+ """
33
+ x -> dim (Batch, channels*2, width, height)
34
+ h_st -> dim (Batch, channels, width, height)
35
+ """
36
+ x_shape = x.shape
37
+ h_shape = h_st.shape
38
+ # resize width, height
39
+ if h_shape[2] > x_shape[2]:
40
+ w_l = 1
41
+ else:
42
+ w_l = 0
43
+
44
+ if h_shape[3] > x_shape[3]:
45
+ h_b = 1
46
+ else:
47
+ h_b = 0
48
+
49
+ x = F.pad(x, (0, h_b, w_l, 0), "reflect")
50
+
51
+ # print(x.shape, h_st.shape)
52
+ xx = torch.cat([x, h_st], dim=1)
53
+ xx = self.conv1(xx)
54
+ gamma, beta = torch.split(xx, self.out_channel, dim=1)
55
+
56
+ reset_gate = torch.sigmoid(gamma)
57
+ update_gate = torch.sigmoid(beta)
58
+
59
+ out = torch.cat([x, h_st * reset_gate], dim=1)
60
+ out = torch.tanh(self.conv2(out))
61
+
62
+ out = (1 - update_gate) * out + h_st * update_gate
63
+ new_st = out
64
+
65
+ return out, new_st
66
+
67
+
68
+ class ConvGRU(torch.nn.Module):
69
+ def __init__(self, in_channel, out_channel, kernel_size):
70
+ super().__init__()
71
+
72
+ self.out_channel = out_channel
73
+ self.convgru_cell = ConvGRUCell(in_channel, out_channel, kernel_size)
74
+
75
+ def _get_init_state(self, batch_size, imd_w, imd_h, dtype):
76
+ state = Variable(
77
+ torch.zeros(
78
+ batch_size,
79
+ self.out_channel,
80
+ self.h,
81
+ self.w)).type(dtype)
82
+
83
+ return state
84
+
85
+ def forward(self, x_sequence, init_hidden=None):
86
+ """
87
+ Args:
88
+ x_sequence shape -> (batch_size, time, c, width, height)
89
+ Return:
90
+ outputs shape -> (time, batch_size, c, width, height)
91
+ """
92
+ seq_len = x_sequence.shape[1]
93
+
94
+ img_w = x_sequence.shape[3]
95
+ img_h = x_sequence.shape[4]
96
+
97
+ dtype = x_sequence.type()
98
+ if init_hidden is None:
99
+ hidden_state = self._get_init_state(
100
+ x_sequence.shape[0], img_w, img_h, dtype)
101
+ else:
102
+ hidden_state = init_hidden
103
+
104
+ out_list = []
105
+ for t in range(seq_len):
106
+ out, hidden_state = self.convgru_cell(
107
+ x_sequence[:, t, :, :, :], hidden_state)
108
+ out_list.append(out)
109
+
110
+ outputs = torch.stack(out_list, dim=0)
111
+
112
+ return outputs
model_architect/components/common.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.utils.parametrizations import spectral_norm
2
+ import torch
3
+ from torch.nn import functional as F
4
+ import einops
5
+
6
+
7
+ class GBlock(torch.nn.Module):
8
+ def __init__(self, in_channel: int, out_channel: int):
9
+ super().__init__()
10
+ self.in_channel = in_channel
11
+ self.out_channel = out_channel
12
+
13
+ # TODO: close batch
14
+ # self.bn1 = torch.nn.BatchNorm2d(in_channel)
15
+ # self.bn2 = torch.nn.BatchNorm2d(out_channel)
16
+
17
+ self.relu = torch.nn.ReLU()
18
+ # conv1x1
19
+ self.conv1x1 = spectral_norm(
20
+ torch.nn.Conv2d(in_channel, out_channel, kernel_size=1)
21
+ )
22
+
23
+ self.conv3x3_1 = spectral_norm(
24
+ torch.nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
25
+ )
26
+ self.conv3x3_2 = spectral_norm(
27
+ torch.nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
28
+ )
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ # if shape is different then applied
32
+ if x.shape[1] != self.out_channel:
33
+ res = self.conv1x1(x)
34
+ else:
35
+ res = x.clone()
36
+
37
+ # first
38
+ # x = self.bn1(x)
39
+ x = self.relu(x)
40
+ x = self.conv3x3_1(x)
41
+ # second
42
+ # x = self.bn2(x)
43
+ x = self.relu(x)
44
+ x = self.conv3x3_2(x)
45
+
46
+ y = x + res
47
+
48
+ return y
49
+
50
+
51
+ class Up_GBlock(torch.nn.Module):
52
+ def __init__(self, in_channel: int):
53
+ super().__init__()
54
+ self.in_channel = in_channel
55
+ self.out_channel = int(in_channel / 2)
56
+
57
+ # TODO: close batch
58
+ # self.bn1 = torch.nn.BatchNorm2d(in_channel)
59
+ # self.bn2 = torch.nn.BatchNorm2d(self.out_channel)
60
+ # self.bn2 = torch.nn.BatchNorm2d(in_channel)
61
+
62
+ self.relu = torch.nn.ReLU()
63
+ self.up = torch.nn.Upsample(scale_factor=2, mode='nearest')
64
+
65
+ self.conv1x1 = spectral_norm(
66
+ torch.nn.Conv2d(in_channel, self.out_channel, kernel_size=1)
67
+ )
68
+
69
+ self.conv3x3_1 = spectral_norm(
70
+ torch.nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1)
71
+ )
72
+ self.conv3x3_2 = spectral_norm(
73
+ torch.nn.Conv2d(
74
+ in_channel,
75
+ self.out_channel,
76
+ kernel_size=3,
77
+ padding=1))
78
+
79
+ def forward(self, x):
80
+ res = self.up(x)
81
+ res = self.conv1x1(res)
82
+
83
+ # x = self.bn1(x)
84
+ x = self.relu(x)
85
+ x = self.up(x)
86
+ x = self.conv3x3_1(x)
87
+
88
+ # x = self.bn2(x)
89
+ x = self.relu(x)
90
+ x = self.conv3x3_2(x)
91
+
92
+ y = x + res
93
+
94
+ return y
95
+
96
+
97
+ class DBlock(torch.nn.Module):
98
+ def __init__(
99
+ self,
100
+ in_channel: int,
101
+ out_channel: int,
102
+ conv_type='2d',
103
+ apply_relu=True,
104
+ apply_down=False):
105
+ super().__init__()
106
+ self.in_channel = in_channel
107
+ self.out_channel = out_channel
108
+ self.apply_relu = apply_relu
109
+ self.apply_down = apply_down
110
+
111
+ # construct layer
112
+ if conv_type == '2d':
113
+ self.avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)
114
+ conv = torch.nn.Conv2d
115
+ elif conv_type == '3d':
116
+ self.avg_pool = torch.nn.AvgPool3d(kernel_size=2, stride=2)
117
+ conv = torch.nn.Conv3d
118
+
119
+ self.relu = torch.nn.ReLU()
120
+ self.conv1x1 = spectral_norm(
121
+ conv(in_channel, out_channel, kernel_size=1)
122
+ )
123
+
124
+ self.conv3x3_1 = spectral_norm(
125
+ conv(in_channel, out_channel, kernel_size=3, padding=1)
126
+ )
127
+ self.conv3x3_2 = spectral_norm(
128
+ conv(out_channel, out_channel, kernel_size=3, padding=1)
129
+ )
130
+
131
+ def forward(self, x):
132
+ # Residual block
133
+ if x.shape[1] != self.out_channel:
134
+ res = self.conv1x1(x)
135
+ else:
136
+ res = x.clone()
137
+ if self.apply_down:
138
+ res = self.avg_pool(res)
139
+
140
+ ##
141
+ if self.apply_relu:
142
+ x = self.relu(x)
143
+ x = self.conv3x3_1(x)
144
+ x = self.relu(x)
145
+ x = self.conv3x3_2(x)
146
+ if self.apply_down:
147
+ x = self.avg_pool(x)
148
+
149
+ # connect
150
+ y = res + x
151
+
152
+ return y
153
+
154
+
155
+ class LBlock(torch.nn.Module):
156
+ def __init__(self, in_channel, out_channel):
157
+ super().__init__()
158
+ self.in_channel = in_channel
159
+ self.out_channel = out_channel
160
+
161
+ self.relu = torch.nn.ReLU()
162
+ conv = torch.nn.Conv2d
163
+ self.conv1x1 = conv(
164
+ in_channel,
165
+ (out_channel - in_channel),
166
+ kernel_size=1)
167
+
168
+ self.conv3x3_1 = conv(in_channel, in_channel, kernel_size=3, padding=1)
169
+ self.conv3x3_2 = conv(
170
+ in_channel,
171
+ out_channel,
172
+ kernel_size=3,
173
+ padding=1)
174
+
175
+ def forward(self, x):
176
+ res = torch.cat([x, self.conv1x1(x)], dim=1)
177
+
178
+ x = self.relu(x)
179
+ x = self.conv3x3_1(x)
180
+ x = self.relu(x)
181
+ x = self.conv3x3_2(x)
182
+
183
+ y = x + res
184
+
185
+ return y
186
+
187
+ # Attention Layer
188
+
189
+
190
+ def attention_einsum(q, k, v):
191
+ """
192
+ Apply self-attention to tensors
193
+ """
194
+
195
+ # Reshape 3D tensor to 2D tensor with first dimension L = h x w
196
+ k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
197
+ v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
198
+
199
+ # Einstein summation corresponding to the query * key operation.
200
+ beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
201
+
202
+ # Einstein summation corresponding to the attention * value operation.
203
+ out = torch.einsum("hwL, Lc->hwc", beta, v)
204
+
205
+ return out
206
+
207
+
208
+ class AttentionLayer(torch.nn.Module):
209
+ def __init__(self, in_channel, out_channel, ratio_kq=8, ratio_v=8):
210
+ super().__init__()
211
+
212
+ self.ratio_kq = ratio_kq
213
+ self.ratio_v = ratio_v
214
+ self.in_channel = in_channel
215
+ self.out_channel = out_channel
216
+
217
+ # compute query, key, and value using 1x1 convolution
218
+ self.query = torch.nn.Conv2d(
219
+ in_channel,
220
+ out_channel // ratio_kq,
221
+ kernel_size=1,
222
+ bias=False
223
+ )
224
+
225
+ self.key = torch.nn.Conv2d(
226
+ in_channel,
227
+ out_channel // ratio_kq,
228
+ kernel_size=1,
229
+ bias=False
230
+ )
231
+
232
+ self.value = torch.nn.Conv2d(
233
+ in_channel,
234
+ out_channel // ratio_v,
235
+ kernel_size=1,
236
+ bias=False
237
+ )
238
+
239
+ self.conv = torch.nn.Conv2d(
240
+ out_channel // 8,
241
+ out_channel,
242
+ kernel_size=1,
243
+ bias=False
244
+ )
245
+
246
+ self.gamma = torch.nn.Parameter(torch.zeros(1))
247
+
248
+ def forward(self, x):
249
+ query = self.query(x)
250
+ key = self.key(x)
251
+ value = self.value(x)
252
+ # apply attention
253
+ out = []
254
+ for i in range(x.shape[0]):
255
+ out.append(attention_einsum(query[i], key[i], value[i]))
256
+
257
+ out = torch.stack(out, dim=0)
258
+ out = self.gamma * self.conv(out)
259
+ out = out + x # skip connection
260
+
261
+ return out
model_architect/inference_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .DGMR_SO.model import Generator as DGMR_SO
4
+ from .Generator_only.model_clr_idx import Generator as Generator_only
5
+
6
+
7
+ class Predictor(nn.Module):
8
+ def __init__(
9
+ self,
10
+ model_type,
11
+ ):
12
+ super().__init__()
13
+
14
+ if model_type == 'DGMR_SO':
15
+ self.generator = DGMR_SO(
16
+ in_channels=1,
17
+ base_channels=24,
18
+ down_step=4,
19
+ prev_step=4,
20
+ sigma=1
21
+ )
22
+
23
+ elif model_type == 'Generator_only':
24
+ self.generator = Generator_only(
25
+ in_channels=1,
26
+ base_channels=24,
27
+ down_step=4,
28
+ prev_step=4,
29
+ )
30
+
31
+ def forward(self, x, x2, topo, datetime_feat, pred_step=36):
32
+ """
33
+ x: input seq -> dims (N, D, C, H, W)
34
+ x2: input seq (WRF) -> dims (N, D, C, H, W)
35
+ topo: topography -> dims (N, 1, H=512, W=512)
36
+ datetime_feat -> dims (N, D, 4)
37
+ """
38
+ pred = self.generator(x, x2, topo, datetime_feat, pred_step=pred_step)
39
+
40
+ return pred
model_weights/DGMR_SO/ft36/weights.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec17b13cb466248e335803a1aac17e9246ab24ea0518d609bcd0fcc04cd1f928
3
+ size 215336260
model_weights/Generator_only/ft36/weights.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87e230402e7f1c8f1f65af63d994856bf6f5eb637a37fd613c4c83fb6f194dc1
3
+ size 222876572
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy==1.26.4
2
+ torch==2.4.0
3
+ einops==0.8.0
sample_data/sample_202504131100.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6117356ac780a530645e192cc85d647b103c915b63433441d810dedc7cdd4ec1
3
+ size 33002900
sample_data/sample_202504161200.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7b9e6c7ed76f695f7c6f2f1a976f4c27128120e5c4328223809c27dc8feee52
3
+ size 33300209
sample_data/sample_202507151200.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e20eb69ea2bef6bb0074c376afa7e8398e7da8eb1edd9a1f11c343ffe711a299
3
+ size 33038261