HSIGene / vae /model.py
BiliSakura's picture
Add files using upload-large-folder tool
66a2b45 verified
"""HSIGene AutoencoderKL - nn.Module, no Lightning. Loss = Identity."""
import torch
import torch.nn as nn
from .vae_blocks import Encoder, Decoder, DiagonalGaussianDistribution
class AutoencoderKL(nn.Module):
"""
AutoencoderKL - nn.Module (not Lightning).
Uses Encoder, Decoder, quant_conv, post_quant_conv.
encode() returns posterior, decode() takes z.
Loss = Identity (no-op).
"""
def __init__(
self,
ddconfig,
embed_dim=4,
lossconfig=None,
**kwargs,
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig.get("double_z", True)
z_channels = ddconfig["z_channels"]
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, z_channels, 1)
self.embed_dim = embed_dim
self.loss = nn.Identity()
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments, deterministic=True)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
return self.decoder(z)
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
class HSIGeneAutoencoderKL(AutoencoderKL):
"""
HSIGene VAE with diffusers-style config.
Accepts in_channels, out_channels, latent_channels, block_out_channels.
"""
def __init__(
self,
in_channels: int = 48,
out_channels: int = 48,
latent_channels: int = 96,
embed_dim: int = 4,
block_out_channels: tuple = (64, 128, 256),
num_res_blocks: int = 4,
attn_resolutions: tuple = (16, 32, 64),
dropout: float = 0.0,
double_z: bool = True,
resolution: int = 256,
**kwargs,
):
ch = block_out_channels[0]
ch_mult = tuple(
block_out_channels[i] // ch for i in range(len(block_out_channels))
)
ddconfig = dict(
double_z=double_z,
z_channels=latent_channels,
resolution=resolution,
in_channels=in_channels,
out_ch=out_channels,
ch=ch,
ch_mult=list(ch_mult),
num_res_blocks=num_res_blocks,
attn_resolutions=list(attn_resolutions),
dropout=dropout,
)
super().__init__(ddconfig=ddconfig, embed_dim=embed_dim, **kwargs)