EO-VAE / _eo_vae /modeling.py
BiliSakura's picture
Update all files for EO-VAE
0414154 verified
# Apache-2.0 - EO-VAE Encoder/Decoder
# Wavelength-conditioned VAE for multi-spectral imagery
import math
from typing import Any, Optional
import torch
import torch.nn as nn
from torch import Tensor
from .dynamic_conv import DynamicConv, DynamicConvDecoder
from .layers import AttnBlock, Downsample, ResnetBlock, Upsample, swish
def _shuffle_latent_pack(z: Tensor, pi: int = 2, pj: int = 2) -> Tensor:
"""(B, C, H*pi, W*pj) -> (B, C*pi*pj, H, W)"""
b, c, h, w = z.shape
z = z.view(b, c, h // pi, pi, w // pj, pj)
z = z.permute(0, 1, 3, 5, 2, 4).reshape(b, c * pi * pj, h // pi, w // pj)
return z
def _shuffle_latent_unpack(z: Tensor, pi: int = 2, pj: int = 2) -> Tensor:
"""(B, C*pi*pj, H, W) -> (B, C, H*pi, W*pj)"""
b, cp, h, w = z.shape
c = cp // (pi * pj)
z = z.view(b, c, pi, pj, h, w)
z = z.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h * pi, w * pj)
return z
class Encoder(nn.Module):
def __init__(
self,
resolution: int = 256,
in_channels: int = 3,
ch: int = 128,
ch_mult: list = (1, 2, 4, 4),
num_res_blocks: int = 2,
z_channels: int = 32,
use_dynamic_ops: bool = True,
dynamic_conv_kwargs: Optional[dict] = None,
):
super().__init__()
dyn = dynamic_conv_kwargs or {"num_layers": 4, "wv_planes": 256}
dyn = dict(dyn)
wv_planes = dyn.pop("wv_planes", 256)
num_layers = dyn.pop("num_layers", 4)
self.resolution = resolution
self.in_channels = in_channels
self.ch = ch
self.num_res_blocks = num_res_blocks
self.z_channels = z_channels
self.use_dynamic_ops = use_dynamic_ops
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.num_resolutions = len(ch_mult)
if use_dynamic_ops:
self.conv_in = DynamicConv(
wv_planes=wv_planes, inter_dim=dyn.get("inter_dim", 128),
kernel_size=3, stride=1, padding=1, embed_dim=ch,
num_layers=num_layers, num_heads=4,
)
else:
self.conv_in = nn.Conv2d(in_channels, ch, 3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = ch
curr_res = resolution
for i in range(self.num_resolutions):
block_out = ch * ch_mult[i]
block = nn.ModuleList()
for _ in range(num_res_blocks):
block.append(ResnetBlock(block_in, block_out, cond_dim=None))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = nn.ModuleList()
if i != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(block_in, block_in, cond_dim=None)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(block_in, block_in, cond_dim=None)
self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, 3, stride=1, padding=1)
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
def forward(self, x: Tensor, wvs: Tensor) -> Tensor:
if self.use_dynamic_ops:
h = self.conv_in(x, wvs)
else:
h = self.conv_in(x)
for i in range(self.num_resolutions):
for j in range(self.num_res_blocks):
h = self.down[i].block[j](h)
if i != self.num_resolutions - 1:
h = self.down[i].downsample(h)
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h = self.quant_conv(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int = 128,
out_ch: int = 3,
ch_mult: list = (1, 2, 4, 4),
num_res_blocks: int = 2,
resolution: int = 256,
z_channels: int = 32,
use_dynamic_ops: bool = True,
dynamic_conv_kwargs: Optional[dict] = None,
):
super().__init__()
dyn = dynamic_conv_kwargs or {"num_layers": 4, "wv_planes": 256}
dyn = dict(dyn)
wv_planes = dyn.pop("wv_planes", 256)
num_layers = dyn.pop("num_layers", 4)
self.ch = ch
self.num_res_blocks = num_res_blocks
self.z_channels = z_channels
self.resolution = resolution
self.use_dynamic_ops = use_dynamic_ops
self.num_resolutions = len(ch_mult)
self.ch_mult = ch_mult
self.post_quant_conv = nn.Conv2d(z_channels, z_channels, 1)
block_in = ch * ch_mult[-1]
self.conv_in = nn.Conv2d(z_channels, block_in, 3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(block_in, block_in, cond_dim=None)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(block_in, block_in, cond_dim=None)
self.up = nn.ModuleList()
for i in reversed(range(self.num_resolutions)):
block_out = ch * ch_mult[i]
block = nn.ModuleList()
for _ in range(num_res_blocks + 1):
block.append(ResnetBlock(block_in, block_out, cond_dim=None))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = nn.ModuleList()
if i != 0:
up.upsample = Upsample(block_in)
self.up.insert(0, up)
self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6, affine=True)
if use_dynamic_ops:
self.conv_out = DynamicConvDecoder(
wv_planes=wv_planes, inter_dim=dyn.get("inter_dim", 128),
kernel_size=3, stride=1, padding=1, embed_dim=block_in,
num_layers=num_layers, num_heads=4,
)
else:
self.conv_out = nn.Conv2d(block_in, out_ch, 3, stride=1, padding=1)
def forward(self, z: Tensor, wvs: Tensor) -> Tensor:
z = self.post_quant_conv(z)
h = self.conv_in(z)
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
for i in reversed(range(self.num_resolutions)):
for j in range(self.num_res_blocks + 1):
h = self.up[i].block[j](h)
if i != 0:
h = self.up[i].upsample(h)
h = self.norm_out(h)
h = swish(h)
if self.use_dynamic_ops:
h = self.conv_out(h, wvs)
else:
h = self.conv_out(h)
return h
class EOVAEModel(nn.Module):
"""EO-VAE: wavelength-conditioned VAE for multi-spectral imagery."""
def __init__(self, encoder: Encoder, decoder: Decoder, scaling_factor: float = 1.0):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.scaling_factor = scaling_factor
self.ps = (2, 2)
self.bn_eps = 1e-4
self.bn = nn.BatchNorm2d(
math.prod(self.ps) * encoder.z_channels,
affine=False,
track_running_stats=True,
)
@property
def z_channels(self) -> int:
return self.encoder.z_channels
def _normalize_latent(self, z: Tensor) -> Tensor:
self.bn.train(mode=self.training)
return self.bn(z)
def _inv_normalize_latent(self, z: Tensor) -> Tensor:
self.bn.eval()
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
m = self.bn.running_mean.view(1, -1, 1, 1)
return z * s + m
def encode(self, x: Tensor, wvs: Tensor) -> "EOVAEEncoderOutput":
from .distributions import DiagonalGaussianDistribution
moments = self.encoder(x, wvs)
posterior = DiagonalGaussianDistribution(moments)
return EOVAEEncoderOutput(latent_dist=posterior)
def decode(self, z: Tensor, wvs: Tensor) -> Tensor:
z = self._inv_normalize_latent(z)
z = _shuffle_latent_unpack(z, self.ps[0], self.ps[1])
return self.decoder(z, wvs)
def forward(self, x: Tensor, wvs: Tensor, sample_posterior: bool = True) -> tuple[Tensor, Any]:
out = self.encode(x, wvs)
z = out.latent_dist.sample() if sample_posterior else out.latent_dist.mode()
z = _shuffle_latent_pack(z, self.ps[0], self.ps[1])
z = self._normalize_latent(z)
recon = self.decode(z, wvs)
return recon, out.latent_dist
@torch.no_grad()
def encode_to_latent(self, x: Tensor, wvs: Tensor) -> Tensor:
out = self.encode(x, wvs)
z = out.latent_dist.mode()
z = _shuffle_latent_pack(z, self.ps[0], self.ps[1])
return self._normalize_latent(z)
@torch.no_grad()
def encode_spatial_normalized(self, x: Tensor, wvs: Tensor) -> Tensor:
z = self.encode_to_latent(x, wvs)
return _shuffle_latent_unpack(z, self.ps[0], self.ps[1])
@torch.no_grad()
def decode_spatial_normalized(self, z: Tensor, wvs: Tensor) -> Tensor:
z = _shuffle_latent_pack(z, self.ps[0], self.ps[1])
return self.decode(z, wvs)
@torch.no_grad()
def reconstruct(self, x: Tensor, wvs: Tensor) -> Tensor:
recon, _ = self.forward(x, wvs, sample_posterior=False)
return recon
@classmethod
def from_config(cls, config: dict[str, Any]) -> "EOVAEModel":
if "model" in config:
config = config["model"]
enc_cfg = {k: v for k, v in config.get("encoder", config).items() if not str(k).startswith("_")}
dec_cfg = {k: v for k, v in config.get("decoder", config).items() if not str(k).startswith("_")}
def g(d: dict, k: str, default: Any):
return d.get(k, default)
enc_dyn = g(enc_cfg, "dynamic_conv_kwargs", {"num_layers": 4, "wv_planes": 256})
dec_dyn = g(dec_cfg, "dynamic_conv_kwargs", {"num_layers": 4, "wv_planes": 256})
encoder = Encoder(
resolution=g(enc_cfg, "resolution", 256),
in_channels=g(enc_cfg, "in_channels", 3),
ch=g(enc_cfg, "ch", 128),
ch_mult=g(enc_cfg, "ch_mult", [1, 2, 4, 4]),
num_res_blocks=g(enc_cfg, "num_res_blocks", 2),
z_channels=g(enc_cfg, "z_channels", 32),
use_dynamic_ops=g(enc_cfg, "use_dynamic_ops", True),
dynamic_conv_kwargs=enc_dyn,
)
decoder = Decoder(
ch=g(dec_cfg, "ch", 128),
out_ch=g(dec_cfg, "out_ch", 3),
ch_mult=g(dec_cfg, "ch_mult", [1, 2, 4, 4]),
num_res_blocks=g(dec_cfg, "num_res_blocks", 2),
resolution=g(dec_cfg, "resolution", 256),
z_channels=g(dec_cfg, "z_channels", 32),
use_dynamic_ops=g(dec_cfg, "use_dynamic_ops", True),
dynamic_conv_kwargs=dec_dyn,
)
return cls(encoder, decoder, scaling_factor=config.get("scaling_factor", 1.0))
class EOVAEEncoderOutput:
def __init__(self, latent_dist) -> None:
self.latent_dist = latent_dist