Commit ·
2fa5aae
0
Parent(s):
feat: initial implementation of DGMR solar radiation nowcasting models
Browse files- Add the model architecture
- Include pre-trained model weights for DGMR_SO & Generator_only models
- Add inference pipeline and sample data for testing
- Configure project structure with requirements and documentation
- .gitattributes +5 -0
- .gitignore +1 -0
- README.md +73 -0
- inference.py +75 -0
- model_architect/DGMR_SO/discriminator.py +143 -0
- model_architect/DGMR_SO/generator.py +274 -0
- model_architect/DGMR_SO/img_extractor.py +49 -0
- model_architect/DGMR_SO/model.py +106 -0
- model_architect/Generator_only/generator_clr_idx_wrf_topot.py +282 -0
- model_architect/Generator_only/model_clr_idx.py +57 -0
- model_architect/__init__.py +0 -0
- model_architect/components/ConvGRU.py +112 -0
- model_architect/components/common.py +261 -0
- model_architect/inference_model.py +40 -0
- model_weights/DGMR_SO/ft36/weights.ckpt +3 -0
- model_weights/Generator_only/ft36/weights.ckpt +3 -0
- requirements.txt +3 -0
- sample_data/sample_202504131100.npz +3 -0
- sample_data/sample_202504161200.npz +3 -0
- sample_data/sample_202507151200.npz +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DGMR Solar Radiation Nowcasting
|
| 2 |
+
|
| 3 |
+
A deep learning model for solar radiation nowcasting using modified [Deep Generative Model of Rainfall (DGMR)](https://www.nature.com/articles/s41586-021-03854-z) architecture with Solar radiation Output (DGMR-SO). The model predicts clearsky index and converts it to solar radiation for up to 36 time steps ahead.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This repository implements two model variants for solar radiation forecasting:
|
| 8 |
+
- **DGMR_SO**: Full Deep Generative Models with one generator and two discriminators during the training stage
|
| 9 |
+
- **Generator_only**: Only one generator during the training stage
|
| 10 |
+
|
| 11 |
+
The model uses multiple input sources:
|
| 12 |
+
- **Himawari satellite data**: Clearsky index calculated from Himawari satellite data
|
| 13 |
+
- **WRF Prediction**: Clearsky index from WRF's solar irradiation prediction
|
| 14 |
+
- **Topography**: Static topographical features
|
| 15 |
+
- **Time features**: Temporal sin/cos encoding for day and hour
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
1. Clone the repository:
|
| 20 |
+
```bash
|
| 21 |
+
git clone <repository-url>
|
| 22 |
+
cd DGMR_SolRad
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
2. Install dependencies:
|
| 26 |
+
```bash
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Requirements
|
| 31 |
+
|
| 32 |
+
- Python 3.x
|
| 33 |
+
- PyTorch 2.4.0
|
| 34 |
+
- NumPy 1.26.4
|
| 35 |
+
- einops 0.8.0
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
### Basic Inference
|
| 40 |
+
|
| 41 |
+
Run solar radiation prediction using the pre-trained models:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
python inference.py --model-type DGMR_SO --basetime 202504131100
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Command Line Arguments
|
| 48 |
+
|
| 49 |
+
- `--model-type`: Choose between `DGMR_SO` or `Generator_only` (default: `DGMR_SO`)
|
| 50 |
+
- `--basetime`: Timestamp for input data in format YYYYMMDDHHMM (default: `202504131100`)
|
| 51 |
+
|
| 52 |
+
### Example
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
# Using DGMR_SO model
|
| 56 |
+
python inference.py --model-type DGMR_SO --basetime 202504131100
|
| 57 |
+
|
| 58 |
+
# Using Generator-only model
|
| 59 |
+
python inference.py --model-type Generator_only --basetime 202507151200
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Sample Data
|
| 63 |
+
|
| 64 |
+
The repository includes sample data files:
|
| 65 |
+
- `sample_202504131100.npz`
|
| 66 |
+
- `sample_202504161200.npz`
|
| 67 |
+
- `sample_202507151200.npz`
|
| 68 |
+
|
| 69 |
+
## Model Weights
|
| 70 |
+
|
| 71 |
+
Pre-trained weights are available for both models:
|
| 72 |
+
- `model_weights/DGMR_SO/ft36/weights.ckpt`
|
| 73 |
+
- `model_weights/Generator_only/ft36/weights.ckpt`
|
inference.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from model_architect.inference_model import Predictor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def data_loading(BASETIME, device):
|
| 9 |
+
data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz')
|
| 10 |
+
|
| 11 |
+
inputs = {}
|
| 12 |
+
for key in data_npz:
|
| 13 |
+
inputs[key] = torch.from_numpy(data_npz[key]).to(device)
|
| 14 |
+
|
| 15 |
+
return inputs
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def model_loading(model_type, device):
|
| 19 |
+
if model_type == 'DGMR_SO':
|
| 20 |
+
ckpt_path = './model_weights/DGMR_SO/ft36/weights.ckpt'
|
| 21 |
+
elif model_type == 'Generator_only':
|
| 22 |
+
ckpt_path = './model_weights/Generator_only/ft36/weights.ckpt'
|
| 23 |
+
|
| 24 |
+
model = Predictor(
|
| 25 |
+
model_type=model_type,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
ckpt = torch.load(ckpt_path, weights_only=True)
|
| 29 |
+
model.load_state_dict(ckpt['generator_state_dict'])
|
| 30 |
+
model.eval()
|
| 31 |
+
model.to(device)
|
| 32 |
+
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def arg_parse():
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
'--model-type',
|
| 40 |
+
type=str,
|
| 41 |
+
default='DGMR_SO',
|
| 42 |
+
choices=[
|
| 43 |
+
'Generator_only',
|
| 44 |
+
'DGMR_SO'])
|
| 45 |
+
parser.add_argument('--basetime', type=str, default='202504131100')
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
return args
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == '__main__':
|
| 51 |
+
args = arg_parse()
|
| 52 |
+
model_type = args.model_type
|
| 53 |
+
BASETIME = args.basetime
|
| 54 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 55 |
+
|
| 56 |
+
inputs = data_loading(BASETIME, device)
|
| 57 |
+
model = model_loading(model_type, device)
|
| 58 |
+
|
| 59 |
+
# prediction
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
pred_clr_idx = model(
|
| 62 |
+
inputs['Himawari'],
|
| 63 |
+
inputs['WRF'],
|
| 64 |
+
inputs['topo'],
|
| 65 |
+
inputs['time_feat'],
|
| 66 |
+
pred_step=36,
|
| 67 |
+
)
|
| 68 |
+
pred_clr_idx = pred_clr_idx.squeeze(2).clamp(0, 1)
|
| 69 |
+
|
| 70 |
+
# transform clearsky index to solar radiation
|
| 71 |
+
pred_srad = pred_clr_idx * inputs['clearsky'] # dim: (1, 36, 512, 512)
|
| 72 |
+
|
| 73 |
+
# save prediction
|
| 74 |
+
np.save(f'./pred_{BASETIME}_{model_type}.npy', pred_srad.cpu().numpy())
|
| 75 |
+
print('Done')
|
model_architect/DGMR_SO/discriminator.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from ..components.common import DBlock
|
| 4 |
+
import einops
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TemporalDiscriminator(nn.Module):
|
| 8 |
+
def __init__(self, in_channel: int, base_c: int = 24):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.in_channel = in_channel
|
| 11 |
+
|
| 12 |
+
# (N, C, D, H, W)
|
| 13 |
+
self.down_sample = nn.AvgPool3d(
|
| 14 |
+
kernel_size=(1, 2, 2),
|
| 15 |
+
stride=(1, 2, 2)
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# (N, D, C, H, W)
|
| 19 |
+
self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
|
| 20 |
+
|
| 21 |
+
# in_channel, out_channel
|
| 22 |
+
# Conv3D -> (N, C, D, H, W)
|
| 23 |
+
chn = base_c * 2 * in_channel
|
| 24 |
+
self.d3_1 = DBlock(in_channel=in_channel * 4,
|
| 25 |
+
out_channel=chn,
|
| 26 |
+
conv_type='3d', apply_relu=False, apply_down=True)
|
| 27 |
+
|
| 28 |
+
self.d3_2 = DBlock(in_channel=chn,
|
| 29 |
+
out_channel=2 * chn,
|
| 30 |
+
conv_type='3d', apply_relu=True, apply_down=True)
|
| 31 |
+
|
| 32 |
+
self.Dlist = nn.ModuleList()
|
| 33 |
+
for i in range(3):
|
| 34 |
+
chn = chn * 2
|
| 35 |
+
self.Dlist.append(
|
| 36 |
+
DBlock(in_channel=chn,
|
| 37 |
+
out_channel=2 * chn,
|
| 38 |
+
conv_type='2d', apply_relu=True, apply_down=True)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.last_D = DBlock(in_channel=2 * chn,
|
| 42 |
+
out_channel=2 * chn,
|
| 43 |
+
conv_type='2d', apply_relu=True, apply_down=False)
|
| 44 |
+
|
| 45 |
+
self.fc = nn.Linear(2 * chn, 1)
|
| 46 |
+
self.relu = nn.ReLU()
|
| 47 |
+
# TODO: close bn
|
| 48 |
+
# self.bn = nn.BatchNorm1d(2*chn)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
x = self.down_sample(x)
|
| 52 |
+
x = self.space_to_depth(x)
|
| 53 |
+
|
| 54 |
+
# go through the 3D Block
|
| 55 |
+
# from (N, D, C, H, W) -> (N, C, D, H, W)
|
| 56 |
+
x = torch.permute(x, dims=(0, 2, 1, 3, 4))
|
| 57 |
+
x = self.d3_1(x)
|
| 58 |
+
x = self.d3_2(x)
|
| 59 |
+
# go through 2D Block, permute -> (N, D, C, H, W)
|
| 60 |
+
x = torch.permute(x, dims=(0, 2, 1, 3, 4))
|
| 61 |
+
n, d, c, h, w = list(x.size())
|
| 62 |
+
####
|
| 63 |
+
fea = einops.rearrange(x, "n d c h w -> (n d) c h w")
|
| 64 |
+
for dd in self.Dlist:
|
| 65 |
+
fea = dd(fea)
|
| 66 |
+
|
| 67 |
+
fea = self.last_D(fea)
|
| 68 |
+
|
| 69 |
+
fea = torch.sum(self.relu(fea), dim=[2, 3])
|
| 70 |
+
# fea = self.bn(fea)
|
| 71 |
+
fea = self.fc(fea)
|
| 72 |
+
|
| 73 |
+
y = torch.reshape(fea, (n, d, 1)) # dims -> (N, D, 1)
|
| 74 |
+
y = torch.sum(y, keepdim=True, dim=1) # dims -> (N, 1, 1)
|
| 75 |
+
|
| 76 |
+
return y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SpatialDiscriminator(nn.Module):
|
| 80 |
+
def __init__(self, in_channel: int, base_c: int = 24):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.in_channel = in_channel
|
| 83 |
+
|
| 84 |
+
# (N, C, H, W)
|
| 85 |
+
self.down_sample = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
|
| 86 |
+
|
| 87 |
+
# (N, C, H, W)
|
| 88 |
+
self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
|
| 89 |
+
|
| 90 |
+
# first Dblock doesn't apply relu
|
| 91 |
+
chn = base_c * in_channel
|
| 92 |
+
self.d1 = DBlock(in_channel=in_channel * 4,
|
| 93 |
+
out_channel=chn * 2,
|
| 94 |
+
conv_type='2d', apply_relu=False, apply_down=True)
|
| 95 |
+
|
| 96 |
+
self.Dlist = nn.ModuleList()
|
| 97 |
+
for i in range(4):
|
| 98 |
+
chn = chn * 2
|
| 99 |
+
self.Dlist.append(
|
| 100 |
+
DBlock(in_channel=chn,
|
| 101 |
+
out_channel=2 * chn,
|
| 102 |
+
conv_type='2d', apply_relu=True, apply_down=True)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.last_D = DBlock(in_channel=2 * chn,
|
| 106 |
+
out_channel=2 * chn,
|
| 107 |
+
conv_type='2d', apply_relu=True, apply_down=False)
|
| 108 |
+
|
| 109 |
+
self.fc = nn.Linear(2 * chn, 1)
|
| 110 |
+
self.relu = nn.ReLU()
|
| 111 |
+
# TODO: close BN
|
| 112 |
+
# self.bn = nn.BatchNorm1d(2*chn)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
# note: input dims -> (N, D, C, H, W)
|
| 116 |
+
# randomly pick up 8 out of 18
|
| 117 |
+
perm = torch.randperm(x.shape[1])
|
| 118 |
+
random_idx = perm[:8]
|
| 119 |
+
|
| 120 |
+
fea = x[:, random_idx, :, :, :]
|
| 121 |
+
|
| 122 |
+
n, d, c, h, w = list(fea.size())
|
| 123 |
+
|
| 124 |
+
fea = einops.rearrange(fea, "n d c h w -> (n d) c h w")
|
| 125 |
+
fea = self.down_sample(fea)
|
| 126 |
+
fea = self.space_to_depth(fea)
|
| 127 |
+
|
| 128 |
+
# apply DBlock
|
| 129 |
+
fea = self.d1(fea)
|
| 130 |
+
for dd in self.Dlist:
|
| 131 |
+
fea = dd(fea)
|
| 132 |
+
|
| 133 |
+
fea = self.last_D(fea)
|
| 134 |
+
|
| 135 |
+
# sum
|
| 136 |
+
fea = torch.sum(self.relu(fea), dim=[2, 3])
|
| 137 |
+
# fea = self.bn(fea)
|
| 138 |
+
fea = self.fc(fea)
|
| 139 |
+
|
| 140 |
+
y = torch.reshape(fea, (n, d, 1)) # dims -> (N, D, 1)
|
| 141 |
+
y = torch.sum(y, keepdim=True, dim=1) # dims -> (N, 1, 1)
|
| 142 |
+
|
| 143 |
+
return y
|
model_architect/DGMR_SO/generator.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn.utils.parametrizations import spectral_norm
|
| 6 |
+
from torch.distributions import normal
|
| 7 |
+
import einops
|
| 8 |
+
|
| 9 |
+
from ..components.ConvGRU import ConvGRU
|
| 10 |
+
from ..components.common import GBlock, Up_GBlock, LBlock, AttentionLayer, DBlock
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Sampler(nn.Module):
|
| 14 |
+
def __init__(self, in_channels, base_channels=24, up_step=4):
|
| 15 |
+
"""
|
| 16 |
+
up_step should be the same as down_step in context-condition-stack
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
super().__init__()
|
| 20 |
+
base_c = base_channels
|
| 21 |
+
|
| 22 |
+
self.up_steps = up_step
|
| 23 |
+
self.convgru_list = nn.ModuleList()
|
| 24 |
+
self.conv1x1_list = nn.ModuleList()
|
| 25 |
+
self.gblock_list = nn.ModuleList()
|
| 26 |
+
self.upg_list = nn.ModuleList()
|
| 27 |
+
|
| 28 |
+
for i in range(self.up_steps):
|
| 29 |
+
# different scale
|
| 30 |
+
chs1 = base_c * 2**(self.up_steps - i + 1) * in_channels
|
| 31 |
+
chs2 = base_c * 2**(self.up_steps - i) * in_channels
|
| 32 |
+
# convgru
|
| 33 |
+
self.convgru_list.append(
|
| 34 |
+
ConvGRU(chs1, chs2, 3)
|
| 35 |
+
)
|
| 36 |
+
# conv1x1
|
| 37 |
+
self.conv1x1_list.append(
|
| 38 |
+
spectral_norm(
|
| 39 |
+
nn.Conv2d(
|
| 40 |
+
in_channels=chs2,
|
| 41 |
+
out_channels=chs1,
|
| 42 |
+
kernel_size=(
|
| 43 |
+
1,
|
| 44 |
+
1))))
|
| 45 |
+
# GBlock
|
| 46 |
+
self.gblock_list.append(
|
| 47 |
+
GBlock(in_channel=chs1, out_channel=chs1)
|
| 48 |
+
)
|
| 49 |
+
# upgblock
|
| 50 |
+
self.upg_list.append(
|
| 51 |
+
Up_GBlock(in_channel=chs1)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# output
|
| 55 |
+
# TODO: close Batch
|
| 56 |
+
# self.bn = nn.BatchNorm2d(chs2)
|
| 57 |
+
self.relu = nn.ReLU()
|
| 58 |
+
self.last_conv1x1 = spectral_norm(
|
| 59 |
+
nn.Conv2d(in_channels=chs2,
|
| 60 |
+
out_channels=4,
|
| 61 |
+
kernel_size=(1, 1))
|
| 62 |
+
)
|
| 63 |
+
self.depth_to_space = nn.PixelShuffle(upscale_factor=2)
|
| 64 |
+
|
| 65 |
+
def forward(self, latents, img_feat, init_states, pred_step):
|
| 66 |
+
"""
|
| 67 |
+
latent dim -> (N, C, W, H)
|
| 68 |
+
img_feat dim -> (N, D, C, W, H)
|
| 69 |
+
init_states dim -> (N, C, W, H)
|
| 70 |
+
"""
|
| 71 |
+
# expand time dims at axis=1
|
| 72 |
+
latents = torch.unsqueeze(latents, dim=1)
|
| 73 |
+
|
| 74 |
+
# repeat batch_size
|
| 75 |
+
if latents.shape[0] == 1:
|
| 76 |
+
# expand batch
|
| 77 |
+
latents = einops.repeat(
|
| 78 |
+
latents,
|
| 79 |
+
"b d c h w -> (repeat b) d c h w",
|
| 80 |
+
repeat=init_states[0].shape[0])
|
| 81 |
+
# repeat time step
|
| 82 |
+
latents = einops.repeat(
|
| 83 |
+
latents, "b d c h w -> b (repeat d) c h w", repeat=pred_step
|
| 84 |
+
)
|
| 85 |
+
# TODO: add img feas
|
| 86 |
+
# seq_out = latents + img_feat
|
| 87 |
+
seq_out = torch.cat((img_feat, latents), dim=2)
|
| 88 |
+
# init_states should be reversed
|
| 89 |
+
# scale up step
|
| 90 |
+
for i in range(self.up_steps):
|
| 91 |
+
seq_out = self.convgru_list[i](
|
| 92 |
+
seq_out, init_states[(self.up_steps - 1) - i])
|
| 93 |
+
# seq_out output shape -> (D, N, C, W, H)
|
| 94 |
+
|
| 95 |
+
# forloop time step
|
| 96 |
+
seq_out = [self.conv1x1_list[i](h) for h in seq_out]
|
| 97 |
+
seq_out = [self.gblock_list[i](h) for h in seq_out]
|
| 98 |
+
seq_out = [self.upg_list[i](h) for h in seq_out]
|
| 99 |
+
# output: seq_out list dim -> D * [N, C, H, W]
|
| 100 |
+
# should stack at dim == 1 to become (N, D, C, H, W)
|
| 101 |
+
seq_out = torch.stack(seq_out, dim=1)
|
| 102 |
+
|
| 103 |
+
# final output
|
| 104 |
+
# forloop time step
|
| 105 |
+
output = []
|
| 106 |
+
for t in range(seq_out.shape[1]):
|
| 107 |
+
y = seq_out[:, t, :, :, :]
|
| 108 |
+
# y = self.bn(y)
|
| 109 |
+
y = self.relu(y)
|
| 110 |
+
y = self.last_conv1x1(y)
|
| 111 |
+
y = self.depth_to_space(y)
|
| 112 |
+
output.append(y)
|
| 113 |
+
|
| 114 |
+
output = torch.stack(output, dim=1)
|
| 115 |
+
|
| 116 |
+
return output
|
| 117 |
+
|
| 118 |
+
class LatentConditionStack(nn.Module):
|
| 119 |
+
def __init__(self, out_channels, down_step, sigma, attn=True):
|
| 120 |
+
"""
|
| 121 |
+
in_shape dims -> e.g. (8, 8) -> (H, W)
|
| 122 |
+
x) base_c is 1/96 of out_channels
|
| 123 |
+
base_c is set to 4
|
| 124 |
+
"""
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
self.down_step = down_step
|
| 128 |
+
self.base_c = out_channels // 96
|
| 129 |
+
if self.base_c < 4:
|
| 130 |
+
self.base_c = 4
|
| 131 |
+
|
| 132 |
+
self.out_channels = out_channels
|
| 133 |
+
self.attn = attn
|
| 134 |
+
|
| 135 |
+
# define the distribution
|
| 136 |
+
self.dist = normal.Normal(loc=0.0, scale=sigma)
|
| 137 |
+
|
| 138 |
+
self.conv3x3 = spectral_norm(
|
| 139 |
+
nn.Conv2d(
|
| 140 |
+
in_channels=self.base_c,
|
| 141 |
+
out_channels=self.base_c,
|
| 142 |
+
kernel_size=(3, 3),
|
| 143 |
+
padding=1
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
cc = self.base_c
|
| 148 |
+
self.l1 = LBlock(cc, cc * 3)
|
| 149 |
+
self.l2 = LBlock(cc * 3, cc * 6)
|
| 150 |
+
self.l3 = LBlock(cc * 6, cc * 24)
|
| 151 |
+
if self.attn:
|
| 152 |
+
self.attn = AttentionLayer(cc * 24, cc * 24)
|
| 153 |
+
self.l4 = LBlock(cc * 24, self.out_channels)
|
| 154 |
+
|
| 155 |
+
def forward(self, x, batch_size=1, z=None):
|
| 156 |
+
"""
|
| 157 |
+
x shape -> (batch_size, time, c, width, height)
|
| 158 |
+
"""
|
| 159 |
+
width = x.shape[3]
|
| 160 |
+
height = x.shape[4]
|
| 161 |
+
# shape after downstep
|
| 162 |
+
s_w = width // (2 * 2**self.down_step)
|
| 163 |
+
s_h = height // (2 * 2**self.down_step)
|
| 164 |
+
|
| 165 |
+
in_shape = [self.base_c] + [s_w, s_h]
|
| 166 |
+
|
| 167 |
+
target_shape = [batch_size] + in_shape
|
| 168 |
+
if z is None:
|
| 169 |
+
z = self.dist.sample(target_shape)
|
| 170 |
+
z = z.type_as(x)
|
| 171 |
+
|
| 172 |
+
# first conv
|
| 173 |
+
z = self.conv3x3(z)
|
| 174 |
+
|
| 175 |
+
# Lblock
|
| 176 |
+
z = self.l1(z)
|
| 177 |
+
z = self.l2(z)
|
| 178 |
+
z = self.l3(z)
|
| 179 |
+
if self.attn:
|
| 180 |
+
z = self.attn(z)
|
| 181 |
+
|
| 182 |
+
z = self.l4(z)
|
| 183 |
+
|
| 184 |
+
return z
|
| 185 |
+
|
| 186 |
+
# TODO: modification(Change the amount of parameters)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class ContextConditionStack(nn.Module):
|
| 190 |
+
def __init__(self,
|
| 191 |
+
in_channels: int = 1,
|
| 192 |
+
base_channels: int = 24,
|
| 193 |
+
down_step: int = 4,
|
| 194 |
+
prev_step: int = 4):
|
| 195 |
+
"""
|
| 196 |
+
base_channels: e.g. 24 -> output_channel: 384
|
| 197 |
+
output_channel: base_c*in_c*2**(down_step-2) * prev_step
|
| 198 |
+
down_step: int
|
| 199 |
+
prev_step: int
|
| 200 |
+
"""
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.in_channels = in_channels
|
| 203 |
+
self.down_step = down_step
|
| 204 |
+
self.prev_step = prev_step
|
| 205 |
+
###
|
| 206 |
+
base_c = base_channels
|
| 207 |
+
in_c = in_channels
|
| 208 |
+
|
| 209 |
+
# different scales channels
|
| 210 |
+
chs = [4 * in_c] + [base_c * in_c * 2 **
|
| 211 |
+
(i + 1) for i in range(down_step)]
|
| 212 |
+
|
| 213 |
+
self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
|
| 214 |
+
self.Dlist = nn.ModuleList()
|
| 215 |
+
self.convList = nn.ModuleList()
|
| 216 |
+
for i in range(down_step):
|
| 217 |
+
self.Dlist.append(
|
| 218 |
+
DBlock(in_channel=chs[i],
|
| 219 |
+
out_channel=chs[i + 1],
|
| 220 |
+
apply_relu=True, apply_down=True)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.convList.append(
|
| 224 |
+
spectral_norm(
|
| 225 |
+
nn.Conv2d(in_channels=prev_step * chs[i + 1],
|
| 226 |
+
out_channels=prev_step * chs[i + 1] // 4,
|
| 227 |
+
kernel_size=(3, 3),
|
| 228 |
+
padding=1)
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# ReLU
|
| 233 |
+
self.relu = nn.ReLU()
|
| 234 |
+
|
| 235 |
+
def forward(self,
|
| 236 |
+
x: torch.Tensor) -> Tuple[torch.Tensor,
|
| 237 |
+
torch.Tensor,
|
| 238 |
+
torch.Tensor,
|
| 239 |
+
torch.Tensor]:
|
| 240 |
+
"""
|
| 241 |
+
## input dims -> (N, D, C, H, W)
|
| 242 |
+
"""
|
| 243 |
+
x = self.space_to_depth(x)
|
| 244 |
+
tsteps = x.shape[1]
|
| 245 |
+
assert tsteps == self.prev_step
|
| 246 |
+
|
| 247 |
+
# different feature index represent different scale
|
| 248 |
+
# features
|
| 249 |
+
# [scale1 -> [t1, t2, t3, t4], scale2 -> [t1, t2, t3, t4], scale3 -> [....]]
|
| 250 |
+
features = [[] for i in range(tsteps)]
|
| 251 |
+
|
| 252 |
+
for st in range(tsteps):
|
| 253 |
+
in_x = x[:, st, :, :, :]
|
| 254 |
+
# in_x -> (Batch(N), C, H, W)
|
| 255 |
+
for scale in range(self.down_step):
|
| 256 |
+
in_x = self.Dlist[scale](in_x)
|
| 257 |
+
features[scale].append(in_x)
|
| 258 |
+
|
| 259 |
+
out_scale = []
|
| 260 |
+
for i, cc in enumerate(self.convList):
|
| 261 |
+
# after stacking, dims -> (Batch, Time, C, H, W)
|
| 262 |
+
# and mixing layer is to concat Time, C
|
| 263 |
+
stacked = self._mixing_layer(torch.stack(features[i], dim=1))
|
| 264 |
+
out = self.relu(cc(stacked))
|
| 265 |
+
out_scale.append(out)
|
| 266 |
+
|
| 267 |
+
return out_scale
|
| 268 |
+
|
| 269 |
+
def _mixing_layer(self, x):
|
| 270 |
+
# conver from (N, Time, C, H, W) -> (N, Time*C, H, W)
|
| 271 |
+
# Then apply Conv2d
|
| 272 |
+
stacked = einops.rearrange(x, "b t c h w -> b (t c) h w")
|
| 273 |
+
|
| 274 |
+
return stacked
|
model_architect/DGMR_SO/img_extractor.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from ..components.common import DBlock
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ImageExtractor(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
in_channels,
|
| 10 |
+
out_channels,
|
| 11 |
+
apply_down_flag,
|
| 12 |
+
down_step=4):
|
| 13 |
+
"""
|
| 14 |
+
in_c -> 1
|
| 15 |
+
x) base_c is 1/96 of out_channels
|
| 16 |
+
base_c is set to 4
|
| 17 |
+
"""
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.down_step = down_step
|
| 20 |
+
|
| 21 |
+
self.base_c = out_channels // 96
|
| 22 |
+
if self.base_c < 4:
|
| 23 |
+
self.base_c = 4
|
| 24 |
+
cc = self.base_c
|
| 25 |
+
|
| 26 |
+
self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
|
| 27 |
+
|
| 28 |
+
chs = [in_channels * 4, cc * 3, cc * 6, cc * 24, out_channels]
|
| 29 |
+
self.DList = nn.ModuleList()
|
| 30 |
+
for i in range(down_step):
|
| 31 |
+
self.DList.append(
|
| 32 |
+
DBlock(
|
| 33 |
+
in_channel=chs[i],
|
| 34 |
+
out_channel=chs[i + 1],
|
| 35 |
+
conv_type='2d',
|
| 36 |
+
apply_down=apply_down_flag[i]
|
| 37 |
+
),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
"""
|
| 42 |
+
x
|
| 43 |
+
"""
|
| 44 |
+
y = self.space_to_depth(x)
|
| 45 |
+
# forloop ImageExtractor
|
| 46 |
+
for i in range(self.down_step):
|
| 47 |
+
y = self.DList[i](y)
|
| 48 |
+
|
| 49 |
+
return y
|
model_architect/DGMR_SO/model.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .discriminator import TemporalDiscriminator, SpatialDiscriminator
|
| 4 |
+
from .generator import Sampler, ContextConditionStack, LatentConditionStack
|
| 5 |
+
from .img_extractor import ImageExtractor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Generator(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_channels,
|
| 12 |
+
base_channels,
|
| 13 |
+
down_step,
|
| 14 |
+
prev_step,
|
| 15 |
+
sigma
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
out_channels = base_channels * \
|
| 19 |
+
2**(down_step - 2) * prev_step * in_channels
|
| 20 |
+
|
| 21 |
+
self.latentStack = LatentConditionStack(
|
| 22 |
+
out_channels=out_channels,
|
| 23 |
+
down_step=down_step,
|
| 24 |
+
sigma=sigma
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.contextStack = ContextConditionStack(
|
| 28 |
+
in_channels=in_channels,
|
| 29 |
+
base_channels=base_channels,
|
| 30 |
+
down_step=down_step,
|
| 31 |
+
prev_step=prev_step
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.sampler = Sampler(
|
| 35 |
+
in_channels=in_channels,
|
| 36 |
+
base_channels=base_channels,
|
| 37 |
+
up_step=down_step
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.encode_time = nn.Linear(
|
| 41 |
+
4, base_channels * 2**(down_step) * in_channels)
|
| 42 |
+
|
| 43 |
+
self.topo_extractor = ImageExtractor(
|
| 44 |
+
in_channels=1,
|
| 45 |
+
out_channels=base_channels * 2**(down_step - 1) * in_channels,
|
| 46 |
+
apply_down_flag=[True, True, True, True],
|
| 47 |
+
down_step=down_step
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.nwp_extractor = ImageExtractor(
|
| 51 |
+
in_channels=1, # TODO: fixed now
|
| 52 |
+
out_channels=base_channels * 2**(down_step - 1) * in_channels,
|
| 53 |
+
apply_down_flag=[False, True, False, True],
|
| 54 |
+
down_step=down_step
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, x, x2, topo, datetime_feat, pred_step=36):
|
| 58 |
+
"""
|
| 59 |
+
x: input seq -> dims (N, D, C, H, W)
|
| 60 |
+
x2: input seq (WRF) -> dims (N, D, C, H, W)
|
| 61 |
+
topo: topography -> dims (N, 1, H=512, W=512)
|
| 62 |
+
datetime_feat -> dims (N, D, 4)
|
| 63 |
+
"""
|
| 64 |
+
context_inits = self.contextStack(x)
|
| 65 |
+
batch_size = context_inits[0].shape[0]
|
| 66 |
+
zlatent = self.latentStack(x, batch_size=batch_size)
|
| 67 |
+
|
| 68 |
+
# topo feature
|
| 69 |
+
topo_feat = self.topo_extractor(topo)
|
| 70 |
+
# encode time feature
|
| 71 |
+
time_feat = self.encode_time(datetime_feat)
|
| 72 |
+
# extract nwp feature
|
| 73 |
+
nwp_feat = []
|
| 74 |
+
# forloop x2
|
| 75 |
+
for i in range(x2.shape[1]):
|
| 76 |
+
nwp_ = self.nwp_extractor(x2[:, i, ...])
|
| 77 |
+
# concat topo and nwp feature
|
| 78 |
+
concat_feat = torch.cat((nwp_, topo_feat), dim=1)
|
| 79 |
+
nwp_feat.append(concat_feat)
|
| 80 |
+
nwp_feat = torch.stack(nwp_feat, dim=1)
|
| 81 |
+
fuse_feat = nwp_feat + time_feat.unsqueeze(-1).unsqueeze(-1)
|
| 82 |
+
|
| 83 |
+
pred = self.sampler(zlatent, fuse_feat, context_inits, pred_step)
|
| 84 |
+
|
| 85 |
+
return pred
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Discriminator(nn.Module):
|
| 89 |
+
def __init__(self, in_channels, base_channels):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.spatial = SpatialDiscriminator(
|
| 92 |
+
in_channel=in_channels, base_c=base_channels)
|
| 93 |
+
self.temporal = TemporalDiscriminator(
|
| 94 |
+
in_channel=in_channels, base_c=base_channels)
|
| 95 |
+
|
| 96 |
+
def forward(self, x, y):
|
| 97 |
+
"""
|
| 98 |
+
x -> dims (N, D, C, H, W) e.g. input_frames
|
| 99 |
+
y -> dims (N, D, C, H, W) e.g. output_grames
|
| 100 |
+
"""
|
| 101 |
+
spatial_out = self.spatial(y)
|
| 102 |
+
temporal_out = self.temporal(torch.cat([x, y], dim=1))
|
| 103 |
+
|
| 104 |
+
dis_out = torch.cat([spatial_out, temporal_out], dim=1)
|
| 105 |
+
|
| 106 |
+
return dis_out
|
model_architect/Generator_only/generator_clr_idx_wrf_topot.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn.utils.parametrizations import spectral_norm
|
| 6 |
+
import einops
|
| 7 |
+
|
| 8 |
+
from ..components.ConvGRU import ConvGRUCell
|
| 9 |
+
from ..components.common import GBlock, Up_GBlock, DBlock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Sampler(nn.Module):
|
| 13 |
+
def __init__(self, in_channels, base_channels=24, up_step=4):
|
| 14 |
+
"""
|
| 15 |
+
up_step should be the same as down_step in context-condition-stack
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
super().__init__()
|
| 19 |
+
base_c = base_channels
|
| 20 |
+
|
| 21 |
+
self.up_steps = up_step
|
| 22 |
+
self.convgru_list = nn.ModuleList()
|
| 23 |
+
self.conv1x1_list = nn.ModuleList()
|
| 24 |
+
self.gblock_list = nn.ModuleList()
|
| 25 |
+
self.upg_list = nn.ModuleList()
|
| 26 |
+
|
| 27 |
+
# image extractor
|
| 28 |
+
self.img_extractor = ImageExtractor(
|
| 29 |
+
in_channels=2,
|
| 30 |
+
out_channels=base_c * 2**(self.up_steps) * in_channels,
|
| 31 |
+
apply_down_flag=[True, True, True, True],
|
| 32 |
+
down_step=self.up_steps
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.nwp_extractor = ImageExtractor(
|
| 36 |
+
in_channels=1,
|
| 37 |
+
out_channels=base_c * 2**(self.up_steps) * in_channels,
|
| 38 |
+
apply_down_flag=[False, True, False, True],
|
| 39 |
+
down_step=self.up_steps
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.encode_time = nn.Linear(
|
| 43 |
+
4, base_c * 2**(self.up_steps + 1) * in_channels)
|
| 44 |
+
|
| 45 |
+
for i in range(self.up_steps):
|
| 46 |
+
# different scale
|
| 47 |
+
chs1 = base_c * 2**(self.up_steps - i + 1) * in_channels
|
| 48 |
+
chs2 = base_c * 2**(self.up_steps - i) * in_channels
|
| 49 |
+
# convgru
|
| 50 |
+
self.convgru_list.append(
|
| 51 |
+
ConvGRUCell(chs1, chs2, 3)
|
| 52 |
+
)
|
| 53 |
+
# conv1x1
|
| 54 |
+
self.conv1x1_list.append(
|
| 55 |
+
spectral_norm(
|
| 56 |
+
nn.Conv2d(
|
| 57 |
+
in_channels=chs2,
|
| 58 |
+
out_channels=chs1,
|
| 59 |
+
kernel_size=(
|
| 60 |
+
1,
|
| 61 |
+
1))))
|
| 62 |
+
# GBlock
|
| 63 |
+
self.gblock_list.append(
|
| 64 |
+
GBlock(in_channel=chs1, out_channel=chs1)
|
| 65 |
+
)
|
| 66 |
+
# upgblock
|
| 67 |
+
self.upg_list.append(
|
| 68 |
+
Up_GBlock(in_channel=chs1)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# output
|
| 72 |
+
# self.bn = nn.BatchNorm2d(chs2)
|
| 73 |
+
self.relu = nn.ReLU()
|
| 74 |
+
self.last_conv1x1 = spectral_norm(
|
| 75 |
+
nn.Conv2d(in_channels=chs2,
|
| 76 |
+
out_channels=4,
|
| 77 |
+
kernel_size=(1, 1))
|
| 78 |
+
)
|
| 79 |
+
self.depth_to_space = nn.PixelShuffle(upscale_factor=2)
|
| 80 |
+
|
| 81 |
+
def forward(
|
| 82 |
+
self,
|
| 83 |
+
input_img,
|
| 84 |
+
nwp_inputs,
|
| 85 |
+
topo,
|
| 86 |
+
time_feat,
|
| 87 |
+
init_states,
|
| 88 |
+
pred_step,
|
| 89 |
+
thres=None):
|
| 90 |
+
"""
|
| 91 |
+
input_img dim -> (N, Tstep, C, W, H) -> Tstep can be 1 or pred_steps
|
| 92 |
+
nwp_inputs dim -> (N, Tstep, C, W, H)
|
| 93 |
+
init_states dim -> [(N, C, W, H)-1, (N, C, W, H)-2, ...]
|
| 94 |
+
probs -> (tsteps)
|
| 95 |
+
"""
|
| 96 |
+
hh = init_states
|
| 97 |
+
output = []
|
| 98 |
+
img_t_len = input_img.shape[1]
|
| 99 |
+
|
| 100 |
+
xx = None
|
| 101 |
+
for t in range(pred_step):
|
| 102 |
+
time_emb = self.encode_time(time_feat[:, t]) # time emb
|
| 103 |
+
|
| 104 |
+
# The 1st tstep image should be truth
|
| 105 |
+
# and 2nd tstep image apply schedule sampling
|
| 106 |
+
if t == 0:
|
| 107 |
+
input_ = input_img[:, 0, :, :, :]
|
| 108 |
+
else:
|
| 109 |
+
if thres is not None and img_t_len > 1:
|
| 110 |
+
# use groud truth
|
| 111 |
+
if torch.rand(1) < thres:
|
| 112 |
+
input_ = input_img[:, t, :, :, :]
|
| 113 |
+
else:
|
| 114 |
+
input_ = xx
|
| 115 |
+
else:
|
| 116 |
+
input_ = xx
|
| 117 |
+
|
| 118 |
+
nwp_in = nwp_inputs[:, t]
|
| 119 |
+
# image extractor -> extract T step image
|
| 120 |
+
# xx is the output of image extractor
|
| 121 |
+
input_ = torch.cat((input_, topo), dim=1)
|
| 122 |
+
xx = self.img_extractor(input_)
|
| 123 |
+
nwp_feat = self.nwp_extractor(nwp_in)
|
| 124 |
+
xx = torch.cat((xx, nwp_feat), dim=1)
|
| 125 |
+
|
| 126 |
+
xx = xx + time_emb[:, :, None, None] # add time embedding
|
| 127 |
+
|
| 128 |
+
for up in range(self.up_steps):
|
| 129 |
+
# convGRU
|
| 130 |
+
# init_states should be reversed
|
| 131 |
+
h_index = (self.up_steps - 1) - up
|
| 132 |
+
xx, out_hh = self.convgru_list[up](xx, hh[h_index])
|
| 133 |
+
hh[h_index] = out_hh
|
| 134 |
+
# conv1x1
|
| 135 |
+
xx = self.conv1x1_list[up](xx)
|
| 136 |
+
# gblock
|
| 137 |
+
xx = self.gblock_list[up](xx)
|
| 138 |
+
# upg_list
|
| 139 |
+
xx = self.upg_list[up](xx)
|
| 140 |
+
|
| 141 |
+
# xx = self.bn(xx)
|
| 142 |
+
xx = self.relu(xx)
|
| 143 |
+
xx = self.last_conv1x1(xx)
|
| 144 |
+
xx = self.depth_to_space(xx)
|
| 145 |
+
|
| 146 |
+
# prediction
|
| 147 |
+
output.append(xx)
|
| 148 |
+
|
| 149 |
+
output = torch.stack(output, dim=1)
|
| 150 |
+
|
| 151 |
+
return output
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ImageExtractor(nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
in_channels,
|
| 158 |
+
out_channels,
|
| 159 |
+
apply_down_flag,
|
| 160 |
+
down_step=4):
|
| 161 |
+
"""
|
| 162 |
+
in_c -> 1
|
| 163 |
+
x) base_c is 1/96 of out_channels
|
| 164 |
+
base_c is set to 4
|
| 165 |
+
"""
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.down_step = down_step
|
| 168 |
+
|
| 169 |
+
self.base_c = out_channels // 96
|
| 170 |
+
if self.base_c < 4:
|
| 171 |
+
self.base_c = 4
|
| 172 |
+
cc = self.base_c
|
| 173 |
+
|
| 174 |
+
self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
|
| 175 |
+
|
| 176 |
+
chs = [in_channels * 4, cc * 3, cc * 6, cc * 24, out_channels]
|
| 177 |
+
self.DList = nn.ModuleList()
|
| 178 |
+
for i in range(down_step):
|
| 179 |
+
self.DList.append(
|
| 180 |
+
DBlock(
|
| 181 |
+
in_channel=chs[i],
|
| 182 |
+
out_channel=chs[i + 1],
|
| 183 |
+
conv_type='2d',
|
| 184 |
+
apply_down=apply_down_flag[i]
|
| 185 |
+
),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
y = self.space_to_depth(x)
|
| 190 |
+
# forloop ImageExtractor
|
| 191 |
+
for i in range(self.down_step):
|
| 192 |
+
y = self.DList[i](y)
|
| 193 |
+
|
| 194 |
+
return y
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class ContextConditionStack(nn.Module):
|
| 198 |
+
def __init__(self,
|
| 199 |
+
in_channels: int = 1,
|
| 200 |
+
base_channels: int = 24,
|
| 201 |
+
down_step: int = 4,
|
| 202 |
+
prev_step: int = 4):
|
| 203 |
+
"""
|
| 204 |
+
base_channels: e.g. 24 -> output_channel: 384
|
| 205 |
+
output_channel: base_c*in_c*2**(down_step-2) * prev_step
|
| 206 |
+
down_step: int
|
| 207 |
+
prev_step: int
|
| 208 |
+
"""
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.in_channels = in_channels
|
| 211 |
+
self.down_step = down_step
|
| 212 |
+
self.prev_step = prev_step
|
| 213 |
+
###
|
| 214 |
+
base_c = base_channels
|
| 215 |
+
in_c = in_channels
|
| 216 |
+
|
| 217 |
+
# different scales channels
|
| 218 |
+
chs = [4 * in_c] + [base_c * in_c * 2 **
|
| 219 |
+
(i + 1) for i in range(down_step)]
|
| 220 |
+
|
| 221 |
+
self.space_to_depth = nn.PixelUnshuffle(downscale_factor=2)
|
| 222 |
+
self.Dlist = nn.ModuleList()
|
| 223 |
+
self.convList = nn.ModuleList()
|
| 224 |
+
for i in range(down_step):
|
| 225 |
+
self.Dlist.append(
|
| 226 |
+
DBlock(in_channel=chs[i],
|
| 227 |
+
out_channel=chs[i + 1],
|
| 228 |
+
apply_relu=True, apply_down=True)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
self.convList.append(
|
| 232 |
+
spectral_norm(
|
| 233 |
+
nn.Conv2d(in_channels=prev_step * chs[i + 1],
|
| 234 |
+
out_channels=prev_step * chs[i + 1] // 4,
|
| 235 |
+
kernel_size=(3, 3),
|
| 236 |
+
padding=1)
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# ReLU
|
| 241 |
+
self.relu = nn.ReLU()
|
| 242 |
+
|
| 243 |
+
def forward(self,
|
| 244 |
+
x: torch.Tensor) -> Tuple[torch.Tensor,
|
| 245 |
+
torch.Tensor,
|
| 246 |
+
torch.Tensor,
|
| 247 |
+
torch.Tensor]:
|
| 248 |
+
"""
|
| 249 |
+
## input dims -> (N, D, C, H, W)
|
| 250 |
+
"""
|
| 251 |
+
x = self.space_to_depth(x)
|
| 252 |
+
tsteps = x.shape[1]
|
| 253 |
+
assert tsteps == self.prev_step
|
| 254 |
+
|
| 255 |
+
# different feature index represent different scale
|
| 256 |
+
# features
|
| 257 |
+
# [scale1 -> [t1, t2, t3, t4], scale2 -> [t1, t2, t3, t4], scale3 -> [....]]
|
| 258 |
+
features = [[] for i in range(tsteps)]
|
| 259 |
+
|
| 260 |
+
for st in range(tsteps):
|
| 261 |
+
in_x = x[:, st, :, :, :]
|
| 262 |
+
# in_x -> (Batch(N), C, H, W)
|
| 263 |
+
for scale in range(self.down_step):
|
| 264 |
+
in_x = self.Dlist[scale](in_x)
|
| 265 |
+
features[scale].append(in_x)
|
| 266 |
+
|
| 267 |
+
out_scale = []
|
| 268 |
+
for i, cc in enumerate(self.convList):
|
| 269 |
+
# after stacking, dims -> (Batch, Time, C, H, W)
|
| 270 |
+
# and mixing layer is to concat Time, C
|
| 271 |
+
stacked = self._mixing_layer(torch.stack(features[i], dim=1))
|
| 272 |
+
out = self.relu(cc(stacked))
|
| 273 |
+
out_scale.append(out)
|
| 274 |
+
|
| 275 |
+
return out_scale
|
| 276 |
+
|
| 277 |
+
def _mixing_layer(self, x):
|
| 278 |
+
# conver from (N, Time, C, H, W) -> (N, Time*C, H, W)
|
| 279 |
+
# Then apply Conv2d
|
| 280 |
+
stacked = einops.rearrange(x, "b t c h w -> b (t c) h w")
|
| 281 |
+
|
| 282 |
+
return stacked
|
model_architect/Generator_only/model_clr_idx.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .generator_clr_idx_wrf_topot import Sampler, ContextConditionStack
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Generator(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
in_channels,
|
| 10 |
+
base_channels,
|
| 11 |
+
down_step,
|
| 12 |
+
prev_step
|
| 13 |
+
):
|
| 14 |
+
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.contextStack = ContextConditionStack(
|
| 17 |
+
in_channels=in_channels,
|
| 18 |
+
base_channels=base_channels,
|
| 19 |
+
down_step=down_step,
|
| 20 |
+
prev_step=prev_step
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.sampler = Sampler(
|
| 24 |
+
in_channels=in_channels,
|
| 25 |
+
base_channels=base_channels,
|
| 26 |
+
up_step=down_step,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(
|
| 30 |
+
self,
|
| 31 |
+
x,
|
| 32 |
+
x2,
|
| 33 |
+
topo,
|
| 34 |
+
time_feat,
|
| 35 |
+
pred_step=12,
|
| 36 |
+
y=None,
|
| 37 |
+
thres=None
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
x: input seq -> dims (N, T, C, H, W)
|
| 41 |
+
x2: input seq for WRF -> dims (N, T, C, H, W)
|
| 42 |
+
"""
|
| 43 |
+
context_inits = self.contextStack(x)
|
| 44 |
+
if y is None:
|
| 45 |
+
y = x[:, -1:, :, :, :]
|
| 46 |
+
|
| 47 |
+
pred = self.sampler(
|
| 48 |
+
y,
|
| 49 |
+
x2,
|
| 50 |
+
topo,
|
| 51 |
+
time_feat,
|
| 52 |
+
context_inits,
|
| 53 |
+
pred_step,
|
| 54 |
+
thres
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return pred
|
model_architect/__init__.py
ADDED
|
File without changes
|
model_architect/components/ConvGRU.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.nn.utils.parametrizations import spectral_norm
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConvGRUCell(torch.nn.Module):
|
| 8 |
+
def __init__(self, in_channel, out_channel, kernel_size=3):
|
| 9 |
+
super().__init__()
|
| 10 |
+
padding = kernel_size // 2
|
| 11 |
+
self.out_channel = out_channel
|
| 12 |
+
|
| 13 |
+
self.conv1 = spectral_norm(
|
| 14 |
+
torch.nn.Conv2d(
|
| 15 |
+
in_channels=in_channel + out_channel,
|
| 16 |
+
out_channels=2 * out_channel,
|
| 17 |
+
kernel_size=kernel_size,
|
| 18 |
+
padding=padding
|
| 19 |
+
)
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.conv2 = spectral_norm(
|
| 23 |
+
torch.nn.Conv2d(
|
| 24 |
+
in_channels=in_channel + out_channel,
|
| 25 |
+
out_channels=out_channel,
|
| 26 |
+
kernel_size=kernel_size,
|
| 27 |
+
padding=padding
|
| 28 |
+
)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x, h_st):
|
| 32 |
+
"""
|
| 33 |
+
x -> dim (Batch, channels*2, width, height)
|
| 34 |
+
h_st -> dim (Batch, channels, width, height)
|
| 35 |
+
"""
|
| 36 |
+
x_shape = x.shape
|
| 37 |
+
h_shape = h_st.shape
|
| 38 |
+
# resize width, height
|
| 39 |
+
if h_shape[2] > x_shape[2]:
|
| 40 |
+
w_l = 1
|
| 41 |
+
else:
|
| 42 |
+
w_l = 0
|
| 43 |
+
|
| 44 |
+
if h_shape[3] > x_shape[3]:
|
| 45 |
+
h_b = 1
|
| 46 |
+
else:
|
| 47 |
+
h_b = 0
|
| 48 |
+
|
| 49 |
+
x = F.pad(x, (0, h_b, w_l, 0), "reflect")
|
| 50 |
+
|
| 51 |
+
# print(x.shape, h_st.shape)
|
| 52 |
+
xx = torch.cat([x, h_st], dim=1)
|
| 53 |
+
xx = self.conv1(xx)
|
| 54 |
+
gamma, beta = torch.split(xx, self.out_channel, dim=1)
|
| 55 |
+
|
| 56 |
+
reset_gate = torch.sigmoid(gamma)
|
| 57 |
+
update_gate = torch.sigmoid(beta)
|
| 58 |
+
|
| 59 |
+
out = torch.cat([x, h_st * reset_gate], dim=1)
|
| 60 |
+
out = torch.tanh(self.conv2(out))
|
| 61 |
+
|
| 62 |
+
out = (1 - update_gate) * out + h_st * update_gate
|
| 63 |
+
new_st = out
|
| 64 |
+
|
| 65 |
+
return out, new_st
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ConvGRU(torch.nn.Module):
|
| 69 |
+
def __init__(self, in_channel, out_channel, kernel_size):
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
self.out_channel = out_channel
|
| 73 |
+
self.convgru_cell = ConvGRUCell(in_channel, out_channel, kernel_size)
|
| 74 |
+
|
| 75 |
+
def _get_init_state(self, batch_size, imd_w, imd_h, dtype):
|
| 76 |
+
state = Variable(
|
| 77 |
+
torch.zeros(
|
| 78 |
+
batch_size,
|
| 79 |
+
self.out_channel,
|
| 80 |
+
self.h,
|
| 81 |
+
self.w)).type(dtype)
|
| 82 |
+
|
| 83 |
+
return state
|
| 84 |
+
|
| 85 |
+
def forward(self, x_sequence, init_hidden=None):
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
x_sequence shape -> (batch_size, time, c, width, height)
|
| 89 |
+
Return:
|
| 90 |
+
outputs shape -> (time, batch_size, c, width, height)
|
| 91 |
+
"""
|
| 92 |
+
seq_len = x_sequence.shape[1]
|
| 93 |
+
|
| 94 |
+
img_w = x_sequence.shape[3]
|
| 95 |
+
img_h = x_sequence.shape[4]
|
| 96 |
+
|
| 97 |
+
dtype = x_sequence.type()
|
| 98 |
+
if init_hidden is None:
|
| 99 |
+
hidden_state = self._get_init_state(
|
| 100 |
+
x_sequence.shape[0], img_w, img_h, dtype)
|
| 101 |
+
else:
|
| 102 |
+
hidden_state = init_hidden
|
| 103 |
+
|
| 104 |
+
out_list = []
|
| 105 |
+
for t in range(seq_len):
|
| 106 |
+
out, hidden_state = self.convgru_cell(
|
| 107 |
+
x_sequence[:, t, :, :, :], hidden_state)
|
| 108 |
+
out_list.append(out)
|
| 109 |
+
|
| 110 |
+
outputs = torch.stack(out_list, dim=0)
|
| 111 |
+
|
| 112 |
+
return outputs
|
model_architect/components/common.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.utils.parametrizations import spectral_norm
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import einops
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GBlock(torch.nn.Module):
|
| 8 |
+
def __init__(self, in_channel: int, out_channel: int):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.in_channel = in_channel
|
| 11 |
+
self.out_channel = out_channel
|
| 12 |
+
|
| 13 |
+
# TODO: close batch
|
| 14 |
+
# self.bn1 = torch.nn.BatchNorm2d(in_channel)
|
| 15 |
+
# self.bn2 = torch.nn.BatchNorm2d(out_channel)
|
| 16 |
+
|
| 17 |
+
self.relu = torch.nn.ReLU()
|
| 18 |
+
# conv1x1
|
| 19 |
+
self.conv1x1 = spectral_norm(
|
| 20 |
+
torch.nn.Conv2d(in_channel, out_channel, kernel_size=1)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.conv3x3_1 = spectral_norm(
|
| 24 |
+
torch.nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
|
| 25 |
+
)
|
| 26 |
+
self.conv3x3_2 = spectral_norm(
|
| 27 |
+
torch.nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
# if shape is different then applied
|
| 32 |
+
if x.shape[1] != self.out_channel:
|
| 33 |
+
res = self.conv1x1(x)
|
| 34 |
+
else:
|
| 35 |
+
res = x.clone()
|
| 36 |
+
|
| 37 |
+
# first
|
| 38 |
+
# x = self.bn1(x)
|
| 39 |
+
x = self.relu(x)
|
| 40 |
+
x = self.conv3x3_1(x)
|
| 41 |
+
# second
|
| 42 |
+
# x = self.bn2(x)
|
| 43 |
+
x = self.relu(x)
|
| 44 |
+
x = self.conv3x3_2(x)
|
| 45 |
+
|
| 46 |
+
y = x + res
|
| 47 |
+
|
| 48 |
+
return y
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Up_GBlock(torch.nn.Module):
|
| 52 |
+
def __init__(self, in_channel: int):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.in_channel = in_channel
|
| 55 |
+
self.out_channel = int(in_channel / 2)
|
| 56 |
+
|
| 57 |
+
# TODO: close batch
|
| 58 |
+
# self.bn1 = torch.nn.BatchNorm2d(in_channel)
|
| 59 |
+
# self.bn2 = torch.nn.BatchNorm2d(self.out_channel)
|
| 60 |
+
# self.bn2 = torch.nn.BatchNorm2d(in_channel)
|
| 61 |
+
|
| 62 |
+
self.relu = torch.nn.ReLU()
|
| 63 |
+
self.up = torch.nn.Upsample(scale_factor=2, mode='nearest')
|
| 64 |
+
|
| 65 |
+
self.conv1x1 = spectral_norm(
|
| 66 |
+
torch.nn.Conv2d(in_channel, self.out_channel, kernel_size=1)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.conv3x3_1 = spectral_norm(
|
| 70 |
+
torch.nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1)
|
| 71 |
+
)
|
| 72 |
+
self.conv3x3_2 = spectral_norm(
|
| 73 |
+
torch.nn.Conv2d(
|
| 74 |
+
in_channel,
|
| 75 |
+
self.out_channel,
|
| 76 |
+
kernel_size=3,
|
| 77 |
+
padding=1))
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
res = self.up(x)
|
| 81 |
+
res = self.conv1x1(res)
|
| 82 |
+
|
| 83 |
+
# x = self.bn1(x)
|
| 84 |
+
x = self.relu(x)
|
| 85 |
+
x = self.up(x)
|
| 86 |
+
x = self.conv3x3_1(x)
|
| 87 |
+
|
| 88 |
+
# x = self.bn2(x)
|
| 89 |
+
x = self.relu(x)
|
| 90 |
+
x = self.conv3x3_2(x)
|
| 91 |
+
|
| 92 |
+
y = x + res
|
| 93 |
+
|
| 94 |
+
return y
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class DBlock(torch.nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
in_channel: int,
|
| 101 |
+
out_channel: int,
|
| 102 |
+
conv_type='2d',
|
| 103 |
+
apply_relu=True,
|
| 104 |
+
apply_down=False):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.in_channel = in_channel
|
| 107 |
+
self.out_channel = out_channel
|
| 108 |
+
self.apply_relu = apply_relu
|
| 109 |
+
self.apply_down = apply_down
|
| 110 |
+
|
| 111 |
+
# construct layer
|
| 112 |
+
if conv_type == '2d':
|
| 113 |
+
self.avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)
|
| 114 |
+
conv = torch.nn.Conv2d
|
| 115 |
+
elif conv_type == '3d':
|
| 116 |
+
self.avg_pool = torch.nn.AvgPool3d(kernel_size=2, stride=2)
|
| 117 |
+
conv = torch.nn.Conv3d
|
| 118 |
+
|
| 119 |
+
self.relu = torch.nn.ReLU()
|
| 120 |
+
self.conv1x1 = spectral_norm(
|
| 121 |
+
conv(in_channel, out_channel, kernel_size=1)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.conv3x3_1 = spectral_norm(
|
| 125 |
+
conv(in_channel, out_channel, kernel_size=3, padding=1)
|
| 126 |
+
)
|
| 127 |
+
self.conv3x3_2 = spectral_norm(
|
| 128 |
+
conv(out_channel, out_channel, kernel_size=3, padding=1)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
# Residual block
|
| 133 |
+
if x.shape[1] != self.out_channel:
|
| 134 |
+
res = self.conv1x1(x)
|
| 135 |
+
else:
|
| 136 |
+
res = x.clone()
|
| 137 |
+
if self.apply_down:
|
| 138 |
+
res = self.avg_pool(res)
|
| 139 |
+
|
| 140 |
+
##
|
| 141 |
+
if self.apply_relu:
|
| 142 |
+
x = self.relu(x)
|
| 143 |
+
x = self.conv3x3_1(x)
|
| 144 |
+
x = self.relu(x)
|
| 145 |
+
x = self.conv3x3_2(x)
|
| 146 |
+
if self.apply_down:
|
| 147 |
+
x = self.avg_pool(x)
|
| 148 |
+
|
| 149 |
+
# connect
|
| 150 |
+
y = res + x
|
| 151 |
+
|
| 152 |
+
return y
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class LBlock(torch.nn.Module):
|
| 156 |
+
def __init__(self, in_channel, out_channel):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.in_channel = in_channel
|
| 159 |
+
self.out_channel = out_channel
|
| 160 |
+
|
| 161 |
+
self.relu = torch.nn.ReLU()
|
| 162 |
+
conv = torch.nn.Conv2d
|
| 163 |
+
self.conv1x1 = conv(
|
| 164 |
+
in_channel,
|
| 165 |
+
(out_channel - in_channel),
|
| 166 |
+
kernel_size=1)
|
| 167 |
+
|
| 168 |
+
self.conv3x3_1 = conv(in_channel, in_channel, kernel_size=3, padding=1)
|
| 169 |
+
self.conv3x3_2 = conv(
|
| 170 |
+
in_channel,
|
| 171 |
+
out_channel,
|
| 172 |
+
kernel_size=3,
|
| 173 |
+
padding=1)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
res = torch.cat([x, self.conv1x1(x)], dim=1)
|
| 177 |
+
|
| 178 |
+
x = self.relu(x)
|
| 179 |
+
x = self.conv3x3_1(x)
|
| 180 |
+
x = self.relu(x)
|
| 181 |
+
x = self.conv3x3_2(x)
|
| 182 |
+
|
| 183 |
+
y = x + res
|
| 184 |
+
|
| 185 |
+
return y
|
| 186 |
+
|
| 187 |
+
# Attention Layer
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def attention_einsum(q, k, v):
|
| 191 |
+
"""
|
| 192 |
+
Apply self-attention to tensors
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
# Reshape 3D tensor to 2D tensor with first dimension L = h x w
|
| 196 |
+
k = einops.rearrange(k, "h w c -> (h w) c") # [h, w, c] -> [L, c]
|
| 197 |
+
v = einops.rearrange(v, "h w c -> (h w) c") # [h, w, c] -> [L, c]
|
| 198 |
+
|
| 199 |
+
# Einstein summation corresponding to the query * key operation.
|
| 200 |
+
beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1)
|
| 201 |
+
|
| 202 |
+
# Einstein summation corresponding to the attention * value operation.
|
| 203 |
+
out = torch.einsum("hwL, Lc->hwc", beta, v)
|
| 204 |
+
|
| 205 |
+
return out
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class AttentionLayer(torch.nn.Module):
|
| 209 |
+
def __init__(self, in_channel, out_channel, ratio_kq=8, ratio_v=8):
|
| 210 |
+
super().__init__()
|
| 211 |
+
|
| 212 |
+
self.ratio_kq = ratio_kq
|
| 213 |
+
self.ratio_v = ratio_v
|
| 214 |
+
self.in_channel = in_channel
|
| 215 |
+
self.out_channel = out_channel
|
| 216 |
+
|
| 217 |
+
# compute query, key, and value using 1x1 convolution
|
| 218 |
+
self.query = torch.nn.Conv2d(
|
| 219 |
+
in_channel,
|
| 220 |
+
out_channel // ratio_kq,
|
| 221 |
+
kernel_size=1,
|
| 222 |
+
bias=False
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.key = torch.nn.Conv2d(
|
| 226 |
+
in_channel,
|
| 227 |
+
out_channel // ratio_kq,
|
| 228 |
+
kernel_size=1,
|
| 229 |
+
bias=False
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.value = torch.nn.Conv2d(
|
| 233 |
+
in_channel,
|
| 234 |
+
out_channel // ratio_v,
|
| 235 |
+
kernel_size=1,
|
| 236 |
+
bias=False
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.conv = torch.nn.Conv2d(
|
| 240 |
+
out_channel // 8,
|
| 241 |
+
out_channel,
|
| 242 |
+
kernel_size=1,
|
| 243 |
+
bias=False
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.gamma = torch.nn.Parameter(torch.zeros(1))
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
query = self.query(x)
|
| 250 |
+
key = self.key(x)
|
| 251 |
+
value = self.value(x)
|
| 252 |
+
# apply attention
|
| 253 |
+
out = []
|
| 254 |
+
for i in range(x.shape[0]):
|
| 255 |
+
out.append(attention_einsum(query[i], key[i], value[i]))
|
| 256 |
+
|
| 257 |
+
out = torch.stack(out, dim=0)
|
| 258 |
+
out = self.gamma * self.conv(out)
|
| 259 |
+
out = out + x # skip connection
|
| 260 |
+
|
| 261 |
+
return out
|
model_architect/inference_model.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .DGMR_SO.model import Generator as DGMR_SO
|
| 4 |
+
from .Generator_only.model_clr_idx import Generator as Generator_only
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Predictor(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
model_type,
|
| 11 |
+
):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
if model_type == 'DGMR_SO':
|
| 15 |
+
self.generator = DGMR_SO(
|
| 16 |
+
in_channels=1,
|
| 17 |
+
base_channels=24,
|
| 18 |
+
down_step=4,
|
| 19 |
+
prev_step=4,
|
| 20 |
+
sigma=1
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
elif model_type == 'Generator_only':
|
| 24 |
+
self.generator = Generator_only(
|
| 25 |
+
in_channels=1,
|
| 26 |
+
base_channels=24,
|
| 27 |
+
down_step=4,
|
| 28 |
+
prev_step=4,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x, x2, topo, datetime_feat, pred_step=36):
|
| 32 |
+
"""
|
| 33 |
+
x: input seq -> dims (N, D, C, H, W)
|
| 34 |
+
x2: input seq (WRF) -> dims (N, D, C, H, W)
|
| 35 |
+
topo: topography -> dims (N, 1, H=512, W=512)
|
| 36 |
+
datetime_feat -> dims (N, D, 4)
|
| 37 |
+
"""
|
| 38 |
+
pred = self.generator(x, x2, topo, datetime_feat, pred_step=pred_step)
|
| 39 |
+
|
| 40 |
+
return pred
|
model_weights/DGMR_SO/ft36/weights.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec17b13cb466248e335803a1aac17e9246ab24ea0518d609bcd0fcc04cd1f928
|
| 3 |
+
size 215336260
|
model_weights/Generator_only/ft36/weights.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:87e230402e7f1c8f1f65af63d994856bf6f5eb637a37fd613c4c83fb6f194dc1
|
| 3 |
+
size 222876572
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch==2.4.0
|
| 3 |
+
einops==0.8.0
|
sample_data/sample_202504131100.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6117356ac780a530645e192cc85d647b103c915b63433441d810dedc7cdd4ec1
|
| 3 |
+
size 33002900
|
sample_data/sample_202504161200.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7b9e6c7ed76f695f7c6f2f1a976f4c27128120e5c4328223809c27dc8feee52
|
| 3 |
+
size 33300209
|
sample_data/sample_202507151200.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e20eb69ea2bef6bb0074c376afa7e8398e7da8eb1edd9a1f11c343ffe711a299
|
| 3 |
+
size 33038261
|