uncond-sd2-base-complex-4 / vae /pips.model.med_vae_wrapper.py
yurman's picture
End of training
d248400 verified
"""
Wrapper to match VAE interface to that of SD VAE.
"""
from types import SimpleNamespace
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from medvae.models import AutoencoderKL_2D
from medvae.utils.factory import (
FILE_DICT_ASSOCIATIONS,
create_model,
download_model_weights,
)
from omegaconf import OmegaConf
class LatentDist:
def __init__(self, dist):
self.latent_dist = dist
def sample(self):
return self.latent_dist.mode()
# alias
def mode(self):
return self.latent_dist.mode()
class MedVAEWrapper(ModelMixin, ConfigMixin):
config_name = "config.json"
ignore_for_config = ["vae"]
@register_to_config
def __init__(self, vae=None, scaling_factor=1.0, downsampling_factor=4):
super().__init__()
assert downsampling_factor in [
4,
8,
], "Only 4x and 8x downsampling are currently supported"
if vae is None:
model_name = (
"medvae_4_4_2d_c" if downsampling_factor == 4 else "medvae_8_4_2d_c"
)
config_fpath = download_model_weights(
FILE_DICT_ASSOCIATIONS[model_name]["config"]
)
if model_name == "medvae_8_4_2d_c":
config_fpath = "/data/yurman/repos/fast-mri-ldm/submodules/medvae/configs/ours-8x1-new.yaml"
conf = OmegaConf.load(config_fpath)
conf.embed_dim = 4
conf.ddconfig.z_channels = 4
conf["ddconfig"]["in_channels"] = 2
conf["ddconfig"]["out_ch"] = 2
vae = AutoencoderKL_2D(
ddconfig=conf.ddconfig,
embed_dim=conf.embed_dim,
)
self.vae = vae
# When using SD pipeline it uses `block_out_channels` to determine the size of the image based on
# 2 ** (len(block_out_channels) - 1)
n_blocks = int(np.log2(downsampling_factor)) + 1
self.register_to_config(
block_out_channels=[
1,
]
* n_blocks,
in_channels=2,
scaling_factor=scaling_factor,
downsampling_factor=downsampling_factor,
)
def encode(self, x):
dist = self.vae.encode(x)
return SimpleNamespace(latent_dist=LatentDist(dist))
def decode(self, x, return_dict=False, generator=None):
with torch.amp.autocast(device_type="cuda", enabled=False):
x = self.vae.decode(x)
return (x,)