File size: 4,085 Bytes
4adca93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import List
import cv2
import numpy as np
import torch
from controlnet_aux import OpenposeDetector
from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
from PIL import Image
from util.cache import clear_cuda_and_gc
from util.commons import disable_safety_checker, download_image
class ControlNet:
__current_task_name = ""
def load(self, model_dir: str):
# we will load canny by default
self.load_canny()
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_dir, controlnet=self.controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
disable_safety_checker(pipe)
self.pipe = pipe
def load_canny(self):
if self.__current_task_name == "canny":
return
canny = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "canny"
self.controlnet = canny
if hasattr(self, "pipe"):
self.pipe.controlnet = canny
clear_cuda_and_gc()
def load_pose(self):
if self.__current_task_name == "pose":
return
pose = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "pose"
self.controlnet = pose
if hasattr(self, "pipe"):
self.pipe.controlnet = pose
clear_cuda_and_gc()
def cleanup(self):
self.pipe.controlnet = None
self.controlnet = None
self.__current_task_name = ""
clear_cuda_and_gc()
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
steps: int,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl)
init_image = self.__canny_detect_edge(init_image)
return self.pipe.__call__(
prompt=prompt,
image=init_image,
guidance_scale=9,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
num_inference_steps=steps,
height=height,
width=width,
).images
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
steps: int,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
return self.pipe.__call__(
prompt=prompt,
image=image,
num_images_per_prompt=1,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
).images
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image)
return image
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
|