| from typing import Dict, List, Any | |
| import torch | |
| from torch import autocast | |
| from diffusers import StableDiffusionPipeline | |
| import base64 | |
| from io import BytesIO | |
| from typing import List, Optional | |
| import torch | |
| from data.dataAccessor import update_db | |
| from data.task import Task, TaskType | |
| from pipelines.commons import Img2Img, Text2Img | |
| from pipelines.controlnets import ControlNet | |
| from pipelines.prompt_modifier import PromptModifier | |
| from util.cache import auto_clear_cuda_and_gc, clear_cuda | |
| from util.commons import add_code_names, pickPoses, upload_images | |
| from util.lora_style import LoraStyle | |
| from util.slack import Slack | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| num_return_sequences = 4 # the number of results to generate | |
| auto_mode = False | |
| prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) | |
| lora_style = LoraStyle() | |
| slack = Slack() | |
| # def get_patched_prompt(task: Task): | |
| # def add_style_and_character(prompt: List[str]): | |
| # for i in range(len(prompt)): | |
| # prompt[i] = add_code_names(prompt[i]) | |
| # prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) | |
| # prompt = task.get_prompt() | |
| # if task.is_prompt_engineering(): | |
| # prompt = prompt_modifier.modify(prompt) | |
| # else: | |
| # prompt = [prompt] * num_return_sequences | |
| # ori_prompt = [task.get_prompt()] * num_return_sequences | |
| # add_style_and_character(ori_prompt) | |
| # add_style_and_character(prompt) | |
| # print({"prompts": prompt}) | |
| # return (prompt, ori_prompt) | |
| # # @update_db | |
| # @auto_clear_cuda_and_gc(controlnet) | |
| # @slack.auto_send_alert | |
| # def canny(task: Task): | |
| # prompt, _ = get_patched_prompt(task) | |
| # controlnet.load_canny() | |
| # lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) | |
| # lora_patcher.patch() | |
| # images = controlnet.process_canny( | |
| # prompt=prompt, | |
| # imageUrl=task.get_imageUrl(), | |
| # seed=task.get_seed(), | |
| # steps=task.get_steps(), | |
| # width=task.get_width(), | |
| # height=task.get_height(), | |
| # negative_prompt=[ | |
| # f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" | |
| # ] | |
| # * num_return_sequences, | |
| # **lora_patcher.kwargs(), | |
| # ) | |
| # generated_image_urls = upload_images(images, "_canny", task.get_taskId()) | |
| # lora_patcher.cleanup() | |
| # controlnet.cleanup() | |
| # return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} | |
| # # @update_db | |
| # @auto_clear_cuda_and_gc(controlnet) | |
| # @slack.auto_send_alert | |
| # def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): | |
| # prompt, _ = get_patched_prompt(task) | |
| # controlnet.load_pose() | |
| # lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) | |
| # lora_patcher.patch() | |
| # if poses is None: | |
| # poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences | |
| # images = controlnet.process_pose( | |
| # prompt=prompt, | |
| # image=poses, | |
| # seed=task.get_seed(), | |
| # steps=task.get_steps(), | |
| # negative_prompt=[task.get_negative_prompt()] * num_return_sequences, | |
| # width=task.get_width(), | |
| # height=task.get_height(), | |
| # **lora_patcher.kwargs(), | |
| # ) | |
| # generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) | |
| # lora_patcher.cleanup() | |
| # controlnet.cleanup() | |
| # return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} | |
| # @update_db | |
| # @auto_clear_cuda_and_gc(controlnet) | |
| def text2img(task: Task, text2img_pipe ): | |
| prompt, ori_prompt = get_patched_prompt(task) | |
| print("logs post: text2img_pipe", text2img_pipe) | |
| lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style()) | |
| lora_patcher.patch() | |
| torch.manual_seed(task.get_seed()) | |
| images = text2img_pipe.process( | |
| prompt=ori_prompt, | |
| modified_prompts=prompt, | |
| num_inference_steps=task.get_steps(), | |
| guidance_scale=7.5, | |
| height=task.get_height(), | |
| width=task.get_width(), | |
| negative_prompt=[task.get_negative_prompt()] * num_return_sequences, | |
| iteration=task.get_iteration(), | |
| **lora_patcher.kwargs(), | |
| ) | |
| generated_image_urls = upload_images(images, "", task.get_taskId()) | |
| lora_patcher.cleanup() | |
| return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} | |
| # # @update_db | |
| # @auto_clear_cuda_and_gc(controlnet) | |
| # @slack.auto_send_alert | |
| # def img2img(task: Task): | |
| # prompt, _ = get_patched_prompt(task) | |
| # lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style()) | |
| # lora_patcher.patch() | |
| # torch.manual_seed(task.get_seed()) | |
| # images = img2img_pipe.process( | |
| # prompt=prompt, | |
| # imageUrl=task.get_imageUrl(), | |
| # negative_prompt=[task.get_negative_prompt()] * num_return_sequences, | |
| # steps=task.get_steps(), | |
| # **lora_patcher.kwargs(), | |
| # ) | |
| # generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) | |
| # lora_patcher.cleanup() | |
| # return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} | |
| # set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if device.type != 'cuda': | |
| raise ValueError("need to run on GPU") | |
| # multi-model list | |
| multi_model_list = [ | |
| {"model_id": "jayparmr/icbinp"}, | |
| {"model_id": "jayparmr/anything-v5"}, | |
| {"model_id": "jayparmr/v5onad-new"}, | |
| ] | |
| multi_controlnet_model={} | |
| multi_text2image_model={} | |
| multi_image2image_model={} | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| # load the optimized model | |
| print("Logs: model loaded .... starts") | |
| # print("Logs: self.multi_text2image_model", multi_text2image_model) | |
| prompt_modifier.load() | |
| lora_style.load(path) | |
| self.path = path | |
| for model in multi_model_list: | |
| controlnet = ControlNet() | |
| img2img_pipe = Img2Img() | |
| text2img_pipe = Text2Img() | |
| multi_controlnet_model[model["model_id"]] = controlnet; | |
| controlnet.load(model["model_id"]) | |
| multi_text2image_model[model["model_id"]] = text2img_pipe; | |
| text2img_pipe.load( model["model_id"]) | |
| multi_image2image_model[model["model_id"]] = img2img_pipe; | |
| img2img_pipe.load( model["model_id"]) | |
| # print(" Logs: model[model_id]",model, model["model_id"]) | |
| print("Logs: multimodel controlnet pipelines are",model, multi_controlnet_model) | |
| print("Logs: multimodel text2img pipelines are",model, multi_text2image_model) | |
| print("Logs: multimodel imgtoimage pipelines are",model, multi_image2image_model) | |
| # controlnet.load(path) | |
| # text2img_pipe.load(path) | |
| # img2img_pipe.load(path) | |
| print("Logs: self.multi_image2image_model") | |
| print("Logs: self.multi_text2image_model", multi_text2image_model) | |
| print("Logs: self.multi_controlnet_model", multi_controlnet_model) | |
| print("Logs: model loaded ....") | |
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: | |
| """ | |
| Args: | |
| data (:obj:): | |
| includes the input data and the parameters for the inference. | |
| Return: | |
| A :obj:`dict`:. base64 encoded image | |
| """ | |
| # deserialize incomin request | |
| # inputs = data.pop("inputs", data) | |
| # parameters = data.pop("parameters", None) | |
| # model_id = data.pop("model_id", None) | |
| # check if model_id is in the list of models | |
| # if model_id is None or model_id not in multi_model_list: | |
| # raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}") | |
| # # pass inputs with all kwargs in data | |
| # if parameters is not None: | |
| # prediction = self.multi_model[model_id](inputs, **parameters) | |
| # else: | |
| # prediction = self.multi_model[model_id](inputs) | |
| # # postprocess the prediction | |
| # return prediction | |
| print("Logs post: init") | |
| print("Logs post: task is ", data) | |
| inputs = data.pop("inputs", data) | |
| parameters = data.pop("parameters", None) | |
| model_id = data.pop("model_id", None) | |
| print("Logs post: model_id is", model_id) | |
| task = Task(data) | |
| try: | |
| task_type = task.get_type() | |
| print("logs post: self.multi_text2image_model[model_id]", multi_text2image_model) | |
| if task_type == TaskType.TEXT_TO_IMAGE: | |
| # character sheet | |
| if "character sheet" in task.get_prompt().lower(): | |
| print("pose is here") | |
| # return pose(task, s3_outkey="", poses=pickPoses()) | |
| else: | |
| print("pose is not here") | |
| # return text2img(task, multi_text2image_model["jayparmr/icbinp"]) | |
| # elif task_type == TaskType.IMAGE_TO_IMAGE: | |
| # return img2img(task) | |
| # elif task_type == TaskType.CANNY: | |
| # return canny(task) | |
| # elif task_type == TaskType.POSE: | |
| # return pose(task) | |
| else: | |
| raise Exception("Invalid task type") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| slack.error_alert(data, e) | |
| return { "error": e, "data":data, "task_type":task_type } | |
| # inputs = data.pop("inputs", data) | |
| # # run inference pipeline | |
| # with autocast(device.type): | |
| # image = self.pipe(inputs, guidance_scale=7.5) | |
| # # encode image as base 64 | |
| # buffered = BytesIO() | |
| # # image.save(buffered, format="JPEG") | |
| # # img_str = base64.b64encode(buffered.getvalue()) | |
| # print(image) | |
| # # postprocess the prediction | |
| # return image["images"] |