File size: 2,605 Bytes
d248400 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""
Wrapper to match VAE interface to that of SD VAE.
"""
from types import SimpleNamespace
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from medvae.models import AutoencoderKL_2D
from medvae.utils.factory import (
FILE_DICT_ASSOCIATIONS,
create_model,
download_model_weights,
)
from omegaconf import OmegaConf
class LatentDist:
def __init__(self, dist):
self.latent_dist = dist
def sample(self):
return self.latent_dist.mode()
# alias
def mode(self):
return self.latent_dist.mode()
class MedVAEWrapper(ModelMixin, ConfigMixin):
config_name = "config.json"
ignore_for_config = ["vae"]
@register_to_config
def __init__(self, vae=None, scaling_factor=1.0, downsampling_factor=4):
super().__init__()
assert downsampling_factor in [
4,
8,
], "Only 4x and 8x downsampling are currently supported"
if vae is None:
model_name = (
"medvae_4_4_2d_c" if downsampling_factor == 4 else "medvae_8_4_2d_c"
)
config_fpath = download_model_weights(
FILE_DICT_ASSOCIATIONS[model_name]["config"]
)
if model_name == "medvae_8_4_2d_c":
config_fpath = "/data/yurman/repos/fast-mri-ldm/submodules/medvae/configs/ours-8x1-new.yaml"
conf = OmegaConf.load(config_fpath)
conf.embed_dim = 4
conf.ddconfig.z_channels = 4
conf["ddconfig"]["in_channels"] = 2
conf["ddconfig"]["out_ch"] = 2
vae = AutoencoderKL_2D(
ddconfig=conf.ddconfig,
embed_dim=conf.embed_dim,
)
self.vae = vae
# When using SD pipeline it uses `block_out_channels` to determine the size of the image based on
# 2 ** (len(block_out_channels) - 1)
n_blocks = int(np.log2(downsampling_factor)) + 1
self.register_to_config(
block_out_channels=[
1,
]
* n_blocks,
in_channels=2,
scaling_factor=scaling_factor,
downsampling_factor=downsampling_factor,
)
def encode(self, x):
dist = self.vae.encode(x)
return SimpleNamespace(latent_dist=LatentDist(dist))
def decode(self, x, return_dict=False, generator=None):
with torch.amp.autocast(device_type="cuda", enabled=False):
x = self.vae.decode(x)
return (x,)
|