| |
| |
|
|
| import math |
| from typing import Any, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
| from .dynamic_conv import DynamicConv, DynamicConvDecoder |
| from .layers import AttnBlock, Downsample, ResnetBlock, Upsample, swish |
|
|
|
|
| def _shuffle_latent_pack(z: Tensor, pi: int = 2, pj: int = 2) -> Tensor: |
| """(B, C, H*pi, W*pj) -> (B, C*pi*pj, H, W)""" |
| b, c, h, w = z.shape |
| z = z.view(b, c, h // pi, pi, w // pj, pj) |
| z = z.permute(0, 1, 3, 5, 2, 4).reshape(b, c * pi * pj, h // pi, w // pj) |
| return z |
|
|
|
|
| def _shuffle_latent_unpack(z: Tensor, pi: int = 2, pj: int = 2) -> Tensor: |
| """(B, C*pi*pj, H, W) -> (B, C, H*pi, W*pj)""" |
| b, cp, h, w = z.shape |
| c = cp // (pi * pj) |
| z = z.view(b, c, pi, pj, h, w) |
| z = z.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h * pi, w * pj) |
| return z |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__( |
| self, |
| resolution: int = 256, |
| in_channels: int = 3, |
| ch: int = 128, |
| ch_mult: list = (1, 2, 4, 4), |
| num_res_blocks: int = 2, |
| z_channels: int = 32, |
| use_dynamic_ops: bool = True, |
| dynamic_conv_kwargs: Optional[dict] = None, |
| ): |
| super().__init__() |
| dyn = dynamic_conv_kwargs or {"num_layers": 4, "wv_planes": 256} |
| dyn = dict(dyn) |
| wv_planes = dyn.pop("wv_planes", 256) |
| num_layers = dyn.pop("num_layers", 4) |
|
|
| self.resolution = resolution |
| self.in_channels = in_channels |
| self.ch = ch |
| self.num_res_blocks = num_res_blocks |
| self.z_channels = z_channels |
| self.use_dynamic_ops = use_dynamic_ops |
| in_ch_mult = (1,) + tuple(ch_mult) |
| self.in_ch_mult = in_ch_mult |
| self.num_resolutions = len(ch_mult) |
|
|
| if use_dynamic_ops: |
| self.conv_in = DynamicConv( |
| wv_planes=wv_planes, inter_dim=dyn.get("inter_dim", 128), |
| kernel_size=3, stride=1, padding=1, embed_dim=ch, |
| num_layers=num_layers, num_heads=4, |
| ) |
| else: |
| self.conv_in = nn.Conv2d(in_channels, ch, 3, stride=1, padding=1) |
|
|
| self.down = nn.ModuleList() |
| block_in = ch |
| curr_res = resolution |
| for i in range(self.num_resolutions): |
| block_out = ch * ch_mult[i] |
| block = nn.ModuleList() |
| for _ in range(num_res_blocks): |
| block.append(ResnetBlock(block_in, block_out, cond_dim=None)) |
| block_in = block_out |
| down = nn.Module() |
| down.block = block |
| down.attn = nn.ModuleList() |
| if i != self.num_resolutions - 1: |
| down.downsample = Downsample(block_in) |
| curr_res = curr_res // 2 |
| self.down.append(down) |
|
|
| self.mid = nn.Module() |
| self.mid.block_1 = ResnetBlock(block_in, block_in, cond_dim=None) |
| self.mid.attn_1 = AttnBlock(block_in) |
| self.mid.block_2 = ResnetBlock(block_in, block_in, cond_dim=None) |
| self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6, affine=True) |
| self.conv_out = nn.Conv2d(block_in, 2 * z_channels, 3, stride=1, padding=1) |
| self.quant_conv = nn.Conv2d(2 * z_channels, 2 * z_channels, 1) |
|
|
| def forward(self, x: Tensor, wvs: Tensor) -> Tensor: |
| if self.use_dynamic_ops: |
| h = self.conv_in(x, wvs) |
| else: |
| h = self.conv_in(x) |
|
|
| for i in range(self.num_resolutions): |
| for j in range(self.num_res_blocks): |
| h = self.down[i].block[j](h) |
| if i != self.num_resolutions - 1: |
| h = self.down[i].downsample(h) |
|
|
| h = self.mid.block_1(h) |
| h = self.mid.attn_1(h) |
| h = self.mid.block_2(h) |
| h = self.norm_out(h) |
| h = swish(h) |
| h = self.conv_out(h) |
| h = self.quant_conv(h) |
| return h |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__( |
| self, |
| ch: int = 128, |
| out_ch: int = 3, |
| ch_mult: list = (1, 2, 4, 4), |
| num_res_blocks: int = 2, |
| resolution: int = 256, |
| z_channels: int = 32, |
| use_dynamic_ops: bool = True, |
| dynamic_conv_kwargs: Optional[dict] = None, |
| ): |
| super().__init__() |
| dyn = dynamic_conv_kwargs or {"num_layers": 4, "wv_planes": 256} |
| dyn = dict(dyn) |
| wv_planes = dyn.pop("wv_planes", 256) |
| num_layers = dyn.pop("num_layers", 4) |
|
|
| self.ch = ch |
| self.num_res_blocks = num_res_blocks |
| self.z_channels = z_channels |
| self.resolution = resolution |
| self.use_dynamic_ops = use_dynamic_ops |
| self.num_resolutions = len(ch_mult) |
| self.ch_mult = ch_mult |
|
|
| self.post_quant_conv = nn.Conv2d(z_channels, z_channels, 1) |
| block_in = ch * ch_mult[-1] |
| self.conv_in = nn.Conv2d(z_channels, block_in, 3, stride=1, padding=1) |
|
|
| self.mid = nn.Module() |
| self.mid.block_1 = ResnetBlock(block_in, block_in, cond_dim=None) |
| self.mid.attn_1 = AttnBlock(block_in) |
| self.mid.block_2 = ResnetBlock(block_in, block_in, cond_dim=None) |
|
|
| self.up = nn.ModuleList() |
| for i in reversed(range(self.num_resolutions)): |
| block_out = ch * ch_mult[i] |
| block = nn.ModuleList() |
| for _ in range(num_res_blocks + 1): |
| block.append(ResnetBlock(block_in, block_out, cond_dim=None)) |
| block_in = block_out |
| up = nn.Module() |
| up.block = block |
| up.attn = nn.ModuleList() |
| if i != 0: |
| up.upsample = Upsample(block_in) |
| self.up.insert(0, up) |
|
|
| self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6, affine=True) |
| if use_dynamic_ops: |
| self.conv_out = DynamicConvDecoder( |
| wv_planes=wv_planes, inter_dim=dyn.get("inter_dim", 128), |
| kernel_size=3, stride=1, padding=1, embed_dim=block_in, |
| num_layers=num_layers, num_heads=4, |
| ) |
| else: |
| self.conv_out = nn.Conv2d(block_in, out_ch, 3, stride=1, padding=1) |
|
|
| def forward(self, z: Tensor, wvs: Tensor) -> Tensor: |
| z = self.post_quant_conv(z) |
| h = self.conv_in(z) |
| h = self.mid.block_1(h) |
| h = self.mid.attn_1(h) |
| h = self.mid.block_2(h) |
|
|
| for i in reversed(range(self.num_resolutions)): |
| for j in range(self.num_res_blocks + 1): |
| h = self.up[i].block[j](h) |
| if i != 0: |
| h = self.up[i].upsample(h) |
|
|
| h = self.norm_out(h) |
| h = swish(h) |
| if self.use_dynamic_ops: |
| h = self.conv_out(h, wvs) |
| else: |
| h = self.conv_out(h) |
| return h |
|
|
|
|
| class EOVAEModel(nn.Module): |
| """EO-VAE: wavelength-conditioned VAE for multi-spectral imagery.""" |
|
|
| def __init__(self, encoder: Encoder, decoder: Decoder, scaling_factor: float = 1.0): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.scaling_factor = scaling_factor |
| self.ps = (2, 2) |
| self.bn_eps = 1e-4 |
| self.bn = nn.BatchNorm2d( |
| math.prod(self.ps) * encoder.z_channels, |
| affine=False, |
| track_running_stats=True, |
| ) |
|
|
| @property |
| def z_channels(self) -> int: |
| return self.encoder.z_channels |
|
|
| def _normalize_latent(self, z: Tensor) -> Tensor: |
| self.bn.train(mode=self.training) |
| return self.bn(z) |
|
|
| def _inv_normalize_latent(self, z: Tensor) -> Tensor: |
| self.bn.eval() |
| s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) |
| m = self.bn.running_mean.view(1, -1, 1, 1) |
| return z * s + m |
|
|
| def encode(self, x: Tensor, wvs: Tensor) -> "EOVAEEncoderOutput": |
| from .distributions import DiagonalGaussianDistribution |
| moments = self.encoder(x, wvs) |
| posterior = DiagonalGaussianDistribution(moments) |
| return EOVAEEncoderOutput(latent_dist=posterior) |
|
|
| def decode(self, z: Tensor, wvs: Tensor) -> Tensor: |
| z = self._inv_normalize_latent(z) |
| z = _shuffle_latent_unpack(z, self.ps[0], self.ps[1]) |
| return self.decoder(z, wvs) |
|
|
| def forward(self, x: Tensor, wvs: Tensor, sample_posterior: bool = True) -> tuple[Tensor, Any]: |
| out = self.encode(x, wvs) |
| z = out.latent_dist.sample() if sample_posterior else out.latent_dist.mode() |
| z = _shuffle_latent_pack(z, self.ps[0], self.ps[1]) |
| z = self._normalize_latent(z) |
| recon = self.decode(z, wvs) |
| return recon, out.latent_dist |
|
|
| @torch.no_grad() |
| def encode_to_latent(self, x: Tensor, wvs: Tensor) -> Tensor: |
| out = self.encode(x, wvs) |
| z = out.latent_dist.mode() |
| z = _shuffle_latent_pack(z, self.ps[0], self.ps[1]) |
| return self._normalize_latent(z) |
|
|
| @torch.no_grad() |
| def encode_spatial_normalized(self, x: Tensor, wvs: Tensor) -> Tensor: |
| z = self.encode_to_latent(x, wvs) |
| return _shuffle_latent_unpack(z, self.ps[0], self.ps[1]) |
|
|
| @torch.no_grad() |
| def decode_spatial_normalized(self, z: Tensor, wvs: Tensor) -> Tensor: |
| z = _shuffle_latent_pack(z, self.ps[0], self.ps[1]) |
| return self.decode(z, wvs) |
|
|
| @torch.no_grad() |
| def reconstruct(self, x: Tensor, wvs: Tensor) -> Tensor: |
| recon, _ = self.forward(x, wvs, sample_posterior=False) |
| return recon |
|
|
| @classmethod |
| def from_config(cls, config: dict[str, Any]) -> "EOVAEModel": |
| if "model" in config: |
| config = config["model"] |
| enc_cfg = {k: v for k, v in config.get("encoder", config).items() if not str(k).startswith("_")} |
| dec_cfg = {k: v for k, v in config.get("decoder", config).items() if not str(k).startswith("_")} |
|
|
| def g(d: dict, k: str, default: Any): |
| return d.get(k, default) |
|
|
| enc_dyn = g(enc_cfg, "dynamic_conv_kwargs", {"num_layers": 4, "wv_planes": 256}) |
| dec_dyn = g(dec_cfg, "dynamic_conv_kwargs", {"num_layers": 4, "wv_planes": 256}) |
|
|
| encoder = Encoder( |
| resolution=g(enc_cfg, "resolution", 256), |
| in_channels=g(enc_cfg, "in_channels", 3), |
| ch=g(enc_cfg, "ch", 128), |
| ch_mult=g(enc_cfg, "ch_mult", [1, 2, 4, 4]), |
| num_res_blocks=g(enc_cfg, "num_res_blocks", 2), |
| z_channels=g(enc_cfg, "z_channels", 32), |
| use_dynamic_ops=g(enc_cfg, "use_dynamic_ops", True), |
| dynamic_conv_kwargs=enc_dyn, |
| ) |
| decoder = Decoder( |
| ch=g(dec_cfg, "ch", 128), |
| out_ch=g(dec_cfg, "out_ch", 3), |
| ch_mult=g(dec_cfg, "ch_mult", [1, 2, 4, 4]), |
| num_res_blocks=g(dec_cfg, "num_res_blocks", 2), |
| resolution=g(dec_cfg, "resolution", 256), |
| z_channels=g(dec_cfg, "z_channels", 32), |
| use_dynamic_ops=g(dec_cfg, "use_dynamic_ops", True), |
| dynamic_conv_kwargs=dec_dyn, |
| ) |
| return cls(encoder, decoder, scaling_factor=config.get("scaling_factor", 1.0)) |
|
|
|
|
| class EOVAEEncoderOutput: |
| def __init__(self, latent_dist) -> None: |
| self.latent_dist = latent_dist |
|
|