from typing import List import cv2 import numpy as np import torch from controlnet_aux import OpenposeDetector from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler) from PIL import Image from util.cache import clear_cuda_and_gc from util.commons import disable_safety_checker, download_image class ControlNet: __current_task_name = "" def load(self, model_dir: str): # we will load canny by default self.load_canny() pipe = StableDiffusionControlNetPipeline.from_pretrained( model_dir, controlnet=self.controlnet, torch_dtype=torch.float16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() disable_safety_checker(pipe) self.pipe = pipe def load_canny(self): if self.__current_task_name == "canny": return canny = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 ).to("cuda") self.__current_task_name = "canny" self.controlnet = canny if hasattr(self, "pipe"): self.pipe.controlnet = canny clear_cuda_and_gc() def load_pose(self): if self.__current_task_name == "pose": return pose = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 ).to("cuda") self.__current_task_name = "pose" self.controlnet = pose if hasattr(self, "pipe"): self.pipe.controlnet = pose clear_cuda_and_gc() def cleanup(self): self.pipe.controlnet = None self.controlnet = None self.__current_task_name = "" clear_cuda_and_gc() @torch.inference_mode() def process_canny( self, prompt: List[str], imageUrl: str, seed: int, steps: int, negative_prompt: List[str], height: int, width: int, ): if self.__current_task_name != "canny": raise Exception("ControlNet is not loaded with canny model") torch.manual_seed(seed) init_image = download_image(imageUrl) init_image = self.__canny_detect_edge(init_image) return self.pipe.__call__( prompt=prompt, image=init_image, guidance_scale=9, num_images_per_prompt=1, negative_prompt=negative_prompt, num_inference_steps=steps, height=height, width=width, ).images @torch.inference_mode() def process_pose( self, prompt: List[str], image: List[Image.Image], seed: int, steps: int, negative_prompt: List[str], height: int, width: int, ): if self.__current_task_name != "pose": raise Exception("ControlNet is not loaded with pose model") torch.manual_seed(seed) return self.pipe.__call__( prompt=prompt, image=image, num_images_per_prompt=1, num_inference_steps=steps, negative_prompt=negative_prompt, height=height, width=width, ).images def detect_pose(self, imageUrl: str) -> Image.Image: detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") image = download_image(imageUrl) image = detector.__call__(image) return image def __canny_detect_edge(self, image: Image.Image) -> Image.Image: image_array = np.array(image) low_threshold = 100 high_threshold = 200 image_array = cv2.Canny(image_array, low_threshold, high_threshold) image_array = image_array[:, :, None] image_array = np.concatenate([image_array, image_array, image_array], axis=2) canny_image = Image.fromarray(image_array) return canny_image