lcm-lora-sdv1-5 / launcher.py
yongqiang
Update: supports generating images with a resolution of 1024x768
6143275
from typing import List, Optional, Tuple, Union
import argparse
import os
import time
import warnings
import numpy as np
import onnxruntime
import axengine
import torch
from PIL import Image
from transformers import CLIPTokenizer, PreTrainedTokenizer
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.utils import load_image, make_image_grid
from diffusers.utils.torch_utils import randn_tensor
############ Img2Img
PipelineImageInput = Union[
Image.Image,
np.ndarray,
torch.Tensor,
List[Image.Image],
List[np.ndarray],
List[torch.Tensor],
]
PipelineDepthInput = PipelineImageInput
TIME_EMBED_KEY = "/down_blocks.0/resnets.0/act_1/Mul_output_0"
TXT2IMG_TIMESTEPS = np.array([999, 759, 499, 259], dtype=np.int64)
IMG2IMG_TIMESTEPS = np.array([499, 259], dtype=np.int64)
IMG2IMG_SELF_TIMESTEPS = np.array([999, 759, 499, 259], dtype=np.int64)
IMG2IMG_STEP_INDEX = [2, 3]
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
# self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
# Convert betas to alphas_bar_sqrt
beta_start = 0.00085
beta_end = 0.012
num_train_timesteps = 1000
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod = alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
r"""
Convert a NumPy image to a PyTorch tensor.
Args:
images (`np.ndarray`):
The NumPy image array to convert to PyTorch format.
Returns:
`torch.Tensor`:
A PyTorch tensor representation of the images.
"""
if images.ndim == 3:
images = images[..., None]
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
return images
def pil_to_numpy(images: Union[List[Image.Image], Image.Image]) -> np.ndarray:
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
if not isinstance(images, list):
images = [images]
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
images = np.stack(images, axis=0)
return images
def is_valid_image(image) -> bool:
r"""
Checks if the input is a valid image.
A valid image can be:
- A `PIL.Image.Image`.
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
Returns:
`bool`:
`True` if the input is a valid image, `False` otherwise.
"""
return isinstance(image, Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
def is_valid_image_imagelist(images):
r"""
Checks if the input is a valid image or list of images.
The input can be one of the following formats:
- A 4D tensor or numpy array (batch of images).
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
`torch.Tensor`.
- A list of valid images.
Args:
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
images.
Returns:
`bool`:
`True` if the input is valid, `False` otherwise.
"""
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
return True
elif isinstance(images, list):
return all(is_valid_image(image) for image in images)
return False
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Normalize an image array to [-1,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to normalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The normalized image array.
"""
return 2.0 * images - 1.0
# Copy from: /home/baiyongqiang/miniforge-pypy3/envs/hf/lib/python3.9/site-packages/diffusers/image_processor.py#607
def preprocess(
image: PipelineImageInput,
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: str = "default", # "default", "fill", "crop"
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""
Preprocess the image input.
Args:
image (`PipelineImageInput`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats.
height (`int`, *optional*):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height.
width (`int`, *optional*):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
`torch.Tensor`:
The preprocessed image.
"""
supported_formats = (Image.Image, np.ndarray, torch.Tensor)
# # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
# if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
# if isinstance(image, torch.Tensor):
# # if image is a pytorch tensor could have 2 possible shapes:
# # 1. batch x height x width: we should insert the channel dimension at position 1
# # 2. channel x height x width: we should insert batch dimension at position 0,
# # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
# # for simplicity, we insert a dimension of size 1 at position 1 for both cases
# image = image.unsqueeze(1)
# else:
# # if it is a numpy array, it could have 2 possible shapes:
# # 1. batch x height x width: insert channel dimension on last position
# # 2. height x width x channel: insert batch dimension on first position
# if image.shape[-1] == 1:
# image = np.expand_dims(image, axis=0)
# else:
# image = np.expand_dims(image, axis=-1)
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
FutureWarning,
)
image = np.concatenate(image, axis=0)
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
FutureWarning,
)
image = torch.cat(image, axis=0)
if not is_valid_image_imagelist(image):
raise ValueError(
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
)
if not isinstance(image, list):
image = [image]
if isinstance(image[0], Image.Image):
if crops_coords is not None:
image = [i.crop(crops_coords) for i in image]
# if self.config.do_resize:
# height, width = self.get_default_height_width(image[0], height, width)
# image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
# if self.config.do_convert_rgb:
# image = [self.convert_to_rgb(i) for i in image]
# elif self.config.do_convert_grayscale:
# image = [self.convert_to_grayscale(i) for i in image]
image = pil_to_numpy(image) # to np
image = numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
# image = self.numpy_to_pt(image)
# height, width = self.get_default_height_width(image, height, width)
# if self.config.do_resize:
# image = self.resize(image, height, width)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
# if self.config.do_convert_grayscale and image.ndim == 3:
# image = image.unsqueeze(1)
channel = image.shape[1]
# don't need any preprocess if the image is latents
# if channel == self.config.vae_latent_channels:
# return image
# height, width = self.get_default_height_width(image, height, width)
# if self.config.do_resize:
# image = self.resize(image, height, width)
# expected range [0,1], normalize to [-1,1]
do_normalize = True # self.config.do_normalize
if do_normalize and image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
FutureWarning,
)
do_normalize = False
if do_normalize:
image = normalize(image)
# if self.config.do_binarize:
# image = self.binarize(image)
return image
##########
def get_args():
parser = argparse.ArgumentParser(
prog="StableDiffusion",
description="Stable Diffusion txt2img/img2img inference"
)
parser.add_argument("--backend", choices=["axe", "onnx"], default="axe", help="Inference backend (axe or onnx)")
parser.add_argument("--prompt", type=str, default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", help="Input text prompt")
parser.add_argument("--model_dir", type=str, default="./models", help="Directory containing tokenizer, text encoder, UNet, VAE, time inputs")
parser.add_argument("--time_input", type=str, default=None, help="Optional override for time input numpy file")
parser.add_argument("--init_image", type=str, default=None, help="Provide an init image to enable img2img")
parser.add_argument(
"--isize",
type=str,
default="512",
help="Output image size. Accepts a single integer (square) or '<height>x<width>'. Overridden by --height/--width if provided",
)
parser.add_argument("--height", type=int, default=None, help="Output image height (must be multiple of 8)")
parser.add_argument("--width", type=int, default=None, help="Output image width (must be multiple of 8)")
parser.add_argument("-o", "--save_dir", type=str, default="./output.png", help="Path to save the generated image")
parser.add_argument("--seed", type=int, default=None, help="Random seed (img2img defaults to 0 if unspecified)")
return parser.parse_args()
def maybe_convert_prompt(prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt
prompts = [_maybe_convert_prompt(p, tokenizer) for p in prompts]
if not isinstance(prompt, List):
return prompts[0]
return prompts
def _maybe_convert_prompt(prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
tokens = tokenizer.tokenize(prompt)
unique_tokens = set(tokens)
for token in unique_tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f" {token}_{i}"
i += 1
prompt = prompt.replace(token, replacement)
return prompt
def create_session(model_path: str, backend: str):
if backend == "onnx":
return onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
return axengine.InferenceSession(model_path)
def ensure_multiple_of_eight(size: int) -> int:
if size % 8 != 0:
raise ValueError("Image size must be a multiple of 8")
return size
def parse_isize(isize: Union[int, str]) -> Tuple[int, int]:
if isinstance(isize, str):
token = isize.lower().replace(" ", "")
if "x" in token:
parts = token.split("x")
if len(parts) != 2 or not all(p.isdigit() for p in parts):
raise ValueError("isize format should be <height>x<width> or a single integer")
return int(parts[0]), int(parts[1])
if token.isdigit():
val = int(token)
return val, val
raise ValueError("isize format should be <height>x<width> or a single integer")
if isinstance(isize, int):
return isize, isize
raise ValueError("isize must be int or string")
def resolve_dimensions(isize: Union[int, str], height: Optional[int], width: Optional[int]) -> Tuple[int, int]:
base_h, base_w = parse_isize(isize)
resolved_h = height if height is not None else base_h
resolved_w = width if width is not None else base_w
return ensure_multiple_of_eight(resolved_h), ensure_multiple_of_eight(resolved_w)
def compute_latent_shape(height: int, width: int, batch_size: int = 1) -> Tuple[int, int, int, int]:
height = ensure_multiple_of_eight(height)
width = ensure_multiple_of_eight(width)
return batch_size, 4, height // 8, width // 8
def prepare_init_image(image_path: str, height: int, width: int) -> Tuple[Image.Image, np.ndarray]:
def convert(img: Image.Image) -> Image.Image:
return img.resize((width, height)).convert("RGB")
image = load_image(image_path, convert_method=convert)
image_show = image.copy()
processed = preprocess(image)
if isinstance(processed, torch.Tensor):
processed = processed.detach().cpu().numpy()
return image_show, processed
def ensure_parent(path: str) -> None:
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
def resolve_with_base(path: str, base_dir: str) -> str:
if os.path.isabs(path) and os.path.exists(path):
return path
candidate = os.path.join(base_dir, path)
if os.path.exists(candidate):
return candidate
return path
def get_prev_timestep(
index: int,
timestep: int,
timesteps: np.ndarray,
self_timesteps: Optional[np.ndarray] = None,
step_index: Optional[List[int]] = None,
) -> int:
if self_timesteps is not None and step_index is not None:
prev_idx = step_index[index] + 1
if prev_idx < len(self_timesteps):
return int(self_timesteps[prev_idx])
return int(timestep)
if index + 1 < len(timesteps):
return int(timesteps[index + 1])
return int(timestep)
def denoise_loop(
latent: np.ndarray,
prompt_embeds: np.ndarray,
time_inputs: np.ndarray,
timesteps: np.ndarray,
unet_session,
alphas_cumprod: np.ndarray,
final_alphas_cumprod: float,
generator: Optional[torch.Generator],
noise_dtype: torch.dtype,
self_timesteps: Optional[np.ndarray] = None,
step_index: Optional[List[int]] = None,
) -> np.ndarray:
if time_inputs.shape[0] < len(timesteps):
raise ValueError("time_input 的步数少于推理步数")
device = torch.device("cpu")
for i, timestep in enumerate(timesteps):
unet_start = time.time()
latent = latent.astype(np.float32)
feeds = {
"sample": latent,
TIME_EMBED_KEY: np.expand_dims(time_inputs[i], axis=0),
"encoder_hidden_states": prompt_embeds,
}
noise_pred = unet_session.run(None, feeds)[0]
print(f"unet once take {(1000 * (time.time() - unet_start)):.1f}ms")
sample = latent
model_output = noise_pred
prev_timestep = get_prev_timestep(i, int(timestep), timesteps, self_timesteps, step_index)
alpha_prod_t = alphas_cumprod[int(timestep)]
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alphas_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
scaled_timestep = int(timestep) * 10
c_skip = 0.5 ** 2 / (scaled_timestep ** 2 + 0.5 ** 2)
c_out = scaled_timestep / (scaled_timestep ** 2 + 0.5 ** 2) ** 0.5
predicted_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
denoised = c_out * predicted_original_sample + c_skip * sample
if i != len(timesteps) - 1:
if noise_dtype == torch.float32 and generator is None:
noise = torch.randn(model_output.shape, device=device, dtype=noise_dtype).cpu().numpy()
else:
noise_tensor = randn_tensor(model_output.shape, generator=generator, device=device, dtype=noise_dtype)
noise = noise_tensor.cpu().numpy()
prev_sample = (alpha_prod_t_prev ** 0.5) * denoised + (beta_prod_t_prev ** 0.5) * noise
else:
prev_sample = denoised
latent = prev_sample.astype(np.float32)
return latent
def get_embeds(
prompt: Union[str, List[str]] = "Portrait of a pretty girl",
tokenizer_dir: str = "./models/tokenizer",
text_encoder_path: str = "./models/text_encoder/sd15_text_encoder_sim.axmodel",
backend: str = "axe",
):
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to("cpu").numpy()
if backend == "axe":
input_ids = input_ids.astype(np.int32)
# print(f"input_ids is \n{input_ids}")
text_encoder = create_session(text_encoder_path, backend)
running_start = time.time()
prompt_embeds_npy = text_encoder.run(None, {"input_ids": input_ids})[0]
print(f"text encoder running take {(1000 * (time.time() - running_start)):.1f}ms")
return prompt_embeds_npy
def get_alphas_cumprod():
betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0).detach().numpy()
final_alphas_cumprod = alphas_cumprod[0]
self_timesteps = np.arange(0, 1000)[::-1].copy().astype(np.int64)
return alphas_cumprod, final_alphas_cumprod, self_timesteps
def main():
args = get_args()
backend = args.backend.lower()
prompt = args.prompt
is_img2img = args.init_image is not None
model_dir = args.model_dir
tokenizer_dir = os.path.join(model_dir, "tokenizer")
text_encoder_dir = os.path.join(model_dir, "text_encoder")
model_suffix = ".axmodel" if backend == "axe" else ".onnx"
text_encoder_path = os.path.join(text_encoder_dir, f"sd15_text_encoder_sim{model_suffix}")
unet_model = os.path.join(model_dir, f"unet{model_suffix}")
vae_decoder_model = os.path.join(model_dir, f"vae_decoder{model_suffix}")
vae_encoder_model = os.path.join(model_dir, f"vae_encoder{model_suffix}")
time_input_default = "time_input_img2img.npy" if is_img2img else "time_input_txt2img.npy"
time_input_path = args.time_input or os.path.join(model_dir, time_input_default)
if args.time_input:
time_input_path = resolve_with_base(args.time_input, model_dir)
init_image_path = None
if is_img2img:
init_image_path = resolve_with_base(args.init_image, model_dir)
height, width = resolve_dimensions(args.isize, args.height, args.width)
print(f"backend: {backend}")
print(f"prompt: {prompt}")
print(f"model_dir: {model_dir}")
print(f"tokenizer_dir: {tokenizer_dir}")
print(f"text_encoder: {text_encoder_path}")
print(f"unet_model: {unet_model}")
print(f"vae_decoder_model: {vae_decoder_model}")
if is_img2img:
# ref prompt: "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
print(f"vae_encoder_model: {vae_encoder_model}")
print(f"init image: {init_image_path}")
print(f"time_input: {time_input_path}")
print(f"image_size: {height}x{width}")
print(f"save_dir: {args.save_dir}")
device = torch.device("cpu")
generator: Optional[torch.Generator] = None
if args.seed is not None:
generator = torch.manual_seed(args.seed)
noise_dtype = torch.float16 if is_img2img else torch.float32
encode_start = time.time()
prompt_embeds_npy = get_embeds(prompt, tokenizer_dir, text_encoder_path, backend)
print(f"text encoder take {(1000 * (time.time() - encode_start)):.1f}ms")
alphas_cumprod, final_alphas_cumprod, _ = get_alphas_cumprod()
load_start = time.time()
vae_encoder_session = None
if is_img2img:
vae_encoder_session = create_session(vae_encoder_model, backend)
unet_session = create_session(unet_model, backend)
vae_decoder_session = create_session(vae_decoder_model, backend)
print(f"load models take {(1000 * (time.time() - load_start)):.1f}ms")
time_input = np.load(time_input_path)
if is_img2img:
init_image_show, init_image_np = prepare_init_image(init_image_path, height, width)
vae_start = time.time()
vae_encoder_inp_name = vae_encoder_session.get_inputs()[0].name
vae_encoder_out = vae_encoder_session.run(None, {vae_encoder_inp_name: init_image_np})[0]
print(f"vae encoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
posterior = DiagonalGaussianDistribution(torch.from_numpy(vae_encoder_out).to(torch.float32))
vae_encode_info = AutoencoderKLOutput(latent_dist=posterior)
if generator is None:
generator = torch.manual_seed(0)
init_latents = retrieve_latents(vae_encode_info, generator=generator)
init_latents = init_latents * 0.18215
init_latents = torch.cat([init_latents], dim=0)
noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=noise_dtype)
timestep_tensor = torch.tensor([int(IMG2IMG_TIMESTEPS[0])], device=device)
init_latents = add_noise(init_latents.to(device), noise, timestep_tensor)
latent = init_latents.detach().cpu().numpy()
timesteps = IMG2IMG_TIMESTEPS
self_timesteps = IMG2IMG_SELF_TIMESTEPS
step_index = IMG2IMG_STEP_INDEX
else:
batch, channels, latent_h, latent_w = compute_latent_shape(height, width)
if generator is None:
latents = torch.randn((batch, channels, latent_h, latent_w), device=device, dtype=torch.float32)
else:
latents = randn_tensor((batch, channels, latent_h, latent_w), generator=generator, device=device, dtype=torch.float32)
latent = latents.cpu().numpy()
init_image_show = None
timesteps = TXT2IMG_TIMESTEPS
self_timesteps = None
step_index = None
unet_loop_start = time.time()
latent = denoise_loop(
latent=latent,
prompt_embeds=prompt_embeds_npy,
time_inputs=time_input,
timesteps=timesteps,
unet_session=unet_session,
alphas_cumprod=alphas_cumprod,
final_alphas_cumprod=final_alphas_cumprod,
generator=generator,
noise_dtype=noise_dtype,
self_timesteps=self_timesteps,
step_index=step_index,
)
print(f"unet loop take {(1000 * (time.time() - unet_loop_start)):.1f}ms")
vae_start = time.time()
latent = latent / 0.18215
vae_decoder_inp_name = vae_decoder_session.get_inputs()[0].name
image = vae_decoder_session.run(None, {vae_decoder_inp_name: latent.astype(np.float32)})[0]
print(f"vae decoder inference take {(1000 * (time.time() - vae_start)):.1f}ms")
save_start = time.time()
image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0)
image_denorm = np.clip(image / 2 + 0.5, 0, 1)
image_uint8 = (image_denorm * 255).round().astype("uint8")
pil_image = Image.fromarray(image_uint8[:, :, :3])
ensure_parent(args.save_dir)
pil_image.save(args.save_dir)
if is_img2img:
grid_path = os.path.splitext(args.save_dir)[0] + "_grid.png"
grid_img = make_image_grid([init_image_show, pil_image], rows=1, cols=2)
ensure_parent(grid_path)
grid_img.save(grid_path)
print(f"grid image saved in {grid_path}")
print(f"save image take {(1000 * (time.time() - save_start)):.1f}ms")
if __name__ == "__main__":
main()