CRS-Diff / crs_core /autoencoder.py
BiliSakura's picture
Add files using upload-large-folder tool
b6acc0a verified
raw
history blame contribute delete
912 Bytes
import torch
import torch.nn as nn
from crs_core.modules.diffusionmodules.model import Encoder, Decoder
from crs_core.modules.distributions.distributions import DiagonalGaussianDistribution
class AutoencoderKL(nn.Module):
def __init__(self, ddconfig, lossconfig=None, embed_dim=4, **kwargs):
super().__init__()
del lossconfig, kwargs
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
return DiagonalGaussianDistribution(moments)
def decode(self, z):
z = self.post_quant_conv(z)
return self.decoder(z)