model-sd-multi / pipelines /controlnets.py
jayparmr's picture
Upload 18 files
4adca93
from typing import List
import cv2
import numpy as np
import torch
from controlnet_aux import OpenposeDetector
from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
from PIL import Image
from util.cache import clear_cuda_and_gc
from util.commons import disable_safety_checker, download_image
class ControlNet:
__current_task_name = ""
def load(self, model_dir: str):
# we will load canny by default
self.load_canny()
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_dir, controlnet=self.controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
disable_safety_checker(pipe)
self.pipe = pipe
def load_canny(self):
if self.__current_task_name == "canny":
return
canny = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "canny"
self.controlnet = canny
if hasattr(self, "pipe"):
self.pipe.controlnet = canny
clear_cuda_and_gc()
def load_pose(self):
if self.__current_task_name == "pose":
return
pose = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "pose"
self.controlnet = pose
if hasattr(self, "pipe"):
self.pipe.controlnet = pose
clear_cuda_and_gc()
def cleanup(self):
self.pipe.controlnet = None
self.controlnet = None
self.__current_task_name = ""
clear_cuda_and_gc()
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
steps: int,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl)
init_image = self.__canny_detect_edge(init_image)
return self.pipe.__call__(
prompt=prompt,
image=init_image,
guidance_scale=9,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
num_inference_steps=steps,
height=height,
width=width,
).images
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
steps: int,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
return self.pipe.__call__(
prompt=prompt,
image=image,
num_images_per_prompt=1,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
).images
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image)
return image
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image