| | """ |
| | Wrapper to match VAE interface to that of SD VAE. |
| | """ |
| |
|
| | from types import SimpleNamespace |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| | from medvae.models import AutoencoderKL_2D |
| | from medvae.utils.factory import ( |
| | FILE_DICT_ASSOCIATIONS, |
| | create_model, |
| | download_model_weights, |
| | ) |
| | from omegaconf import OmegaConf |
| |
|
| |
|
| | class LatentDist: |
| | def __init__(self, dist): |
| | self.latent_dist = dist |
| |
|
| | def sample(self): |
| | return self.latent_dist.mode() |
| |
|
| | |
| | def mode(self): |
| | return self.latent_dist.mode() |
| |
|
| |
|
| | class MedVAEWrapper(ModelMixin, ConfigMixin): |
| | config_name = "config.json" |
| | ignore_for_config = ["vae"] |
| |
|
| | @register_to_config |
| | def __init__(self, vae=None, scaling_factor=1.0, downsampling_factor=4): |
| | super().__init__() |
| | assert downsampling_factor in [ |
| | 4, |
| | 8, |
| | ], "Only 4x and 8x downsampling are currently supported" |
| | if vae is None: |
| | model_name = ( |
| | "medvae_4_4_2d_c" if downsampling_factor == 4 else "medvae_8_4_2d_c" |
| | ) |
| | config_fpath = download_model_weights( |
| | FILE_DICT_ASSOCIATIONS[model_name]["config"] |
| | ) |
| | if model_name == "medvae_8_4_2d_c": |
| | config_fpath = "/data/yurman/repos/fast-mri-ldm/submodules/medvae/configs/ours-8x1-new.yaml" |
| |
|
| | conf = OmegaConf.load(config_fpath) |
| | conf.embed_dim = 4 |
| | conf.ddconfig.z_channels = 4 |
| | conf["ddconfig"]["in_channels"] = 2 |
| | conf["ddconfig"]["out_ch"] = 2 |
| |
|
| | vae = AutoencoderKL_2D( |
| | ddconfig=conf.ddconfig, |
| | embed_dim=conf.embed_dim, |
| | ) |
| |
|
| | self.vae = vae |
| | |
| | |
| | n_blocks = int(np.log2(downsampling_factor)) + 1 |
| | self.register_to_config( |
| | block_out_channels=[ |
| | 1, |
| | ] |
| | * n_blocks, |
| | in_channels=2, |
| | scaling_factor=scaling_factor, |
| | downsampling_factor=downsampling_factor, |
| | ) |
| |
|
| | def encode(self, x): |
| | dist = self.vae.encode(x) |
| |
|
| | return SimpleNamespace(latent_dist=LatentDist(dist)) |
| |
|
| | def decode(self, x, return_dict=False, generator=None): |
| | with torch.amp.autocast(device_type="cuda", enabled=False): |
| | x = self.vae.decode(x) |
| | return (x,) |
| |
|