csaybar commited on
Commit
ea2435e
·
verified ·
1 Parent(s): 8306454

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. example_data.safetensor +3 -0
  3. load.py +89 -0
  4. mlm.json +203 -0
  5. model.py +348 -0
  6. model.safetensor +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_data.safetensor filter=lfs diff=lfs merge=lfs -text
37
+ model.safetensor filter=lfs diff=lfs merge=lfs -text
example_data.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:958673beb8f61c0b11dd680340914baf862ef4d2b46876d9cc785b3c945fbbab
3
+ size 1310816
load.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import pathlib
3
+
4
+ import matplotlib.pyplot as plt
5
+ import safetensors.torch
6
+ import torch
7
+
8
+
9
+ def load_model_module(model_path: pathlib.Path):
10
+ model_path = model_path.resolve()
11
+
12
+ spec = importlib.util.spec_from_file_location("model", model_path)
13
+ model = importlib.util.module_from_spec(spec)
14
+ spec.loader.exec_module(model)
15
+
16
+ return model
17
+
18
+
19
+ class GeneratorNormal(torch.nn.Module):
20
+ def __init__(self, model):
21
+ super(GeneratorNormal, self).__init__()
22
+ self.model = model
23
+
24
+ def forward(self, X):
25
+ X = torch.clamp(X, 0, 1)
26
+ X = torch.nan_to_num(X, nan=1.0)
27
+ X = X.permute(0, 2, 3, 4, 1).contiguous()
28
+ return self.model(X)[0].permute(0, 4, 1, 2, 3)
29
+
30
+
31
+ def example_data(path: pathlib.Path, *args, **kwargs):
32
+ data_f = path / "example_data.safetensor"
33
+ return safetensors.torch.load_file(data_f)["example_data"][None].float()
34
+
35
+
36
+ def trainable_model(path, device="cpu", *args, **kwargs):
37
+ weights = safetensors.torch.load_file(path / "model.safetensor")
38
+ model = load_model_module(path / "model.py").Generator(
39
+ device=device, inputChannels=4, outputChannels=4
40
+ )
41
+ model.load_state_dict(weights)
42
+ return model
43
+
44
+
45
+ def compiled_model(path, device="cpu", *args, **kwargs):
46
+ weights = safetensors.torch.load_file(path / "model.safetensor")
47
+ model = load_model_module(path / "model.py").Generator(
48
+ device=device, inputChannels=4, outputChannels=4
49
+ )
50
+ model.load_state_dict(weights)
51
+ model = model.eval()
52
+ for param in model.parameters():
53
+ param.requires_grad = False
54
+ return GeneratorNormal(model.to(device))
55
+
56
+
57
+ def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
58
+ # Load model
59
+ model = compiled_model(path, device)
60
+
61
+ # Load data
62
+ s2_ts = example_data(path)
63
+
64
+ # Run model
65
+ gap_filled = model(s2_ts.to(device))
66
+
67
+ # Convert to CPU and detach for plotting
68
+ s2_ts = s2_ts.squeeze(0).detach().cpu() # [T, C, H, W]
69
+ gap_filled = gap_filled.squeeze(0).detach().cpu()
70
+
71
+ num_timesteps = s2_ts.shape[0]
72
+ rgb_indices = [2, 1, 0] # Assuming RGB is BGR in channel order (4 bands)
73
+
74
+ fig, axs = plt.subplots(2, num_timesteps, figsize=(3 * num_timesteps, 6))
75
+
76
+ for t in range(num_timesteps):
77
+ original_rgb = s2_ts[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy()
78
+ filled_rgb = gap_filled[t, rgb_indices].permute(1, 2, 0).clamp(0, 1).numpy()
79
+
80
+ axs[0, t].imshow(original_rgb * 3)
81
+ axs[0, t].axis("off")
82
+ axs[0, t].set_title(f"Original t={t}")
83
+
84
+ axs[1, t].imshow(filled_rgb * 3)
85
+ axs[1, t].axis("off")
86
+ axs[1, t].set_title(f"Filled t={t}")
87
+
88
+ plt.tight_layout()
89
+ return fig
mlm.json ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.4.0/schema.json"
6
+ ],
7
+ "id": "UNetMobV2_V2 model",
8
+ "geometry": {
9
+ "type": "Polygon",
10
+ "coordinates": [
11
+ [
12
+ [
13
+ -180.0,
14
+ -90.0
15
+ ],
16
+ [
17
+ -180.0,
18
+ 90.0
19
+ ],
20
+ [
21
+ 180.0,
22
+ 90.0
23
+ ],
24
+ [
25
+ 180.0,
26
+ -90.0
27
+ ],
28
+ [
29
+ -180.0,
30
+ -90.0
31
+ ]
32
+ ]
33
+ ]
34
+ },
35
+ "bbox": [
36
+ -180,
37
+ -90,
38
+ 180,
39
+ 90
40
+ ],
41
+ "properties": {
42
+ "start_datetime": "1900-01-01T00:00:00Z",
43
+ "end_datetime": "9999-01-01T00:00:00Z",
44
+ "description": "A UNet model trained on Sentinel-2 imagery for cloud segmentation.",
45
+ "forward_backward_pass": {
46
+ "32": 3.229184,
47
+ "64": 12.916736,
48
+ "128": 51.666944,
49
+ "256": 206.667776,
50
+ "512": 826.671104,
51
+ "1024": 3306.684416,
52
+ "2048": 13226.737664
53
+ },
54
+ "dependencies": [
55
+ "torch",
56
+ "segmentation-models-pytorch",
57
+ "safetensors.torch"
58
+ ],
59
+ "mlm:framework": "pytorch",
60
+ "mlm:framework_version": "2.1.2+cu121",
61
+ "file:size": 26529040,
62
+ "mlm:memory_size": 1,
63
+ "mlm:accelerator": "cuda",
64
+ "mlm:accelerator_constrained": false,
65
+ "mlm:accelerator_summary": "Unknown",
66
+ "mlm:name": "UNetMobV2_V1",
67
+ "mlm:architecture": "UNetMobV2",
68
+ "mlm:tasks": [
69
+ "semantic-segmentation"
70
+ ],
71
+ "mlm:input": [
72
+ {
73
+ "name": "13 Band Sentinel-2 Batch",
74
+ "bands": [
75
+ "B01",
76
+ "B02",
77
+ "B03",
78
+ "B04",
79
+ "B05",
80
+ "B06",
81
+ "B07",
82
+ "B08",
83
+ "B8A",
84
+ "B09",
85
+ "B10",
86
+ "B11",
87
+ "B12"
88
+ ],
89
+ "input": {
90
+ "shape": [
91
+ -1,
92
+ 13,
93
+ 512,
94
+ 512
95
+ ],
96
+ "dim_order": [
97
+ "batch",
98
+ "channel",
99
+ "height",
100
+ "width"
101
+ ],
102
+ "data_type": "float32"
103
+ },
104
+ "pre_processing_function": null
105
+ }
106
+ ],
107
+ "mlm:output": [
108
+ {
109
+ "name": "semantic-segmentation",
110
+ "tasks": [
111
+ "semantic-segmentation"
112
+ ],
113
+ "result": {
114
+ "shape": [
115
+ -1,
116
+ 4,
117
+ 512,
118
+ 512
119
+ ],
120
+ "dim_order": [
121
+ "batch",
122
+ "channel",
123
+ "height",
124
+ "width"
125
+ ],
126
+ "data_type": "float32"
127
+ },
128
+ "classification:classes": [
129
+ {
130
+ "value": 0,
131
+ "name": "Clear",
132
+ "description": "Clear"
133
+ },
134
+ {
135
+ "value": 1,
136
+ "name": "Thick Clouds",
137
+ "description": "Thick Clouds"
138
+ },
139
+ {
140
+ "value": 2,
141
+ "name": "Thin Clouds",
142
+ "description": "Thin Clouds"
143
+ },
144
+ {
145
+ "value": 3,
146
+ "name": "Cloud Shadows",
147
+ "description": "Cloud Shadows"
148
+ }
149
+ ],
150
+ "post_processing_function": null
151
+ }
152
+ ],
153
+ "mlm:total_parameters": 6632260,
154
+ "mlm:pretrained": true,
155
+ "datetime": null
156
+ },
157
+ "links": [],
158
+ "assets": {
159
+ "trainable": {
160
+ "href": "https://huggingface.co/tacofoundation/GANFilling/resolve/main/model.safetensor",
161
+ "type": "application/octet-stream; application=safetensor",
162
+ "title": "Pytorch weights checkpoint",
163
+ "description": "A UNet model trained on Sentinel-2 imagery for cloud segmentation.The model was trained using the CloudSEN12 dataset.",
164
+ "mlm:artifact_type": "safetensor.torch.save_file",
165
+ "roles": [
166
+ "mlm:model",
167
+ "mlm:weights",
168
+ "data"
169
+ ]
170
+ },
171
+ "source_code": {
172
+ "href": "https://huggingface.co/tacofoundation/GANFilling/resolve/main/load.py",
173
+ "type": "text/x-python",
174
+ "title": "Model load script",
175
+ "description": "Source code to run the model.",
176
+ "roles": [
177
+ "mlm:source_code",
178
+ "code"
179
+ ]
180
+ },
181
+ "source_code_model": {
182
+ "href": "https://huggingface.co/tacofoundation/GANFilling/resolve/main/model.py",
183
+ "type": "text/x-python",
184
+ "title": "Model load script",
185
+ "description": "Source code to run the model.",
186
+ "roles": [
187
+ "mlm:source_code",
188
+ "code"
189
+ ]
190
+ },
191
+ "example_data": {
192
+ "href": "https://huggingface.co/tacofoundation/GANFilling/resolve/main/example_data.safetensor",
193
+ "type": "application/octet-stream; application=safetensors",
194
+ "title": "Example Sentinel-2 image",
195
+ "description": "Example Sentinel-2 image for model inference.",
196
+ "roles": [
197
+ "mlm:example_data",
198
+ "data"
199
+ ]
200
+ }
201
+ },
202
+ "collection": "GANFilling"
203
+ }
model.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ConvLSTMCell(nn.Module):
7
+
8
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias, device):
9
+ """
10
+ Initialize ConvLSTM cell.
11
+
12
+ Parameters
13
+ ----------
14
+ input_dim: int
15
+ Number of channels of input tensor.
16
+ hidden_dim: int
17
+ Number of channels of hidden state.
18
+ kernel_size: (int, int)
19
+ Size of the convolutional kernel.
20
+ bias: bool
21
+ Whether or not to add the bias.
22
+ """
23
+
24
+ super(ConvLSTMCell, self).__init__()
25
+
26
+ self.input_dim = input_dim
27
+ self.hidden_dim = hidden_dim
28
+
29
+ self.kernel_size = kernel_size
30
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
31
+ self.bias = bias
32
+ self.device = device
33
+
34
+ self.conv = nn.Conv2d(
35
+ in_channels=self.input_dim + self.hidden_dim,
36
+ out_channels=4 * self.hidden_dim,
37
+ kernel_size=self.kernel_size,
38
+ padding=self.padding,
39
+ bias=self.bias,
40
+ )
41
+
42
+ def __initStates(self, size):
43
+ return torch.zeros(size).to(self.device), torch.zeros(size).to(self.device)
44
+ # return torch.zeros(size).cuda(), torch.zeros(size).cuda()
45
+
46
+ def forward(self, input_tensor, cur_state):
47
+ if cur_state == None:
48
+ h_cur, c_cur = self.__initStates(
49
+ [
50
+ input_tensor.shape[0],
51
+ self.hidden_dim,
52
+ input_tensor.shape[2],
53
+ input_tensor.shape[3],
54
+ ]
55
+ )
56
+ else:
57
+ h_cur, c_cur = cur_state
58
+
59
+ combined = torch.cat(
60
+ [input_tensor, h_cur], dim=1
61
+ ) # concatenate along channel axis
62
+ combined_conv = self.conv(combined)
63
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
64
+
65
+ i = torch.sigmoid(cc_i)
66
+ f = torch.sigmoid(cc_f)
67
+ o = torch.sigmoid(cc_o)
68
+ g = torch.tanh(cc_g)
69
+
70
+ c_next = f * c_cur + i * g
71
+ h_next = o * torch.tanh(c_next)
72
+
73
+ return h_next, c_next
74
+
75
+ def init_hidden(self, batch_size, image_size):
76
+ height, width = image_size
77
+ return (
78
+ torch.zeros(
79
+ batch_size,
80
+ self.hidden_dim,
81
+ height,
82
+ width,
83
+ device=self.conv.weight.device,
84
+ ),
85
+ torch.zeros(
86
+ batch_size,
87
+ self.hidden_dim,
88
+ height,
89
+ width,
90
+ device=self.conv.weight.device,
91
+ ),
92
+ )
93
+
94
+
95
+ class ConvLSTM(nn.Module):
96
+ """
97
+
98
+ Parameters:
99
+ input_dim: Number of channels in input
100
+ hidden_dim: Number of hidden channels
101
+ kernel_size: Size of kernel in convolutions
102
+ num_layers: Number of LSTM layers stacked on each other
103
+ batch_first: Whether or not dimension 0 is the batch or not
104
+ bias: Bias or no bias in Convolution
105
+ return_all_layers: Return the list of computations for all layers
106
+ Note: Will do same padding.
107
+
108
+ Input:
109
+ A tensor of size B, T, C, H, W or T, B, C, H, W
110
+ Output:
111
+ A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
112
+ 0 - layer_output_list is the list of lists of length T of each output
113
+ 1 - last_state_list is the list of last states
114
+ each element of the list is a tuple (h, c) for hidden state and memory
115
+ Example:
116
+ >> x = torch.rand((32, 10, 64, 128, 128))
117
+ >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
118
+ >> _, last_states = convlstm(x)
119
+ >> h = last_states[0][0] # 0 for layer index, 0 for h index
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ input_dim,
125
+ hidden_dim,
126
+ kernel_size,
127
+ num_layers,
128
+ batch_first=False,
129
+ bias=True,
130
+ return_all_layers=False,
131
+ ):
132
+ super(ConvLSTM, self).__init__()
133
+
134
+ self._check_kernel_size_consistency(kernel_size)
135
+
136
+ # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
137
+ kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
138
+ hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
139
+ if not len(kernel_size) == len(hidden_dim) == num_layers:
140
+ raise ValueError("Inconsistent list length.")
141
+
142
+ self.input_dim = input_dim
143
+ self.hidden_dim = hidden_dim
144
+ self.kernel_size = kernel_size
145
+ self.num_layers = num_layers
146
+ self.batch_first = batch_first
147
+ self.bias = bias
148
+ self.return_all_layers = return_all_layers
149
+
150
+ cell_list = []
151
+ for i in range(0, self.num_layers):
152
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
153
+
154
+ cell_list.append(
155
+ ConvLSTMCell(
156
+ input_dim=cur_input_dim,
157
+ hidden_dim=self.hidden_dim[i],
158
+ kernel_size=self.kernel_size[i],
159
+ bias=self.bias,
160
+ )
161
+ )
162
+
163
+ self.cell_list = nn.ModuleList(cell_list)
164
+
165
+ def forward(self, input_tensor, hidden_state=None):
166
+ """
167
+
168
+ Parameters
169
+ ----------
170
+ input_tensor: todo
171
+ 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
172
+ hidden_state: todo
173
+ None. todo implement stateful
174
+
175
+ Returns
176
+ -------
177
+ last_state_list, layer_output
178
+ """
179
+ if not self.batch_first:
180
+ # (t, b, c, h, w) -> (b, t, c, h, w)
181
+ input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
182
+
183
+ b, _, _, h, w = input_tensor.size()
184
+
185
+ # Implement stateful ConvLSTM
186
+ if hidden_state is not None:
187
+ raise NotImplementedError()
188
+ else:
189
+ # Since the init is done in forward. Can send image size here
190
+ hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
191
+
192
+ layer_output_list = []
193
+ last_state_list = []
194
+
195
+ seq_len = input_tensor.size(1)
196
+ cur_layer_input = input_tensor
197
+
198
+ for layer_idx in range(self.num_layers):
199
+
200
+ h, c = hidden_state[layer_idx]
201
+ output_inner = []
202
+ for t in range(seq_len):
203
+ h, c = self.cell_list[layer_idx](
204
+ input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
205
+ )
206
+ output_inner.append(h)
207
+
208
+ layer_output = torch.stack(output_inner, dim=1)
209
+ cur_layer_input = layer_output
210
+
211
+ layer_output_list.append(layer_output)
212
+ last_state_list.append([h, c])
213
+
214
+ if not self.return_all_layers:
215
+ layer_output_list = layer_output_list[-1:]
216
+ last_state_list = last_state_list[-1:]
217
+
218
+ return layer_output_list, last_state_list
219
+
220
+ def _init_hidden(self, batch_size, image_size):
221
+ init_states = []
222
+ for i in range(self.num_layers):
223
+ init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
224
+ return init_states
225
+
226
+ @staticmethod
227
+ def _check_kernel_size_consistency(kernel_size):
228
+ if not (
229
+ isinstance(kernel_size, tuple)
230
+ or (
231
+ isinstance(kernel_size, list)
232
+ and all([isinstance(elem, tuple) for elem in kernel_size])
233
+ )
234
+ ):
235
+ raise ValueError("`kernel_size` must be tuple or list of tuples")
236
+
237
+ @staticmethod
238
+ def _extend_for_multilayer(param, num_layers):
239
+ if not isinstance(param, list):
240
+ param = [param] * num_layers
241
+ return param
242
+
243
+
244
+ def normal_init(m, mean, std):
245
+ if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
246
+ m.weight.data.normal_(mean, std)
247
+ m.bias.data.zero_()
248
+
249
+
250
+ class Generator(nn.Module):
251
+ def __init__(self, device, inputChannels=4, outputChannels=3, d=64):
252
+ super().__init__()
253
+ self.d = d
254
+ self.device = device
255
+
256
+ self.conv1 = nn.Conv2d(inputChannels, d, 3, 2, 1)
257
+ self.conv2 = nn.Conv2d(d, d * 2, 3, 2, 1)
258
+ self.conv3 = nn.Conv2d(d * 2, d * 4, 3, 2, 1)
259
+ self.conv4 = nn.Conv2d(d * 4, d * 8, 3, 2, 1)
260
+ self.conv5 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
261
+ self.conv6 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
262
+ self.conv7 = nn.Conv2d(d * 8, d * 8, 3, 2, 1)
263
+
264
+ self.conv_lstm_d1 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device)
265
+ self.conv_lstm_d2 = ConvLSTMCell(d * 8 * 2, d * 8, (3, 3), False, device)
266
+ self.conv_lstm_d3 = ConvLSTMCell(d * 8 * 2, d * 8, (3, 3), False, device)
267
+ self.conv_lstm_d4 = ConvLSTMCell(d * 8 * 2, d * 4, (3, 3), False, device)
268
+ self.conv_lstm_d5 = ConvLSTMCell(d * 4 * 2, d * 2, (3, 3), False, device)
269
+ self.conv_lstm_d6 = ConvLSTMCell(d * 2 * 2, d, (3, 3), False, device)
270
+ self.conv_lstm_d7 = ConvLSTMCell(d * 2, d, (3, 3), False, device)
271
+
272
+ self.conv_lstm_e1 = ConvLSTMCell(d, d, (3, 3), False, device)
273
+ self.conv_lstm_e2 = ConvLSTMCell(d * 2, d * 2, (3, 3), False, device)
274
+ self.conv_lstm_e3 = ConvLSTMCell(d * 4, d * 4, (3, 3), False, device)
275
+ self.conv_lstm_e4 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device)
276
+ self.conv_lstm_e5 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device)
277
+ self.conv_lstm_e6 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device)
278
+ self.conv_lstm_e7 = ConvLSTMCell(d * 8, d * 8, (3, 3), False, device)
279
+
280
+ self.up = nn.Upsample(scale_factor=2)
281
+ self.conv_out = nn.Conv2d(d, outputChannels, 3, 1, 1)
282
+
283
+ self.slope = 0.2
284
+
285
+ def weight_init(self, mean, std):
286
+ for m in self._modules:
287
+ normal_init(self._modules[m], mean, std)
288
+
289
+ def forward_step(self, input, states_encoder, states_decoder):
290
+
291
+ e1 = self.conv1(input)
292
+ states_e1 = self.conv_lstm_e1(e1, states_encoder[0])
293
+ e2 = self.conv2(F.leaky_relu(states_e1[0], self.slope))
294
+ states_e2 = self.conv_lstm_e2(e2, states_encoder[1])
295
+ e3 = self.conv3(F.leaky_relu(states_e2[0], self.slope))
296
+ states_e3 = self.conv_lstm_e3(e3, states_encoder[2])
297
+ e4 = self.conv4(F.leaky_relu(states_e3[0], self.slope))
298
+ states_e4 = self.conv_lstm_e4(e4, states_encoder[3])
299
+ e5 = self.conv5(F.leaky_relu(states_e4[0], self.slope))
300
+ states_e5 = self.conv_lstm_e5(e5, states_encoder[4])
301
+ e6 = self.conv6(F.leaky_relu(states_e5[0], self.slope))
302
+ states_e6 = self.conv_lstm_e6(e6, states_encoder[5])
303
+ e7 = self.conv7(F.leaky_relu(states_e6[0], self.slope))
304
+
305
+ states1 = self.conv_lstm_d1(F.relu(e7), states_decoder[0])
306
+ d1 = self.up(states1[0])
307
+ d1 = torch.cat([d1, e6], 1)
308
+
309
+ states2 = self.conv_lstm_d2(F.relu(d1), states_decoder[1])
310
+ d2 = self.up(states2[0])
311
+ d2 = torch.cat([d2, e5], 1)
312
+
313
+ states3 = self.conv_lstm_d3(F.relu(d2), states_decoder[2])
314
+ d3 = self.up(states3[0])
315
+ d3 = torch.cat([d3, e4], 1)
316
+
317
+ states4 = self.conv_lstm_d4(F.relu(d3), states_decoder[3])
318
+ d4 = self.up(states4[0])
319
+ d4 = torch.cat([d4, e3], 1)
320
+
321
+ states5 = self.conv_lstm_d5(F.relu(d4), states_decoder[4])
322
+ d5 = self.up(states5[0])
323
+ d5 = torch.cat([d5, e2], 1)
324
+
325
+ states6 = self.conv_lstm_d6(F.relu(d5), states_decoder[5])
326
+ d6 = self.up(states6[0])
327
+ d6 = torch.cat([d6, e1], 1)
328
+
329
+ states7 = self.conv_lstm_d7(F.relu(d6), states_decoder[6])
330
+ d7 = self.up(states7[0])
331
+
332
+ o = torch.clip(torch.tanh(self.conv_out(d7)), min=-0.0, max=1)
333
+
334
+ states_e = [states_e1, states_e2, states_e3, states_e4, states_e5, states_e6]
335
+ states_d = [states1, states2, states3, states4, states5, states6, states7]
336
+
337
+ return o, (states_e, states_d)
338
+
339
+ def forward(self, tensor):
340
+ states_encoder = (None, None, None, None, None, None, None)
341
+ states_decoder = (None, None, None, None, None, None, None)
342
+ output = torch.empty_like(tensor)
343
+ for timeStep in range(tensor.shape[4]):
344
+ output[:, :, :, :, timeStep], states = self.forward_step(
345
+ tensor[:, :, :, :, timeStep], states_encoder, states_decoder
346
+ )
347
+ states_encoder, states_decoder = states[0], states[1]
348
+ return output, states
model.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6764b118a59ed60f7c89ec5ab34b960b436ec55cbdaba6ef2980fdfe2234cd0c
3
+ size 726989248