|
|
from typing import List |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from controlnet_aux import OpenposeDetector |
|
|
from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline, |
|
|
UniPCMultistepScheduler) |
|
|
from PIL import Image |
|
|
from util.cache import clear_cuda_and_gc |
|
|
from util.commons import disable_safety_checker, download_image |
|
|
|
|
|
|
|
|
class ControlNet: |
|
|
__current_task_name = "" |
|
|
|
|
|
def load(self, model_dir: str): |
|
|
|
|
|
self.load_canny() |
|
|
|
|
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
|
model_dir, controlnet=self.controlnet, torch_dtype=torch.float16 |
|
|
) |
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
pipe.enable_model_cpu_offload() |
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
disable_safety_checker(pipe) |
|
|
self.pipe = pipe |
|
|
|
|
|
def load_canny(self): |
|
|
if self.__current_task_name == "canny": |
|
|
return |
|
|
canny = ControlNetModel.from_pretrained( |
|
|
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 |
|
|
).to("cuda") |
|
|
self.__current_task_name = "canny" |
|
|
self.controlnet = canny |
|
|
if hasattr(self, "pipe"): |
|
|
self.pipe.controlnet = canny |
|
|
clear_cuda_and_gc() |
|
|
|
|
|
def load_pose(self): |
|
|
if self.__current_task_name == "pose": |
|
|
return |
|
|
pose = ControlNetModel.from_pretrained( |
|
|
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 |
|
|
).to("cuda") |
|
|
self.__current_task_name = "pose" |
|
|
self.controlnet = pose |
|
|
if hasattr(self, "pipe"): |
|
|
self.pipe.controlnet = pose |
|
|
clear_cuda_and_gc() |
|
|
|
|
|
def cleanup(self): |
|
|
self.pipe.controlnet = None |
|
|
self.controlnet = None |
|
|
self.__current_task_name = "" |
|
|
|
|
|
clear_cuda_and_gc() |
|
|
|
|
|
@torch.inference_mode() |
|
|
def process_canny( |
|
|
self, |
|
|
prompt: List[str], |
|
|
imageUrl: str, |
|
|
seed: int, |
|
|
steps: int, |
|
|
negative_prompt: List[str], |
|
|
height: int, |
|
|
width: int, |
|
|
): |
|
|
if self.__current_task_name != "canny": |
|
|
raise Exception("ControlNet is not loaded with canny model") |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
init_image = download_image(imageUrl) |
|
|
init_image = self.__canny_detect_edge(init_image) |
|
|
|
|
|
return self.pipe.__call__( |
|
|
prompt=prompt, |
|
|
image=init_image, |
|
|
guidance_scale=9, |
|
|
num_images_per_prompt=1, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=steps, |
|
|
height=height, |
|
|
width=width, |
|
|
).images |
|
|
|
|
|
@torch.inference_mode() |
|
|
def process_pose( |
|
|
self, |
|
|
prompt: List[str], |
|
|
image: List[Image.Image], |
|
|
seed: int, |
|
|
steps: int, |
|
|
negative_prompt: List[str], |
|
|
height: int, |
|
|
width: int, |
|
|
): |
|
|
if self.__current_task_name != "pose": |
|
|
raise Exception("ControlNet is not loaded with pose model") |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
return self.pipe.__call__( |
|
|
prompt=prompt, |
|
|
image=image, |
|
|
num_images_per_prompt=1, |
|
|
num_inference_steps=steps, |
|
|
negative_prompt=negative_prompt, |
|
|
height=height, |
|
|
width=width, |
|
|
).images |
|
|
|
|
|
def detect_pose(self, imageUrl: str) -> Image.Image: |
|
|
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
|
|
image = download_image(imageUrl) |
|
|
image = detector.__call__(image) |
|
|
return image |
|
|
|
|
|
def __canny_detect_edge(self, image: Image.Image) -> Image.Image: |
|
|
image_array = np.array(image) |
|
|
|
|
|
low_threshold = 100 |
|
|
high_threshold = 200 |
|
|
|
|
|
image_array = cv2.Canny(image_array, low_threshold, high_threshold) |
|
|
image_array = image_array[:, :, None] |
|
|
image_array = np.concatenate([image_array, image_array, image_array], axis=2) |
|
|
canny_image = Image.fromarray(image_array) |
|
|
return canny_image |
|
|
|