|
|
""" |
|
|
Wrapper to match VAE interface to that of SD VAE. |
|
|
""" |
|
|
|
|
|
from types import SimpleNamespace |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from medvae.models import AutoencoderKL_2D |
|
|
from medvae.utils.factory import ( |
|
|
FILE_DICT_ASSOCIATIONS, |
|
|
create_model, |
|
|
download_model_weights, |
|
|
) |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
class LatentDist: |
|
|
def __init__(self, dist): |
|
|
self.latent_dist = dist |
|
|
|
|
|
def sample(self): |
|
|
return self.latent_dist.mode() |
|
|
|
|
|
|
|
|
def mode(self): |
|
|
return self.latent_dist.mode() |
|
|
|
|
|
|
|
|
class MedVAEWrapper(ModelMixin, ConfigMixin): |
|
|
config_name = "config.json" |
|
|
ignore_for_config = ["vae"] |
|
|
|
|
|
@register_to_config |
|
|
def __init__(self, vae=None, scaling_factor=1.0, downsampling_factor=4): |
|
|
super().__init__() |
|
|
assert downsampling_factor in [ |
|
|
4, |
|
|
8, |
|
|
], "Only 4x and 8x downsampling are currently supported" |
|
|
if vae is None: |
|
|
model_name = ( |
|
|
"medvae_4_4_2d_c" if downsampling_factor == 4 else "medvae_8_4_2d_c" |
|
|
) |
|
|
config_fpath = download_model_weights( |
|
|
FILE_DICT_ASSOCIATIONS[model_name]["config"] |
|
|
) |
|
|
if model_name == "medvae_8_4_2d_c": |
|
|
config_fpath = "/data/yurman/repos/fast-mri-ldm/submodules/medvae/configs/ours-8x1-new.yaml" |
|
|
|
|
|
conf = OmegaConf.load(config_fpath) |
|
|
conf.embed_dim = 4 |
|
|
conf.ddconfig.z_channels = 4 |
|
|
conf["ddconfig"]["in_channels"] = 2 |
|
|
conf["ddconfig"]["out_ch"] = 2 |
|
|
|
|
|
vae = AutoencoderKL_2D( |
|
|
ddconfig=conf.ddconfig, |
|
|
embed_dim=conf.embed_dim, |
|
|
) |
|
|
|
|
|
self.vae = vae |
|
|
|
|
|
|
|
|
n_blocks = int(np.log2(downsampling_factor)) + 1 |
|
|
self.register_to_config( |
|
|
block_out_channels=[ |
|
|
1, |
|
|
] |
|
|
* n_blocks, |
|
|
in_channels=2, |
|
|
scaling_factor=scaling_factor, |
|
|
downsampling_factor=downsampling_factor, |
|
|
) |
|
|
|
|
|
def encode(self, x): |
|
|
dist = self.vae.encode(x) |
|
|
|
|
|
return SimpleNamespace(latent_dist=LatentDist(dist)) |
|
|
|
|
|
def decode(self, x, return_dict=False, generator=None): |
|
|
with torch.amp.autocast(device_type="cuda", enabled=False): |
|
|
x = self.vae.decode(x) |
|
|
return (x,) |
|
|
|