| import copy |
| import math |
| import os |
| from glob import glob |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import cv2 |
| import numpy as np |
| import streamlit as st |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as TT |
| from einops import rearrange, repeat |
| from imwatermark import WatermarkEncoder |
| from omegaconf import ListConfig, OmegaConf |
| from PIL import Image |
| from safetensors.torch import load_file as load_safetensors |
| from torch import autocast |
| from torchvision import transforms |
| from torchvision.utils import make_grid, save_image |
|
|
| from scripts.demo.discretization import (Img2ImgDiscretizationWrapper, |
| Txt2NoisyDiscretizationWrapper) |
| from scripts.util.detection.nsfw_and_watermark_dectection import \ |
| DeepFloydDataFiltering |
| from sgm.inference.helpers import embed_watermark |
| from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider, |
| VanillaCFG) |
| from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, |
| DPMPP2SAncestralSampler, |
| EulerAncestralSampler, |
| EulerEDMSampler, |
| HeunEDMSampler, |
| LinearMultistepSampler) |
| from sgm.util import append_dims, default, instantiate_from_config |
|
|
|
|
| @st.cache_resource() |
| def init_st(version_dict, load_ckpt=True, load_filter=True): |
| state = dict() |
| if not "model" in state: |
| config = version_dict["config"] |
| ckpt = version_dict["ckpt"] |
|
|
| config = OmegaConf.load(config) |
| model, msg = load_model_from_config(config, ckpt if load_ckpt else None) |
|
|
| state["msg"] = msg |
| state["model"] = model |
| state["ckpt"] = ckpt if load_ckpt else None |
| state["config"] = config |
| if load_filter: |
| state["filter"] = DeepFloydDataFiltering(verbose=False) |
| return state |
|
|
|
|
| def load_model(model): |
| model.cuda() |
|
|
|
|
| lowvram_mode = False |
|
|
|
|
| def set_lowvram_mode(mode): |
| global lowvram_mode |
| lowvram_mode = mode |
|
|
|
|
| def initial_model_load(model): |
| global lowvram_mode |
| if lowvram_mode: |
| model.model.half() |
| else: |
| model.cuda() |
| return model |
|
|
|
|
| def unload_model(model): |
| global lowvram_mode |
| if lowvram_mode: |
| model.cpu() |
| torch.cuda.empty_cache() |
|
|
|
|
| def load_model_from_config(config, ckpt=None, verbose=True): |
| model = instantiate_from_config(config.model) |
|
|
| if ckpt is not None: |
| print(f"Loading model from {ckpt}") |
| if ckpt.endswith("ckpt"): |
| pl_sd = torch.load(ckpt, map_location="cpu") |
| if "global_step" in pl_sd: |
| global_step = pl_sd["global_step"] |
| st.info(f"loaded ckpt from global step {global_step}") |
| print(f"Global Step: {pl_sd['global_step']}") |
| sd = pl_sd["state_dict"] |
| elif ckpt.endswith("safetensors"): |
| sd = load_safetensors(ckpt) |
| else: |
| raise NotImplementedError |
|
|
| msg = None |
|
|
| m, u = model.load_state_dict(sd, strict=False) |
|
|
| if len(m) > 0 and verbose: |
| print("missing keys:") |
| print(m) |
| if len(u) > 0 and verbose: |
| print("unexpected keys:") |
| print(u) |
| else: |
| msg = None |
|
|
| model = initial_model_load(model) |
| model.eval() |
| return model, msg |
|
|
|
|
| def get_unique_embedder_keys_from_conditioner(conditioner): |
| return list(set([x.input_key for x in conditioner.embedders])) |
|
|
|
|
| def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): |
| |
|
|
| value_dict = {} |
| for key in keys: |
| if key == "txt": |
| if prompt is None: |
| prompt = "A professional photograph of an astronaut riding a pig" |
| if negative_prompt is None: |
| negative_prompt = "" |
|
|
| prompt = st.text_input("Prompt", prompt) |
| negative_prompt = st.text_input("Negative prompt", negative_prompt) |
|
|
| value_dict["prompt"] = prompt |
| value_dict["negative_prompt"] = negative_prompt |
|
|
| if key == "original_size_as_tuple": |
| orig_width = st.number_input( |
| "orig_width", |
| value=init_dict["orig_width"], |
| min_value=16, |
| ) |
| orig_height = st.number_input( |
| "orig_height", |
| value=init_dict["orig_height"], |
| min_value=16, |
| ) |
|
|
| value_dict["orig_width"] = orig_width |
| value_dict["orig_height"] = orig_height |
|
|
| if key == "crop_coords_top_left": |
| crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) |
| crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) |
|
|
| value_dict["crop_coords_top"] = crop_coord_top |
| value_dict["crop_coords_left"] = crop_coord_left |
|
|
| if key == "aesthetic_score": |
| value_dict["aesthetic_score"] = 6.0 |
| value_dict["negative_aesthetic_score"] = 2.5 |
|
|
| if key == "target_size_as_tuple": |
| value_dict["target_width"] = init_dict["target_width"] |
| value_dict["target_height"] = init_dict["target_height"] |
|
|
| if key in ["fps_id", "fps"]: |
| fps = st.number_input("fps", value=6, min_value=1) |
|
|
| value_dict["fps"] = fps |
| value_dict["fps_id"] = fps - 1 |
|
|
| if key == "motion_bucket_id": |
| mb_id = st.number_input("motion bucket id", 0, 511, value=127) |
| value_dict["motion_bucket_id"] = mb_id |
|
|
| if key == "pool_image": |
| st.text("Image for pool conditioning") |
| image = load_img( |
| key="pool_image_input", |
| size=224, |
| center_crop=True, |
| ) |
| if image is None: |
| st.info("Need an image here") |
| image = torch.zeros(1, 3, 224, 224) |
| value_dict["pool_image"] = image |
|
|
| return value_dict |
|
|
|
|
| def perform_save_locally(save_path, samples): |
| os.makedirs(os.path.join(save_path), exist_ok=True) |
| base_count = len(os.listdir(os.path.join(save_path))) |
| samples = embed_watermark(samples) |
| for sample in samples: |
| sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") |
| Image.fromarray(sample.astype(np.uint8)).save( |
| os.path.join(save_path, f"{base_count:09}.png") |
| ) |
| base_count += 1 |
|
|
|
|
| def init_save_locally(_dir, init_value: bool = False): |
| save_locally = st.sidebar.checkbox("Save images locally", value=init_value) |
| if save_locally: |
| save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) |
| else: |
| save_path = None |
|
|
| return save_locally, save_path |
|
|
|
|
| def get_guider(options, key): |
| guider = st.sidebar.selectbox( |
| f"Discretization #{key}", |
| [ |
| "VanillaCFG", |
| "IdentityGuider", |
| "LinearPredictionGuider", |
| ], |
| options.get("guider", 0), |
| ) |
|
|
| additional_guider_kwargs = options.pop("additional_guider_kwargs", {}) |
|
|
| if guider == "IdentityGuider": |
| guider_config = { |
| "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" |
| } |
| elif guider == "VanillaCFG": |
| scale_schedule = st.sidebar.selectbox( |
| f"Scale schedule #{key}", |
| ["Identity", "Oscillating"], |
| ) |
|
|
| if scale_schedule == "Identity": |
| scale = st.number_input( |
| f"cfg-scale #{key}", |
| value=options.get("cfg", 5.0), |
| min_value=0.0, |
| ) |
|
|
| scale_schedule_config = { |
| "target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule", |
| "params": {"scale": scale}, |
| } |
|
|
| elif scale_schedule == "Oscillating": |
| small_scale = st.number_input( |
| f"small cfg-scale #{key}", |
| value=4.0, |
| min_value=0.0, |
| ) |
|
|
| large_scale = st.number_input( |
| f"large cfg-scale #{key}", |
| value=16.0, |
| min_value=0.0, |
| ) |
|
|
| sigma_cutoff = st.number_input( |
| f"sigma cutoff #{key}", |
| value=1.0, |
| min_value=0.0, |
| ) |
|
|
| scale_schedule_config = { |
| "target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule", |
| "params": { |
| "small_scale": small_scale, |
| "large_scale": large_scale, |
| "sigma_cutoff": sigma_cutoff, |
| }, |
| } |
| else: |
| raise NotImplementedError |
|
|
| guider_config = { |
| "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", |
| "params": { |
| "scale_schedule_config": scale_schedule_config, |
| **additional_guider_kwargs, |
| }, |
| } |
| elif guider == "LinearPredictionGuider": |
| max_scale = st.number_input( |
| f"max-cfg-scale #{key}", |
| value=options.get("cfg", 1.5), |
| min_value=1.0, |
| ) |
| min_scale = st.number_input( |
| f"min guidance scale", |
| value=options.get("min_cfg", 1.0), |
| min_value=1.0, |
| max_value=10.0, |
| ) |
|
|
| guider_config = { |
| "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider", |
| "params": { |
| "max_scale": max_scale, |
| "min_scale": min_scale, |
| "num_frames": options["num_frames"], |
| **additional_guider_kwargs, |
| }, |
| } |
| else: |
| raise NotImplementedError |
| return guider_config |
|
|
|
|
| def init_sampling( |
| key=1, |
| img2img_strength: Optional[float] = None, |
| specify_num_samples: bool = True, |
| stage2strength: Optional[float] = None, |
| options: Optional[Dict[str, int]] = None, |
| ): |
| options = {} if options is None else options |
|
|
| num_rows, num_cols = 1, 1 |
| if specify_num_samples: |
| num_cols = st.number_input( |
| f"num cols #{key}", value=num_cols, min_value=1, max_value=10 |
| ) |
|
|
| steps = st.sidebar.number_input( |
| f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000 |
| ) |
| sampler = st.sidebar.selectbox( |
| f"Sampler #{key}", |
| [ |
| "EulerEDMSampler", |
| "HeunEDMSampler", |
| "EulerAncestralSampler", |
| "DPMPP2SAncestralSampler", |
| "DPMPP2MSampler", |
| "LinearMultistepSampler", |
| ], |
| options.get("sampler", 0), |
| ) |
| discretization = st.sidebar.selectbox( |
| f"Discretization #{key}", |
| [ |
| "LegacyDDPMDiscretization", |
| "EDMDiscretization", |
| ], |
| options.get("discretization", 0), |
| ) |
|
|
| discretization_config = get_discretization(discretization, options=options, key=key) |
|
|
| guider_config = get_guider(options=options, key=key) |
|
|
| sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) |
| if img2img_strength is not None: |
| st.warning( |
| f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" |
| ) |
| sampler.discretization = Img2ImgDiscretizationWrapper( |
| sampler.discretization, strength=img2img_strength |
| ) |
| if stage2strength is not None: |
| sampler.discretization = Txt2NoisyDiscretizationWrapper( |
| sampler.discretization, strength=stage2strength, original_steps=steps |
| ) |
| return sampler, num_rows, num_cols |
|
|
|
|
| def get_discretization(discretization, options, key=1): |
| if discretization == "LegacyDDPMDiscretization": |
| discretization_config = { |
| "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", |
| } |
| elif discretization == "EDMDiscretization": |
| sigma_min = st.number_input( |
| f"sigma_min #{key}", value=options.get("sigma_min", 0.03) |
| ) |
| sigma_max = st.number_input( |
| f"sigma_max #{key}", value=options.get("sigma_max", 14.61) |
| ) |
| rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0)) |
| discretization_config = { |
| "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", |
| "params": { |
| "sigma_min": sigma_min, |
| "sigma_max": sigma_max, |
| "rho": rho, |
| }, |
| } |
|
|
| return discretization_config |
|
|
|
|
| def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): |
| if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": |
| s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) |
| s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) |
| s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) |
| s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) |
|
|
| if sampler_name == "EulerEDMSampler": |
| sampler = EulerEDMSampler( |
| num_steps=steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| s_churn=s_churn, |
| s_tmin=s_tmin, |
| s_tmax=s_tmax, |
| s_noise=s_noise, |
| verbose=True, |
| ) |
| elif sampler_name == "HeunEDMSampler": |
| sampler = HeunEDMSampler( |
| num_steps=steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| s_churn=s_churn, |
| s_tmin=s_tmin, |
| s_tmax=s_tmax, |
| s_noise=s_noise, |
| verbose=True, |
| ) |
| elif ( |
| sampler_name == "EulerAncestralSampler" |
| or sampler_name == "DPMPP2SAncestralSampler" |
| ): |
| s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) |
| eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) |
|
|
| if sampler_name == "EulerAncestralSampler": |
| sampler = EulerAncestralSampler( |
| num_steps=steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| eta=eta, |
| s_noise=s_noise, |
| verbose=True, |
| ) |
| elif sampler_name == "DPMPP2SAncestralSampler": |
| sampler = DPMPP2SAncestralSampler( |
| num_steps=steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| eta=eta, |
| s_noise=s_noise, |
| verbose=True, |
| ) |
| elif sampler_name == "DPMPP2MSampler": |
| sampler = DPMPP2MSampler( |
| num_steps=steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| verbose=True, |
| ) |
| elif sampler_name == "LinearMultistepSampler": |
| order = st.sidebar.number_input("order", value=4, min_value=1) |
| sampler = LinearMultistepSampler( |
| num_steps=steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| order=order, |
| verbose=True, |
| ) |
| else: |
| raise ValueError(f"unknown sampler {sampler_name}!") |
|
|
| return sampler |
|
|
|
|
| def get_interactive_image() -> Image.Image: |
| image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) |
| if image is not None: |
| image = Image.open(image) |
| if not image.mode == "RGB": |
| image = image.convert("RGB") |
| return image |
|
|
|
|
| def load_img( |
| display: bool = True, |
| size: Union[None, int, Tuple[int, int]] = None, |
| center_crop: bool = False, |
| ): |
| image = get_interactive_image() |
| if image is None: |
| return None |
| if display: |
| st.image(image) |
| w, h = image.size |
| print(f"loaded input image of size ({w}, {h})") |
|
|
| transform = [] |
| if size is not None: |
| transform.append(transforms.Resize(size)) |
| if center_crop: |
| transform.append(transforms.CenterCrop(size)) |
| transform.append(transforms.ToTensor()) |
| transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0)) |
|
|
| transform = transforms.Compose(transform) |
| img = transform(image)[None, ...] |
| st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") |
| return img |
|
|
|
|
| def get_init_img(batch_size=1, key=None): |
| init_image = load_img(key=key).cuda() |
| init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) |
| return init_image |
|
|
|
|
| def do_sample( |
| model, |
| sampler, |
| value_dict, |
| num_samples, |
| H, |
| W, |
| C, |
| F, |
| force_uc_zero_embeddings: Optional[List] = None, |
| force_cond_zero_embeddings: Optional[List] = None, |
| batch2model_input: List = None, |
| return_latents=False, |
| filter=None, |
| T=None, |
| additional_batch_uc_fields=None, |
| decoding_t=None, |
| ): |
| force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) |
| batch2model_input = default(batch2model_input, []) |
| additional_batch_uc_fields = default(additional_batch_uc_fields, []) |
|
|
| st.text("Sampling") |
|
|
| outputs = st.empty() |
| precision_scope = autocast |
| with torch.no_grad(): |
| with precision_scope("cuda"): |
| with model.ema_scope(): |
| if T is not None: |
| num_samples = [num_samples, T] |
| else: |
| num_samples = [num_samples] |
|
|
| load_model(model.conditioner) |
| batch, batch_uc = get_batch( |
| get_unique_embedder_keys_from_conditioner(model.conditioner), |
| value_dict, |
| num_samples, |
| T=T, |
| additional_batch_uc_fields=additional_batch_uc_fields, |
| ) |
|
|
| c, uc = model.conditioner.get_unconditional_conditioning( |
| batch, |
| batch_uc=batch_uc, |
| force_uc_zero_embeddings=force_uc_zero_embeddings, |
| force_cond_zero_embeddings=force_cond_zero_embeddings, |
| ) |
| unload_model(model.conditioner) |
|
|
| for k in c: |
| if not k == "crossattn": |
| c[k], uc[k] = map( |
| lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) |
| ) |
| if k in ["crossattn", "concat"] and T is not None: |
| uc[k] = repeat(uc[k], "b ... -> b t ...", t=T) |
| uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T) |
| c[k] = repeat(c[k], "b ... -> b t ...", t=T) |
| c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T) |
|
|
| additional_model_inputs = {} |
| for k in batch2model_input: |
| if k == "image_only_indicator": |
| assert T is not None |
|
|
| if isinstance( |
| sampler.guider, (VanillaCFG, LinearPredictionGuider) |
| ): |
| additional_model_inputs[k] = torch.zeros( |
| num_samples[0] * 2, num_samples[1] |
| ).to("cuda") |
| else: |
| additional_model_inputs[k] = torch.zeros(num_samples).to( |
| "cuda" |
| ) |
| else: |
| additional_model_inputs[k] = batch[k] |
|
|
| shape = (math.prod(num_samples), C, H // F, W // F) |
| randn = torch.randn(shape).to("cuda") |
|
|
| def denoiser(input, sigma, c): |
| return model.denoiser( |
| model.model, input, sigma, c, **additional_model_inputs |
| ) |
|
|
| load_model(model.denoiser) |
| load_model(model.model) |
| samples_z = sampler(denoiser, randn, cond=c, uc=uc) |
| unload_model(model.model) |
| unload_model(model.denoiser) |
|
|
| load_model(model.first_stage_model) |
| model.en_and_decode_n_samples_a_time = ( |
| decoding_t |
| ) |
| samples_x = model.decode_first_stage(samples_z) |
| samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
| unload_model(model.first_stage_model) |
|
|
| if filter is not None: |
| samples = filter(samples) |
|
|
| if T is None: |
| grid = torch.stack([samples]) |
| grid = rearrange(grid, "n b c h w -> (n h) (b w) c") |
| outputs.image(grid.cpu().numpy()) |
| else: |
| as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T) |
| for i, vid in enumerate(as_vids): |
| grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c") |
| st.image( |
| grid.cpu().numpy(), |
| f"Sample #{i} as image", |
| ) |
|
|
| if return_latents: |
| return samples, samples_z |
| return samples |
|
|
|
|
| def get_batch( |
| keys, |
| value_dict: dict, |
| N: Union[List, ListConfig], |
| device: str = "cuda", |
| T: int = None, |
| additional_batch_uc_fields: List[str] = [], |
| ): |
| |
|
|
| batch = {} |
| batch_uc = {} |
|
|
| for key in keys: |
| if key == "txt": |
| batch["txt"] = [value_dict["prompt"]] * math.prod(N) |
|
|
| batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N) |
|
|
| elif key == "original_size_as_tuple": |
| batch["original_size_as_tuple"] = ( |
| torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) |
| .to(device) |
| .repeat(math.prod(N), 1) |
| ) |
| elif key == "crop_coords_top_left": |
| batch["crop_coords_top_left"] = ( |
| torch.tensor( |
| [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] |
| ) |
| .to(device) |
| .repeat(math.prod(N), 1) |
| ) |
| elif key == "aesthetic_score": |
| batch["aesthetic_score"] = ( |
| torch.tensor([value_dict["aesthetic_score"]]) |
| .to(device) |
| .repeat(math.prod(N), 1) |
| ) |
| batch_uc["aesthetic_score"] = ( |
| torch.tensor([value_dict["negative_aesthetic_score"]]) |
| .to(device) |
| .repeat(math.prod(N), 1) |
| ) |
|
|
| elif key == "target_size_as_tuple": |
| batch["target_size_as_tuple"] = ( |
| torch.tensor([value_dict["target_height"], value_dict["target_width"]]) |
| .to(device) |
| .repeat(math.prod(N), 1) |
| ) |
| elif key == "fps": |
| batch[key] = ( |
| torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) |
| ) |
| elif key == "fps_id": |
| batch[key] = ( |
| torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) |
| ) |
| elif key == "motion_bucket_id": |
| batch[key] = ( |
| torch.tensor([value_dict["motion_bucket_id"]]) |
| .to(device) |
| .repeat(math.prod(N)) |
| ) |
| elif key == "pool_image": |
| batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( |
| device, dtype=torch.half |
| ) |
| elif key == "cond_aug": |
| batch[key] = repeat( |
| torch.tensor([value_dict["cond_aug"]]).to("cuda"), |
| "1 -> b", |
| b=math.prod(N), |
| ) |
| elif key == "cond_frames": |
| batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) |
| elif key == "cond_frames_without_noise": |
| batch[key] = repeat( |
| value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] |
| ) |
| else: |
| batch[key] = value_dict[key] |
|
|
| if T is not None: |
| batch["num_video_frames"] = T |
|
|
| for key in batch.keys(): |
| if key not in batch_uc and isinstance(batch[key], torch.Tensor): |
| batch_uc[key] = torch.clone(batch[key]) |
| elif key in additional_batch_uc_fields and key not in batch_uc: |
| batch_uc[key] = copy.copy(batch[key]) |
| return batch, batch_uc |
|
|
|
|
| @torch.no_grad() |
| def do_img2img( |
| img, |
| model, |
| sampler, |
| value_dict, |
| num_samples, |
| force_uc_zero_embeddings: Optional[List] = None, |
| force_cond_zero_embeddings: Optional[List] = None, |
| additional_kwargs={}, |
| offset_noise_level: int = 0.0, |
| return_latents=False, |
| skip_encode=False, |
| filter=None, |
| add_noise=True, |
| ): |
| st.text("Sampling") |
|
|
| outputs = st.empty() |
| precision_scope = autocast |
| with torch.no_grad(): |
| with precision_scope("cuda"): |
| with model.ema_scope(): |
| load_model(model.conditioner) |
| batch, batch_uc = get_batch( |
| get_unique_embedder_keys_from_conditioner(model.conditioner), |
| value_dict, |
| [num_samples], |
| ) |
| c, uc = model.conditioner.get_unconditional_conditioning( |
| batch, |
| batch_uc=batch_uc, |
| force_uc_zero_embeddings=force_uc_zero_embeddings, |
| force_cond_zero_embeddings=force_cond_zero_embeddings, |
| ) |
| unload_model(model.conditioner) |
| for k in c: |
| c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) |
|
|
| for k in additional_kwargs: |
| c[k] = uc[k] = additional_kwargs[k] |
| if skip_encode: |
| z = img |
| else: |
| load_model(model.first_stage_model) |
| z = model.encode_first_stage(img) |
| unload_model(model.first_stage_model) |
|
|
| noise = torch.randn_like(z) |
|
|
| sigmas = sampler.discretization(sampler.num_steps).cuda() |
| sigma = sigmas[0] |
|
|
| st.info(f"all sigmas: {sigmas}") |
| st.info(f"noising sigma: {sigma}") |
| if offset_noise_level > 0.0: |
| noise = noise + offset_noise_level * append_dims( |
| torch.randn(z.shape[0], device=z.device), z.ndim |
| ) |
| if add_noise: |
| noised_z = z + noise * append_dims(sigma, z.ndim).cuda() |
| noised_z = noised_z / torch.sqrt( |
| 1.0 + sigmas[0] ** 2.0 |
| ) |
| else: |
| noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) |
|
|
| def denoiser(x, sigma, c): |
| return model.denoiser(model.model, x, sigma, c) |
|
|
| load_model(model.denoiser) |
| load_model(model.model) |
| samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) |
| unload_model(model.model) |
| unload_model(model.denoiser) |
|
|
| load_model(model.first_stage_model) |
| samples_x = model.decode_first_stage(samples_z) |
| unload_model(model.first_stage_model) |
| samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
| if filter is not None: |
| samples = filter(samples) |
|
|
| grid = rearrange(grid, "n b c h w -> (n h) (b w) c") |
| outputs.image(grid.cpu().numpy()) |
| if return_latents: |
| return samples, samples_z |
| return samples |
|
|
|
|
| def get_resizing_factor( |
| desired_shape: Tuple[int, int], current_shape: Tuple[int, int] |
| ) -> float: |
| r_bound = desired_shape[1] / desired_shape[0] |
| aspect_r = current_shape[1] / current_shape[0] |
| if r_bound >= 1.0: |
| if aspect_r >= r_bound: |
| factor = min(desired_shape) / min(current_shape) |
| else: |
| if aspect_r < 1.0: |
| factor = max(desired_shape) / min(current_shape) |
| else: |
| factor = max(desired_shape) / max(current_shape) |
| else: |
| if aspect_r <= r_bound: |
| factor = min(desired_shape) / min(current_shape) |
| else: |
| if aspect_r > 1: |
| factor = max(desired_shape) / min(current_shape) |
| else: |
| factor = max(desired_shape) / max(current_shape) |
|
|
| return factor |
|
|
|
|
| def get_interactive_image(key=None) -> Image.Image: |
| image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) |
| if image is not None: |
| image = Image.open(image) |
| if not image.mode == "RGB": |
| image = image.convert("RGB") |
| return image |
|
|
|
|
| def load_img_for_prediction( |
| W: int, H: int, display=True, key=None, device="cuda" |
| ) -> torch.Tensor: |
| image = get_interactive_image(key=key) |
| if image is None: |
| return None |
| if display: |
| st.image(image) |
| w, h = image.size |
|
|
| image = np.array(image).transpose(2, 0, 1) |
| image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 |
| image = image.unsqueeze(0) |
|
|
| rfs = get_resizing_factor((H, W), (h, w)) |
| resize_size = [int(np.ceil(rfs * s)) for s in (h, w)] |
| top = (resize_size[0] - H) // 2 |
| left = (resize_size[1] - W) // 2 |
|
|
| image = torch.nn.functional.interpolate( |
| image, resize_size, mode="area", antialias=False |
| ) |
| image = TT.functional.crop(image, top=top, left=left, height=H, width=W) |
|
|
| if display: |
| numpy_img = np.transpose(image[0].numpy(), (1, 2, 0)) |
| pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8)) |
| st.image(pil_image) |
| return image.to(device) * 2.0 - 1.0 |
|
|
|
|
| def save_video_as_grid_and_mp4( |
| video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5 |
| ): |
| os.makedirs(save_path, exist_ok=True) |
| base_count = len(glob(os.path.join(save_path, "*.mp4"))) |
|
|
| video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T) |
| video_batch = embed_watermark(video_batch) |
| for vid in video_batch: |
| save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4) |
|
|
| video_path = os.path.join(save_path, f"{base_count:06d}.mp4") |
|
|
| writer = cv2.VideoWriter( |
| video_path, |
| cv2.VideoWriter_fourcc(*"MP4V"), |
| fps, |
| (vid.shape[-1], vid.shape[-2]), |
| ) |
|
|
| vid = ( |
| (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) |
| ) |
| for frame in vid: |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
| writer.write(frame) |
|
|
| writer.release() |
|
|
| video_path_h264 = video_path[:-4] + "_h264.mp4" |
| os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}") |
|
|
| with open(video_path_h264, "rb") as f: |
| video_bytes = f.read() |
| st.video(video_bytes) |
|
|
| base_count += 1 |
|
|