| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import typing |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | from PIL import Image |
| | from torchvision import transforms |
| |
|
| | from diffusers.image_processor import VaeImageProcessor |
| | from diffusers.models.autoencoders.autoencoder_kl import ( |
| | AutoencoderKL, |
| | AutoencoderKLOutput, |
| | ) |
| | from diffusers.models.autoencoders.autoencoder_tiny import ( |
| | AutoencoderTiny, |
| | AutoencoderTinyOutput, |
| | ) |
| | from diffusers.models.autoencoders.vae import DecoderOutput |
| |
|
| |
|
| | SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny] |
| |
|
| |
|
| | def load_vae_model( |
| | *, |
| | device: torch.device, |
| | model_name_or_path: str, |
| | revision: Optional[str], |
| | variant: Optional[str], |
| | |
| | subfolder: Optional[str], |
| | use_tiny_nn: bool, |
| | ) -> SupportedAutoencoder: |
| | if use_tiny_nn: |
| | |
| | down_scale = 2 |
| | up_scale = 2 |
| | vae = AutoencoderTiny.from_pretrained( |
| | model_name_or_path, |
| | subfolder=subfolder, |
| | revision=revision, |
| | variant=variant, |
| | downscaling_scaling_factor=down_scale, |
| | upsampling_scaling_factor=up_scale, |
| | ) |
| | assert isinstance(vae, AutoencoderTiny) |
| | else: |
| | vae = AutoencoderKL.from_pretrained( |
| | model_name_or_path, |
| | subfolder=subfolder, |
| | revision=revision, |
| | variant=variant, |
| | ) |
| | assert isinstance(vae, AutoencoderKL) |
| | vae = vae.to(device) |
| | vae.eval() |
| | return vae |
| |
|
| |
|
| | def pil_to_nhwc( |
| | *, |
| | device: torch.device, |
| | image: Image.Image, |
| | ) -> torch.Tensor: |
| | assert image.mode == "RGB" |
| | transform = transforms.ToTensor() |
| | nhwc = transform(image).unsqueeze(0).to(device) |
| | assert isinstance(nhwc, torch.Tensor) |
| | return nhwc |
| |
|
| |
|
| | def nhwc_to_pil( |
| | *, |
| | nhwc: torch.Tensor, |
| | ) -> Image.Image: |
| | assert nhwc.shape[0] == 1 |
| | hwc = nhwc.squeeze(0).cpu() |
| | return transforms.ToPILImage()(hwc) |
| |
|
| |
|
| | def concatenate_images( |
| | *, |
| | left: Image.Image, |
| | right: Image.Image, |
| | vertical: bool = False, |
| | ) -> Image.Image: |
| | width1, height1 = left.size |
| | width2, height2 = right.size |
| | if vertical: |
| | total_height = height1 + height2 |
| | max_width = max(width1, width2) |
| | new_image = Image.new("RGB", (max_width, total_height)) |
| | new_image.paste(left, (0, 0)) |
| | new_image.paste(right, (0, height1)) |
| | else: |
| | total_width = width1 + width2 |
| | max_height = max(height1, height2) |
| | new_image = Image.new("RGB", (total_width, max_height)) |
| | new_image.paste(left, (0, 0)) |
| | new_image.paste(right, (width1, 0)) |
| | return new_image |
| |
|
| |
|
| | def to_latent( |
| | *, |
| | rgb_nchw: torch.Tensor, |
| | vae: SupportedAutoencoder, |
| | ) -> torch.Tensor: |
| | rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) |
| | encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw)) |
| | if isinstance(encoding_nchw, AutoencoderKLOutput): |
| | latent = encoding_nchw.latent_dist.sample() |
| | assert isinstance(latent, torch.Tensor) |
| | elif isinstance(encoding_nchw, AutoencoderTinyOutput): |
| | latent = encoding_nchw.latents |
| | do_internal_vae_scaling = False |
| | if do_internal_vae_scaling: |
| | latent = vae.scale_latents(latent).mul(255).round().byte() |
| | latent = vae.unscale_latents(latent / 255.0) |
| | assert isinstance(latent, torch.Tensor) |
| | else: |
| | assert False, f"Unknown encoding type: {type(encoding_nchw)}" |
| | return latent |
| |
|
| |
|
| | def from_latent( |
| | *, |
| | latent_nchw: torch.Tensor, |
| | vae: SupportedAutoencoder, |
| | ) -> torch.Tensor: |
| | decoding_nchw = vae.decode(latent_nchw) |
| | assert isinstance(decoding_nchw, DecoderOutput) |
| | rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) |
| | assert isinstance(rgb_nchw, torch.Tensor) |
| | return rgb_nchw |
| |
|
| |
|
| | def main_kwargs( |
| | *, |
| | device: torch.device, |
| | input_image_path: str, |
| | pretrained_model_name_or_path: str, |
| | revision: Optional[str], |
| | variant: Optional[str], |
| | subfolder: Optional[str], |
| | use_tiny_nn: bool, |
| | ) -> None: |
| | vae = load_vae_model( |
| | device=device, |
| | model_name_or_path=pretrained_model_name_or_path, |
| | revision=revision, |
| | variant=variant, |
| | subfolder=subfolder, |
| | use_tiny_nn=use_tiny_nn, |
| | ) |
| | original_pil = Image.open(input_image_path).convert("RGB") |
| | original_image = pil_to_nhwc( |
| | device=device, |
| | image=original_pil, |
| | ) |
| | print(f"Original image shape: {original_image.shape}") |
| | reconstructed_image: Optional[torch.Tensor] = None |
| |
|
| | with torch.no_grad(): |
| | latent_image = to_latent(rgb_nchw=original_image, vae=vae) |
| | print(f"Latent shape: {latent_image.shape}") |
| | reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae) |
| | reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image) |
| | combined_image = concatenate_images( |
| | left=original_pil, |
| | right=reconstructed_pil, |
| | vertical=False, |
| | ) |
| | combined_image.show("Original | Reconstruction") |
| | print(f"Reconstructed image shape: {reconstructed_image.shape}") |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Inference with VAE") |
| | parser.add_argument( |
| | "--input_image", |
| | type=str, |
| | required=True, |
| | help="Path to the input image for inference.", |
| | ) |
| | parser.add_argument( |
| | "--pretrained_model_name_or_path", |
| | type=str, |
| | required=True, |
| | help="Path to pretrained VAE model.", |
| | ) |
| | parser.add_argument( |
| | "--revision", |
| | type=str, |
| | default=None, |
| | help="Model version.", |
| | ) |
| | parser.add_argument( |
| | "--variant", |
| | type=str, |
| | default=None, |
| | help="Model file variant, e.g., 'fp16'.", |
| | ) |
| | parser.add_argument( |
| | "--subfolder", |
| | type=str, |
| | default=None, |
| | help="Subfolder in the model file.", |
| | ) |
| | parser.add_argument( |
| | "--use_cuda", |
| | action="store_true", |
| | help="Use CUDA if available.", |
| | ) |
| | parser.add_argument( |
| | "--use_tiny_nn", |
| | action="store_true", |
| | help="Use tiny neural network.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | def main_cli() -> None: |
| | args = parse_args() |
| |
|
| | input_image_path = args.input_image |
| | assert isinstance(input_image_path, str) |
| |
|
| | pretrained_model_name_or_path = args.pretrained_model_name_or_path |
| | assert isinstance(pretrained_model_name_or_path, str) |
| |
|
| | revision = args.revision |
| | assert isinstance(revision, (str, type(None))) |
| |
|
| | variant = args.variant |
| | assert isinstance(variant, (str, type(None))) |
| |
|
| | subfolder = args.subfolder |
| | assert isinstance(subfolder, (str, type(None))) |
| |
|
| | use_cuda = args.use_cuda |
| | assert isinstance(use_cuda, bool) |
| |
|
| | use_tiny_nn = args.use_tiny_nn |
| | assert isinstance(use_tiny_nn, bool) |
| |
|
| | device = torch.device("cuda" if use_cuda else "cpu") |
| |
|
| | main_kwargs( |
| | device=device, |
| | input_image_path=input_image_path, |
| | pretrained_model_name_or_path=pretrained_model_name_or_path, |
| | revision=revision, |
| | variant=variant, |
| | subfolder=subfolder, |
| | use_tiny_nn=use_tiny_nn, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main_cli() |
| |
|