BiliSakura commited on
Commit
0414154
·
verified ·
1 Parent(s): d0b2296

Update all files for EO-VAE

Browse files
Files changed (1) hide show
  1. _eo_vae/modeling.py +306 -0
_eo_vae/modeling.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Apache-2.0 - EO-VAE Encoder/Decoder
2
+ # Wavelength-conditioned VAE for multi-spectral imagery
3
+
4
+ import math
5
+ from typing import Any, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+
11
+ from .dynamic_conv import DynamicConv, DynamicConvDecoder
12
+ from .layers import AttnBlock, Downsample, ResnetBlock, Upsample, swish
13
+
14
+
15
+ def _shuffle_latent_pack(z: Tensor, pi: int = 2, pj: int = 2) -> Tensor:
16
+ """(B, C, H*pi, W*pj) -> (B, C*pi*pj, H, W)"""
17
+ b, c, h, w = z.shape
18
+ z = z.view(b, c, h // pi, pi, w // pj, pj)
19
+ z = z.permute(0, 1, 3, 5, 2, 4).reshape(b, c * pi * pj, h // pi, w // pj)
20
+ return z
21
+
22
+
23
+ def _shuffle_latent_unpack(z: Tensor, pi: int = 2, pj: int = 2) -> Tensor:
24
+ """(B, C*pi*pj, H, W) -> (B, C, H*pi, W*pj)"""
25
+ b, cp, h, w = z.shape
26
+ c = cp // (pi * pj)
27
+ z = z.view(b, c, pi, pj, h, w)
28
+ z = z.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h * pi, w * pj)
29
+ return z
30
+
31
+
32
+ class Encoder(nn.Module):
33
+ def __init__(
34
+ self,
35
+ resolution: int = 256,
36
+ in_channels: int = 3,
37
+ ch: int = 128,
38
+ ch_mult: list = (1, 2, 4, 4),
39
+ num_res_blocks: int = 2,
40
+ z_channels: int = 32,
41
+ use_dynamic_ops: bool = True,
42
+ dynamic_conv_kwargs: Optional[dict] = None,
43
+ ):
44
+ super().__init__()
45
+ dyn = dynamic_conv_kwargs or {"num_layers": 4, "wv_planes": 256}
46
+ dyn = dict(dyn)
47
+ wv_planes = dyn.pop("wv_planes", 256)
48
+ num_layers = dyn.pop("num_layers", 4)
49
+
50
+ self.resolution = resolution
51
+ self.in_channels = in_channels
52
+ self.ch = ch
53
+ self.num_res_blocks = num_res_blocks
54
+ self.z_channels = z_channels
55
+ self.use_dynamic_ops = use_dynamic_ops
56
+ in_ch_mult = (1,) + tuple(ch_mult)
57
+ self.in_ch_mult = in_ch_mult
58
+ self.num_resolutions = len(ch_mult)
59
+
60
+ if use_dynamic_ops:
61
+ self.conv_in = DynamicConv(
62
+ wv_planes=wv_planes, inter_dim=dyn.get("inter_dim", 128),
63
+ kernel_size=3, stride=1, padding=1, embed_dim=ch,
64
+ num_layers=num_layers, num_heads=4,
65
+ )
66
+ else:
67
+ self.conv_in = nn.Conv2d(in_channels, ch, 3, stride=1, padding=1)
68
+
69
+ self.down = nn.ModuleList()
70
+ block_in = ch
71
+ curr_res = resolution
72
+ for i in range(self.num_resolutions):
73
+ block_out = ch * ch_mult[i]
74
+ block = nn.ModuleList()
75
+ for _ in range(num_res_blocks):
76
+ block.append(ResnetBlock(block_in, block_out, cond_dim=None))
77
+ block_in = block_out
78
+ down = nn.Module()
79
+ down.block = block
80
+ down.attn = nn.ModuleList()
81
+ if i != self.num_resolutions - 1:
82
+ down.downsample = Downsample(block_in)
83
+ curr_res = curr_res // 2
84
+ self.down.append(down)
85
+
86
+ self.mid = nn.Module()
87
+ self.mid.block_1 = ResnetBlock(block_in, block_in, cond_dim=None)
88
+ self.mid.attn_1 = AttnBlock(block_in)
89
+ self.mid.block_2 = ResnetBlock(block_in, block_in, cond_dim=None)
90
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6, affine=True)
91
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, 3, stride=1, padding=1)
92
+ self.quant_conv = nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
93
+
94
+ def forward(self, x: Tensor, wvs: Tensor) -> Tensor:
95
+ if self.use_dynamic_ops:
96
+ h = self.conv_in(x, wvs)
97
+ else:
98
+ h = self.conv_in(x)
99
+
100
+ for i in range(self.num_resolutions):
101
+ for j in range(self.num_res_blocks):
102
+ h = self.down[i].block[j](h)
103
+ if i != self.num_resolutions - 1:
104
+ h = self.down[i].downsample(h)
105
+
106
+ h = self.mid.block_1(h)
107
+ h = self.mid.attn_1(h)
108
+ h = self.mid.block_2(h)
109
+ h = self.norm_out(h)
110
+ h = swish(h)
111
+ h = self.conv_out(h)
112
+ h = self.quant_conv(h)
113
+ return h
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ ch: int = 128,
120
+ out_ch: int = 3,
121
+ ch_mult: list = (1, 2, 4, 4),
122
+ num_res_blocks: int = 2,
123
+ resolution: int = 256,
124
+ z_channels: int = 32,
125
+ use_dynamic_ops: bool = True,
126
+ dynamic_conv_kwargs: Optional[dict] = None,
127
+ ):
128
+ super().__init__()
129
+ dyn = dynamic_conv_kwargs or {"num_layers": 4, "wv_planes": 256}
130
+ dyn = dict(dyn)
131
+ wv_planes = dyn.pop("wv_planes", 256)
132
+ num_layers = dyn.pop("num_layers", 4)
133
+
134
+ self.ch = ch
135
+ self.num_res_blocks = num_res_blocks
136
+ self.z_channels = z_channels
137
+ self.resolution = resolution
138
+ self.use_dynamic_ops = use_dynamic_ops
139
+ self.num_resolutions = len(ch_mult)
140
+ self.ch_mult = ch_mult
141
+
142
+ self.post_quant_conv = nn.Conv2d(z_channels, z_channels, 1)
143
+ block_in = ch * ch_mult[-1]
144
+ self.conv_in = nn.Conv2d(z_channels, block_in, 3, stride=1, padding=1)
145
+
146
+ self.mid = nn.Module()
147
+ self.mid.block_1 = ResnetBlock(block_in, block_in, cond_dim=None)
148
+ self.mid.attn_1 = AttnBlock(block_in)
149
+ self.mid.block_2 = ResnetBlock(block_in, block_in, cond_dim=None)
150
+
151
+ self.up = nn.ModuleList()
152
+ for i in reversed(range(self.num_resolutions)):
153
+ block_out = ch * ch_mult[i]
154
+ block = nn.ModuleList()
155
+ for _ in range(num_res_blocks + 1):
156
+ block.append(ResnetBlock(block_in, block_out, cond_dim=None))
157
+ block_in = block_out
158
+ up = nn.Module()
159
+ up.block = block
160
+ up.attn = nn.ModuleList()
161
+ if i != 0:
162
+ up.upsample = Upsample(block_in)
163
+ self.up.insert(0, up)
164
+
165
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6, affine=True)
166
+ if use_dynamic_ops:
167
+ self.conv_out = DynamicConvDecoder(
168
+ wv_planes=wv_planes, inter_dim=dyn.get("inter_dim", 128),
169
+ kernel_size=3, stride=1, padding=1, embed_dim=block_in,
170
+ num_layers=num_layers, num_heads=4,
171
+ )
172
+ else:
173
+ self.conv_out = nn.Conv2d(block_in, out_ch, 3, stride=1, padding=1)
174
+
175
+ def forward(self, z: Tensor, wvs: Tensor) -> Tensor:
176
+ z = self.post_quant_conv(z)
177
+ h = self.conv_in(z)
178
+ h = self.mid.block_1(h)
179
+ h = self.mid.attn_1(h)
180
+ h = self.mid.block_2(h)
181
+
182
+ for i in reversed(range(self.num_resolutions)):
183
+ for j in range(self.num_res_blocks + 1):
184
+ h = self.up[i].block[j](h)
185
+ if i != 0:
186
+ h = self.up[i].upsample(h)
187
+
188
+ h = self.norm_out(h)
189
+ h = swish(h)
190
+ if self.use_dynamic_ops:
191
+ h = self.conv_out(h, wvs)
192
+ else:
193
+ h = self.conv_out(h)
194
+ return h
195
+
196
+
197
+ class EOVAEModel(nn.Module):
198
+ """EO-VAE: wavelength-conditioned VAE for multi-spectral imagery."""
199
+
200
+ def __init__(self, encoder: Encoder, decoder: Decoder, scaling_factor: float = 1.0):
201
+ super().__init__()
202
+ self.encoder = encoder
203
+ self.decoder = decoder
204
+ self.scaling_factor = scaling_factor
205
+ self.ps = (2, 2)
206
+ self.bn_eps = 1e-4
207
+ self.bn = nn.BatchNorm2d(
208
+ math.prod(self.ps) * encoder.z_channels,
209
+ affine=False,
210
+ track_running_stats=True,
211
+ )
212
+
213
+ @property
214
+ def z_channels(self) -> int:
215
+ return self.encoder.z_channels
216
+
217
+ def _normalize_latent(self, z: Tensor) -> Tensor:
218
+ self.bn.train(mode=self.training)
219
+ return self.bn(z)
220
+
221
+ def _inv_normalize_latent(self, z: Tensor) -> Tensor:
222
+ self.bn.eval()
223
+ s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
224
+ m = self.bn.running_mean.view(1, -1, 1, 1)
225
+ return z * s + m
226
+
227
+ def encode(self, x: Tensor, wvs: Tensor) -> "EOVAEEncoderOutput":
228
+ from .distributions import DiagonalGaussianDistribution
229
+ moments = self.encoder(x, wvs)
230
+ posterior = DiagonalGaussianDistribution(moments)
231
+ return EOVAEEncoderOutput(latent_dist=posterior)
232
+
233
+ def decode(self, z: Tensor, wvs: Tensor) -> Tensor:
234
+ z = self._inv_normalize_latent(z)
235
+ z = _shuffle_latent_unpack(z, self.ps[0], self.ps[1])
236
+ return self.decoder(z, wvs)
237
+
238
+ def forward(self, x: Tensor, wvs: Tensor, sample_posterior: bool = True) -> tuple[Tensor, Any]:
239
+ out = self.encode(x, wvs)
240
+ z = out.latent_dist.sample() if sample_posterior else out.latent_dist.mode()
241
+ z = _shuffle_latent_pack(z, self.ps[0], self.ps[1])
242
+ z = self._normalize_latent(z)
243
+ recon = self.decode(z, wvs)
244
+ return recon, out.latent_dist
245
+
246
+ @torch.no_grad()
247
+ def encode_to_latent(self, x: Tensor, wvs: Tensor) -> Tensor:
248
+ out = self.encode(x, wvs)
249
+ z = out.latent_dist.mode()
250
+ z = _shuffle_latent_pack(z, self.ps[0], self.ps[1])
251
+ return self._normalize_latent(z)
252
+
253
+ @torch.no_grad()
254
+ def encode_spatial_normalized(self, x: Tensor, wvs: Tensor) -> Tensor:
255
+ z = self.encode_to_latent(x, wvs)
256
+ return _shuffle_latent_unpack(z, self.ps[0], self.ps[1])
257
+
258
+ @torch.no_grad()
259
+ def decode_spatial_normalized(self, z: Tensor, wvs: Tensor) -> Tensor:
260
+ z = _shuffle_latent_pack(z, self.ps[0], self.ps[1])
261
+ return self.decode(z, wvs)
262
+
263
+ @torch.no_grad()
264
+ def reconstruct(self, x: Tensor, wvs: Tensor) -> Tensor:
265
+ recon, _ = self.forward(x, wvs, sample_posterior=False)
266
+ return recon
267
+
268
+ @classmethod
269
+ def from_config(cls, config: dict[str, Any]) -> "EOVAEModel":
270
+ if "model" in config:
271
+ config = config["model"]
272
+ enc_cfg = {k: v for k, v in config.get("encoder", config).items() if not str(k).startswith("_")}
273
+ dec_cfg = {k: v for k, v in config.get("decoder", config).items() if not str(k).startswith("_")}
274
+
275
+ def g(d: dict, k: str, default: Any):
276
+ return d.get(k, default)
277
+
278
+ enc_dyn = g(enc_cfg, "dynamic_conv_kwargs", {"num_layers": 4, "wv_planes": 256})
279
+ dec_dyn = g(dec_cfg, "dynamic_conv_kwargs", {"num_layers": 4, "wv_planes": 256})
280
+
281
+ encoder = Encoder(
282
+ resolution=g(enc_cfg, "resolution", 256),
283
+ in_channels=g(enc_cfg, "in_channels", 3),
284
+ ch=g(enc_cfg, "ch", 128),
285
+ ch_mult=g(enc_cfg, "ch_mult", [1, 2, 4, 4]),
286
+ num_res_blocks=g(enc_cfg, "num_res_blocks", 2),
287
+ z_channels=g(enc_cfg, "z_channels", 32),
288
+ use_dynamic_ops=g(enc_cfg, "use_dynamic_ops", True),
289
+ dynamic_conv_kwargs=enc_dyn,
290
+ )
291
+ decoder = Decoder(
292
+ ch=g(dec_cfg, "ch", 128),
293
+ out_ch=g(dec_cfg, "out_ch", 3),
294
+ ch_mult=g(dec_cfg, "ch_mult", [1, 2, 4, 4]),
295
+ num_res_blocks=g(dec_cfg, "num_res_blocks", 2),
296
+ resolution=g(dec_cfg, "resolution", 256),
297
+ z_channels=g(dec_cfg, "z_channels", 32),
298
+ use_dynamic_ops=g(dec_cfg, "use_dynamic_ops", True),
299
+ dynamic_conv_kwargs=dec_dyn,
300
+ )
301
+ return cls(encoder, decoder, scaling_factor=config.get("scaling_factor", 1.0))
302
+
303
+
304
+ class EOVAEEncoderOutput:
305
+ def __init__(self, latent_dist) -> None:
306
+ self.latent_dist = latent_dist