| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | from torchvision.transforms import GaussianBlur |
| |
|
| |
|
| |
|
| | class BasePipeline(torch.nn.Module): |
| |
|
| | def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64): |
| | super().__init__() |
| | self.device = device |
| | self.torch_dtype = torch_dtype |
| | self.height_division_factor = height_division_factor |
| | self.width_division_factor = width_division_factor |
| | self.cpu_offload = False |
| | self.model_names = [] |
| |
|
| |
|
| | def check_resize_height_width(self, height, width): |
| | if height % self.height_division_factor != 0: |
| | height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor |
| | print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.") |
| | if width % self.width_division_factor != 0: |
| | width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor |
| | print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.") |
| | return height, width |
| |
|
| |
|
| | def preprocess_image(self, image): |
| | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) |
| | return image |
| | |
| |
|
| | def preprocess_images(self, images): |
| | return [self.preprocess_image(image) for image in images] |
| | |
| |
|
| | def vae_output_to_image(self, vae_output): |
| | image = vae_output[0].cpu().float().permute(1, 2, 0).numpy() |
| | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) |
| | return image |
| | |
| |
|
| | def vae_output_to_video(self, vae_output): |
| | video = vae_output.cpu().permute(1, 2, 0).numpy() |
| | video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] |
| | return video |
| |
|
| | |
| | def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0): |
| | if len(latents) > 0: |
| | blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) |
| | height, width = value.shape[-2:] |
| | weight = torch.ones_like(value) |
| | for latent, mask, scale in zip(latents, masks, scales): |
| | mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 |
| | mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) |
| | mask = blur(mask) |
| | value += latent * mask * scale |
| | weight += mask * scale |
| | value /= weight |
| | return value |
| |
|
| |
|
| | def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None): |
| | if special_kwargs is None: |
| | noise_pred_global = inference_callback(prompt_emb_global) |
| | else: |
| | noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) |
| | if special_local_kwargs_list is None: |
| | noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] |
| | else: |
| | noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)] |
| | noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) |
| | return noise_pred |
| | |
| |
|
| | def extend_prompt(self, prompt, local_prompts, masks, mask_scales): |
| | local_prompts = local_prompts or [] |
| | masks = masks or [] |
| | mask_scales = mask_scales or [] |
| | extended_prompt_dict = self.prompter.extend_prompt(prompt) |
| | prompt = extended_prompt_dict.get("prompt", prompt) |
| | local_prompts += extended_prompt_dict.get("prompts", []) |
| | masks += extended_prompt_dict.get("masks", []) |
| | mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) |
| | return prompt, local_prompts, masks, mask_scales |
| | |
| | |
| | def enable_cpu_offload(self): |
| | self.cpu_offload = True |
| |
|
| |
|
| | def load_models_to_device(self, loadmodel_names=[]): |
| | |
| | if not self.cpu_offload: |
| | return |
| | |
| | for model_name in self.model_names: |
| | if model_name not in loadmodel_names: |
| | model = getattr(self, model_name) |
| | if model is not None: |
| | if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: |
| | for module in model.modules(): |
| | if hasattr(module, "offload"): |
| | module.offload() |
| | else: |
| | model.cpu() |
| | |
| | for model_name in loadmodel_names: |
| | model = getattr(self, model_name) |
| | if model is not None: |
| | if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: |
| | for module in model.modules(): |
| | if hasattr(module, "onload"): |
| | module.onload() |
| | else: |
| | model.to(self.device) |
| | |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): |
| | generator = None if seed is None else torch.Generator(device).manual_seed(seed) |
| | noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) |
| | return noise |
| |
|