Upload 18 files
Browse files- data/__init__.py +0 -0
- data/dataAccessor.py +60 -0
- data/task.py +85 -0
- inference.py +212 -0
- inference2.py +116 -0
- pipelines/commons.py +85 -0
- pipelines/controlnets.py +133 -0
- pipelines/inpainter.py +40 -0
- pipelines/prompt_modifier.py +54 -0
- pipelines/remove_background.py +15 -0
- pipelines/twoStepPipeline.py +252 -0
- pipelines/upscaler.py +77 -0
- requirements.txt +19 -0
- util/__init__.py +0 -0
- util/cache.py +31 -0
- util/commons.py +220 -0
- util/lora_style.py +92 -0
- util/slack.py +55 -0
data/__init__.py
ADDED
|
File without changes
|
data/dataAccessor.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from data.task import Task
|
| 3 |
+
from util.slack import Slack
|
| 4 |
+
|
| 5 |
+
comic_url = "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def updateSource(sourceId, userId, state):
|
| 9 |
+
print("update source is called")
|
| 10 |
+
url = comic_url + f"/comic-crecoai/source/{sourceId}"
|
| 11 |
+
headers = {"Content-Type": "application/json", "user-id": str(userId)}
|
| 12 |
+
|
| 13 |
+
data = {"state": state}
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
response = requests.patch(url, headers=headers, json=data, timeout=10)
|
| 17 |
+
print("update source response", response)
|
| 18 |
+
except requests.exceptions.Timeout:
|
| 19 |
+
print("Request timed out while updating source")
|
| 20 |
+
except requests.exceptions.RequestException as e:
|
| 21 |
+
print(f"Error while updating source: {e}")
|
| 22 |
+
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def saveGeneratedImages(sourceId, userId):
|
| 27 |
+
print("save generation called")
|
| 28 |
+
url = comic_url + "/comic-crecoai/source/" + str(sourceId) + "/generatedImages"
|
| 29 |
+
headers = {"Content-Type": "application/json", "user-id": str(userId)}
|
| 30 |
+
data = {"state": "ACTIVE"}
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
requests.patch(url, headers=headers, json=data)
|
| 34 |
+
# print("save generation response", response)
|
| 35 |
+
except requests.exceptions.Timeout:
|
| 36 |
+
print("Request timed out while saving image")
|
| 37 |
+
except requests.exceptions.RequestException as e:
|
| 38 |
+
print("Failed to mark source as active: ", e)
|
| 39 |
+
return
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def update_db(func):
|
| 44 |
+
def caller(*args, **kwargs):
|
| 45 |
+
if type(args[0]) is not Task:
|
| 46 |
+
raise Exception("First argument must be a Task object")
|
| 47 |
+
task = args[0]
|
| 48 |
+
try:
|
| 49 |
+
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
|
| 50 |
+
rargs = func(*args, **kwargs)
|
| 51 |
+
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
|
| 52 |
+
saveGeneratedImages(task.get_sourceId(), task.get_userId())
|
| 53 |
+
return rargs
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print("Error processing image: {}".format(str(e)))
|
| 56 |
+
slack = Slack()
|
| 57 |
+
slack.error_alert(task, e)
|
| 58 |
+
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
|
| 59 |
+
|
| 60 |
+
return caller
|
data/task.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TaskType(Enum):
|
| 8 |
+
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE"
|
| 9 |
+
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE"
|
| 10 |
+
POSE = "POSE"
|
| 11 |
+
CANNY = "CANNY"
|
| 12 |
+
REMOVE_BG = "REMOVE_BG"
|
| 13 |
+
INPAINT = "INPAINT"
|
| 14 |
+
UPSCALE_IMAGE = "UPSCALE_IMAGE"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ModelType(Enum):
|
| 18 |
+
REAL = 10000
|
| 19 |
+
ANIME = 10001
|
| 20 |
+
COMIC = 10002
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Task:
|
| 24 |
+
def __init__(self, data):
|
| 25 |
+
self.__data = data
|
| 26 |
+
if data.get("seed", -1) == None or self.get_seed() == -1:
|
| 27 |
+
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max)
|
| 28 |
+
|
| 29 |
+
def get_taskId(self) -> str:
|
| 30 |
+
return self.__data.get("task_id")
|
| 31 |
+
|
| 32 |
+
def get_sourceId(self) -> str:
|
| 33 |
+
return self.__data.get("source_id")
|
| 34 |
+
|
| 35 |
+
def get_imageUrl(self) -> str:
|
| 36 |
+
return self.__data.get("imageUrl")
|
| 37 |
+
|
| 38 |
+
def get_prompt(self) -> str:
|
| 39 |
+
return self.__data.get("prompt")
|
| 40 |
+
|
| 41 |
+
def get_userId(self) -> str:
|
| 42 |
+
return self.__data.get("userId", "")
|
| 43 |
+
|
| 44 |
+
def get_email(self) -> str:
|
| 45 |
+
return self.__data.get("email", "")
|
| 46 |
+
|
| 47 |
+
def get_style(self) -> str:
|
| 48 |
+
return self.__data.get("style", None)
|
| 49 |
+
|
| 50 |
+
def get_iteration(self) -> float:
|
| 51 |
+
return float(self.__data.get("iteration", 3.0))
|
| 52 |
+
|
| 53 |
+
def get_modelType(self) -> ModelType:
|
| 54 |
+
id = int(self.__data.get("model_id", 10000))
|
| 55 |
+
return ModelType(id)
|
| 56 |
+
|
| 57 |
+
def get_width(self) -> int:
|
| 58 |
+
return int(self.__data.get("width", 512))
|
| 59 |
+
|
| 60 |
+
def get_height(self) -> int:
|
| 61 |
+
return int(self.__data.get("height", 512))
|
| 62 |
+
|
| 63 |
+
def get_seed(self) -> int:
|
| 64 |
+
return int(self.__data.get("seed", -1))
|
| 65 |
+
|
| 66 |
+
def get_steps(self) -> int:
|
| 67 |
+
return int(self.__data.get("steps", "75"))
|
| 68 |
+
|
| 69 |
+
def get_type(self) -> Union[TaskType, None]:
|
| 70 |
+
try:
|
| 71 |
+
return TaskType(self.__data.get("task_type"))
|
| 72 |
+
except ValueError:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def get_maskImageUrl(self) -> str:
|
| 76 |
+
return self.__data.get("maskImageUrl")
|
| 77 |
+
|
| 78 |
+
def get_negative_prompt(self) -> str:
|
| 79 |
+
return self.__data.get("negative_prompt", "")
|
| 80 |
+
|
| 81 |
+
def is_prompt_engineering(self) -> bool:
|
| 82 |
+
return self.__data.get("auto_mode", True)
|
| 83 |
+
|
| 84 |
+
def get_raw(self) -> dict:
|
| 85 |
+
return self.__data.copy()
|
inference.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from data.dataAccessor import update_db
|
| 5 |
+
from data.task import Task, TaskType
|
| 6 |
+
from pipelines.commons import Img2Img, Text2Img
|
| 7 |
+
from pipelines.controlnets import ControlNet
|
| 8 |
+
from pipelines.prompt_modifier import PromptModifier
|
| 9 |
+
from util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 10 |
+
from util.commons import add_code_names, pickPoses, upload_images
|
| 11 |
+
from util.lora_style import LoraStyle
|
| 12 |
+
from util.slack import Slack
|
| 13 |
+
|
| 14 |
+
torch.backends.cudnn.benchmark = True
|
| 15 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 16 |
+
|
| 17 |
+
num_return_sequences = 4 # the number of results to generate
|
| 18 |
+
auto_mode = False
|
| 19 |
+
|
| 20 |
+
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
|
| 21 |
+
controlnet = ControlNet()
|
| 22 |
+
lora_style = LoraStyle()
|
| 23 |
+
text2img_pipe = Text2Img()
|
| 24 |
+
img2img_pipe = Img2Img()
|
| 25 |
+
slack = Slack()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_patched_prompt(task: Task):
|
| 29 |
+
def add_style_and_character(prompt: List[str]):
|
| 30 |
+
for i in range(len(prompt)):
|
| 31 |
+
prompt[i] = add_code_names(prompt[i])
|
| 32 |
+
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
|
| 33 |
+
|
| 34 |
+
prompt = task.get_prompt()
|
| 35 |
+
|
| 36 |
+
if task.is_prompt_engineering():
|
| 37 |
+
prompt = prompt_modifier.modify(prompt)
|
| 38 |
+
else:
|
| 39 |
+
prompt = [prompt] * num_return_sequences
|
| 40 |
+
|
| 41 |
+
ori_prompt = [task.get_prompt()] * num_return_sequences
|
| 42 |
+
|
| 43 |
+
add_style_and_character(ori_prompt)
|
| 44 |
+
add_style_and_character(prompt)
|
| 45 |
+
|
| 46 |
+
print({"prompts": prompt})
|
| 47 |
+
|
| 48 |
+
return (prompt, ori_prompt)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@update_db
|
| 52 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 53 |
+
@slack.auto_send_alert
|
| 54 |
+
def canny(task: Task):
|
| 55 |
+
prompt, _ = get_patched_prompt(task)
|
| 56 |
+
|
| 57 |
+
controlnet.load_canny()
|
| 58 |
+
|
| 59 |
+
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 60 |
+
lora_patcher.patch()
|
| 61 |
+
|
| 62 |
+
images = controlnet.process_canny(
|
| 63 |
+
prompt=prompt,
|
| 64 |
+
imageUrl=task.get_imageUrl(),
|
| 65 |
+
seed=task.get_seed(),
|
| 66 |
+
steps=task.get_steps(),
|
| 67 |
+
width=task.get_width(),
|
| 68 |
+
height=task.get_height(),
|
| 69 |
+
negative_prompt=[
|
| 70 |
+
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
|
| 71 |
+
]
|
| 72 |
+
* num_return_sequences,
|
| 73 |
+
**lora_patcher.kwargs(),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
|
| 77 |
+
|
| 78 |
+
lora_patcher.cleanup()
|
| 79 |
+
controlnet.cleanup()
|
| 80 |
+
|
| 81 |
+
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@update_db
|
| 85 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 86 |
+
@slack.auto_send_alert
|
| 87 |
+
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
| 88 |
+
prompt, _ = get_patched_prompt(task)
|
| 89 |
+
|
| 90 |
+
controlnet.load_pose()
|
| 91 |
+
|
| 92 |
+
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 93 |
+
lora_patcher.patch()
|
| 94 |
+
|
| 95 |
+
if poses is None:
|
| 96 |
+
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
|
| 97 |
+
|
| 98 |
+
images = controlnet.process_pose(
|
| 99 |
+
prompt=prompt,
|
| 100 |
+
image=poses,
|
| 101 |
+
seed=task.get_seed(),
|
| 102 |
+
steps=task.get_steps(),
|
| 103 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 104 |
+
width=task.get_width(),
|
| 105 |
+
height=task.get_height(),
|
| 106 |
+
**lora_patcher.kwargs(),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
|
| 110 |
+
|
| 111 |
+
lora_patcher.cleanup()
|
| 112 |
+
controlnet.cleanup()
|
| 113 |
+
|
| 114 |
+
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@update_db
|
| 118 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 119 |
+
@slack.auto_send_alert
|
| 120 |
+
def text2img(task: Task):
|
| 121 |
+
prompt, ori_prompt = get_patched_prompt(task)
|
| 122 |
+
|
| 123 |
+
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
|
| 124 |
+
lora_patcher.patch()
|
| 125 |
+
|
| 126 |
+
torch.manual_seed(task.get_seed())
|
| 127 |
+
|
| 128 |
+
images = text2img_pipe.process(
|
| 129 |
+
prompt=ori_prompt,
|
| 130 |
+
modified_prompts=prompt,
|
| 131 |
+
num_inference_steps=task.get_steps(),
|
| 132 |
+
guidance_scale=7.5,
|
| 133 |
+
height=task.get_height(),
|
| 134 |
+
width=task.get_width(),
|
| 135 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 136 |
+
iteration=task.get_iteration(),
|
| 137 |
+
**lora_patcher.kwargs(),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
generated_image_urls = upload_images(images, "", task.get_taskId())
|
| 141 |
+
|
| 142 |
+
lora_patcher.cleanup()
|
| 143 |
+
|
| 144 |
+
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@update_db
|
| 148 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 149 |
+
@slack.auto_send_alert
|
| 150 |
+
def img2img(task: Task):
|
| 151 |
+
prompt, _ = get_patched_prompt(task)
|
| 152 |
+
|
| 153 |
+
lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
|
| 154 |
+
lora_patcher.patch()
|
| 155 |
+
|
| 156 |
+
torch.manual_seed(task.get_seed())
|
| 157 |
+
|
| 158 |
+
images = img2img_pipe.process(
|
| 159 |
+
prompt=prompt,
|
| 160 |
+
imageUrl=task.get_imageUrl(),
|
| 161 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 162 |
+
steps=task.get_steps(),
|
| 163 |
+
**lora_patcher.kwargs(),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
|
| 167 |
+
|
| 168 |
+
lora_patcher.cleanup()
|
| 169 |
+
|
| 170 |
+
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def model_fn(model_dir):
|
| 174 |
+
print("Logs: model loaded .... starts")
|
| 175 |
+
|
| 176 |
+
prompt_modifier.load()
|
| 177 |
+
|
| 178 |
+
lora_style.load(model_dir)
|
| 179 |
+
controlnet.load(model_dir)
|
| 180 |
+
|
| 181 |
+
text2img_pipe.load(model_dir)
|
| 182 |
+
img2img_pipe.load(model_dir)
|
| 183 |
+
|
| 184 |
+
print("Logs: model loaded ....")
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def predict_fn(data, pipe):
|
| 189 |
+
task = Task(data)
|
| 190 |
+
print("task is ", data)
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
task_type = task.get_type()
|
| 194 |
+
|
| 195 |
+
if task_type == TaskType.TEXT_TO_IMAGE:
|
| 196 |
+
# character sheet
|
| 197 |
+
if "character sheet" in task.get_prompt().lower():
|
| 198 |
+
return pose(task, s3_outkey="", poses=pickPoses())
|
| 199 |
+
else:
|
| 200 |
+
return text2img(task)
|
| 201 |
+
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 202 |
+
return img2img(task)
|
| 203 |
+
elif task_type == TaskType.CANNY:
|
| 204 |
+
return canny(task)
|
| 205 |
+
elif task_type == TaskType.POSE:
|
| 206 |
+
return pose(task)
|
| 207 |
+
else:
|
| 208 |
+
raise Exception("Invalid task type")
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"Error: {e}")
|
| 211 |
+
slack.error_alert(task, e)
|
| 212 |
+
return None
|
inference2.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from data.dataAccessor import update_db
|
| 5 |
+
from data.task import ModelType, Task, TaskType
|
| 6 |
+
from pipelines.inpainter import InPainter
|
| 7 |
+
from pipelines.prompt_modifier import PromptModifier
|
| 8 |
+
from pipelines.remove_background import RemoveBackground
|
| 9 |
+
from pipelines.upscaler import Upscaler
|
| 10 |
+
from util.cache import clear_cuda
|
| 11 |
+
from util.commons import (
|
| 12 |
+
add_code_names,
|
| 13 |
+
construct_default_s3_url,
|
| 14 |
+
upload_image,
|
| 15 |
+
upload_images,
|
| 16 |
+
)
|
| 17 |
+
from util.slack import Slack
|
| 18 |
+
|
| 19 |
+
torch.backends.cudnn.benchmark = True
|
| 20 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 21 |
+
|
| 22 |
+
num_return_sequences = 4
|
| 23 |
+
auto_mode = False
|
| 24 |
+
|
| 25 |
+
slack = Slack()
|
| 26 |
+
|
| 27 |
+
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
|
| 28 |
+
upscaler = Upscaler()
|
| 29 |
+
inpainter = InPainter()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@update_db
|
| 33 |
+
@slack.auto_send_alert
|
| 34 |
+
def remove_bg(task: Task):
|
| 35 |
+
remove_background = RemoveBackground()
|
| 36 |
+
output_image = remove_background.remove(task.get_imageUrl())
|
| 37 |
+
|
| 38 |
+
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
| 39 |
+
upload_image(output_image, output_key)
|
| 40 |
+
|
| 41 |
+
return {"generated_image_url": construct_default_s3_url(output_key)}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@update_db
|
| 45 |
+
@slack.auto_send_alert
|
| 46 |
+
def inpaint(task: Task):
|
| 47 |
+
prompt = add_code_names(task.get_prompt())
|
| 48 |
+
if task.is_prompt_engineering():
|
| 49 |
+
prompt = prompt_modifier.modify(prompt)
|
| 50 |
+
else:
|
| 51 |
+
prompt = [prompt] * num_return_sequences
|
| 52 |
+
|
| 53 |
+
print({"prompts": prompt})
|
| 54 |
+
|
| 55 |
+
images = inpainter.process(
|
| 56 |
+
prompt=prompt,
|
| 57 |
+
image_url=task.get_imageUrl(),
|
| 58 |
+
mask_image_url=task.get_maskImageUrl(),
|
| 59 |
+
width=task.get_width(),
|
| 60 |
+
height=task.get_height(),
|
| 61 |
+
seed=task.get_seed(),
|
| 62 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 63 |
+
)
|
| 64 |
+
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
|
| 65 |
+
|
| 66 |
+
clear_cuda()
|
| 67 |
+
|
| 68 |
+
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@update_db
|
| 72 |
+
@slack.auto_send_alert
|
| 73 |
+
def upscale_image(task: Task):
|
| 74 |
+
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
|
| 75 |
+
out_img = None
|
| 76 |
+
if task.get_modelType() == ModelType.ANIME:
|
| 77 |
+
print("Using Anime model")
|
| 78 |
+
out_img = upscaler.upscale_anime(task.get_imageUrl())
|
| 79 |
+
else:
|
| 80 |
+
print("Using Real model")
|
| 81 |
+
out_img = upscaler.upscale(task.get_imageUrl())
|
| 82 |
+
|
| 83 |
+
upload_image(BytesIO(out_img), output_key)
|
| 84 |
+
return {"generated_image_url": construct_default_s3_url(output_key)}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def model_fn(model_dir):
|
| 88 |
+
print("Logs: model loaded .... starts")
|
| 89 |
+
|
| 90 |
+
prompt_modifier.load()
|
| 91 |
+
upscaler.load()
|
| 92 |
+
inpainter.load()
|
| 93 |
+
|
| 94 |
+
print("Logs: model loaded ....")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def predict_fn(data, pipe):
|
| 99 |
+
task = Task(data)
|
| 100 |
+
print("task is ", data)
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
task_type = task.get_type()
|
| 104 |
+
|
| 105 |
+
if task_type == TaskType.REMOVE_BG:
|
| 106 |
+
return remove_bg(task)
|
| 107 |
+
elif task_type == TaskType.INPAINT:
|
| 108 |
+
return inpaint(task)
|
| 109 |
+
elif task_type == TaskType.UPSCALE_IMAGE:
|
| 110 |
+
return upscale_image(task)
|
| 111 |
+
else:
|
| 112 |
+
raise Exception("Invalid task type")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error: {e}")
|
| 115 |
+
slack.error_alert(task, e)
|
| 116 |
+
return None
|
pipelines/commons.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import StableDiffusionImg2ImgPipeline
|
| 5 |
+
from pipelines.twoStepPipeline import two_step_pipeline
|
| 6 |
+
from util.commons import disable_safety_checker, download_image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Text2Img:
|
| 10 |
+
def load(self, model_dir: str):
|
| 11 |
+
self.pipe = two_step_pipeline.from_pretrained(
|
| 12 |
+
model_dir, torch_dtype=torch.float16
|
| 13 |
+
).to("cuda")
|
| 14 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 15 |
+
disable_safety_checker(self.pipe)
|
| 16 |
+
|
| 17 |
+
@torch.inference_mode()
|
| 18 |
+
def process(
|
| 19 |
+
self,
|
| 20 |
+
prompt: Union[str, List[str]] = None,
|
| 21 |
+
modified_prompts: Union[str, List[str]] = None,
|
| 22 |
+
height: Optional[int] = None,
|
| 23 |
+
width: Optional[int] = None,
|
| 24 |
+
num_inference_steps: int = 50,
|
| 25 |
+
guidance_scale: float = 7.5,
|
| 26 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 27 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 28 |
+
eta: float = 0.0,
|
| 29 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 30 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 31 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 32 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 33 |
+
output_type: Optional[str] = "pil",
|
| 34 |
+
return_dict: bool = True,
|
| 35 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 36 |
+
callback_steps: int = 1,
|
| 37 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 38 |
+
iteration: float = 3.0,
|
| 39 |
+
):
|
| 40 |
+
return self.pipe.two_step_pipeline(
|
| 41 |
+
prompt=prompt,
|
| 42 |
+
modified_prompts=modified_prompts,
|
| 43 |
+
height=height,
|
| 44 |
+
width=width,
|
| 45 |
+
num_inference_steps=num_inference_steps,
|
| 46 |
+
guidance_scale=guidance_scale,
|
| 47 |
+
negative_prompt=negative_prompt,
|
| 48 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 49 |
+
eta=eta,
|
| 50 |
+
generator=generator,
|
| 51 |
+
latents=latents,
|
| 52 |
+
prompt_embeds=prompt_embeds,
|
| 53 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 54 |
+
output_type=output_type,
|
| 55 |
+
return_dict=return_dict,
|
| 56 |
+
callback=callback,
|
| 57 |
+
callback_steps=callback_steps,
|
| 58 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 59 |
+
iteration=iteration,
|
| 60 |
+
).images
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Img2Img:
|
| 64 |
+
def load(self, model_dir: str):
|
| 65 |
+
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
| 66 |
+
model_dir, torch_dtype=torch.float16
|
| 67 |
+
).to("cuda")
|
| 68 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 69 |
+
disable_safety_checker(self.pipe)
|
| 70 |
+
|
| 71 |
+
@torch.inference_mode()
|
| 72 |
+
def process(
|
| 73 |
+
self, prompt: List[str], imageUrl: str, negative_prompt: List[str], steps: int
|
| 74 |
+
):
|
| 75 |
+
image = download_image(imageUrl)
|
| 76 |
+
|
| 77 |
+
return self.pipe.__call__(
|
| 78 |
+
prompt=prompt,
|
| 79 |
+
image=image,
|
| 80 |
+
strength=0.75,
|
| 81 |
+
negative_prompt=negative_prompt,
|
| 82 |
+
guidance_scale=7.5,
|
| 83 |
+
num_images_per_prompt=1,
|
| 84 |
+
num_inference_steps=steps,
|
| 85 |
+
).images
|
pipelines/controlnets.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from controlnet_aux import OpenposeDetector
|
| 7 |
+
from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline,
|
| 8 |
+
UniPCMultistepScheduler)
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from util.cache import clear_cuda_and_gc
|
| 11 |
+
from util.commons import disable_safety_checker, download_image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ControlNet:
|
| 15 |
+
__current_task_name = ""
|
| 16 |
+
|
| 17 |
+
def load(self, model_dir: str):
|
| 18 |
+
# we will load canny by default
|
| 19 |
+
self.load_canny()
|
| 20 |
+
|
| 21 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 22 |
+
model_dir, controlnet=self.controlnet, torch_dtype=torch.float16
|
| 23 |
+
)
|
| 24 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 25 |
+
pipe.enable_model_cpu_offload()
|
| 26 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 27 |
+
disable_safety_checker(pipe)
|
| 28 |
+
self.pipe = pipe
|
| 29 |
+
|
| 30 |
+
def load_canny(self):
|
| 31 |
+
if self.__current_task_name == "canny":
|
| 32 |
+
return
|
| 33 |
+
canny = ControlNetModel.from_pretrained(
|
| 34 |
+
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
|
| 35 |
+
).to("cuda")
|
| 36 |
+
self.__current_task_name = "canny"
|
| 37 |
+
self.controlnet = canny
|
| 38 |
+
if hasattr(self, "pipe"):
|
| 39 |
+
self.pipe.controlnet = canny
|
| 40 |
+
clear_cuda_and_gc()
|
| 41 |
+
|
| 42 |
+
def load_pose(self):
|
| 43 |
+
if self.__current_task_name == "pose":
|
| 44 |
+
return
|
| 45 |
+
pose = ControlNetModel.from_pretrained(
|
| 46 |
+
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
|
| 47 |
+
).to("cuda")
|
| 48 |
+
self.__current_task_name = "pose"
|
| 49 |
+
self.controlnet = pose
|
| 50 |
+
if hasattr(self, "pipe"):
|
| 51 |
+
self.pipe.controlnet = pose
|
| 52 |
+
clear_cuda_and_gc()
|
| 53 |
+
|
| 54 |
+
def cleanup(self):
|
| 55 |
+
self.pipe.controlnet = None
|
| 56 |
+
self.controlnet = None
|
| 57 |
+
self.__current_task_name = ""
|
| 58 |
+
|
| 59 |
+
clear_cuda_and_gc()
|
| 60 |
+
|
| 61 |
+
@torch.inference_mode()
|
| 62 |
+
def process_canny(
|
| 63 |
+
self,
|
| 64 |
+
prompt: List[str],
|
| 65 |
+
imageUrl: str,
|
| 66 |
+
seed: int,
|
| 67 |
+
steps: int,
|
| 68 |
+
negative_prompt: List[str],
|
| 69 |
+
height: int,
|
| 70 |
+
width: int,
|
| 71 |
+
):
|
| 72 |
+
if self.__current_task_name != "canny":
|
| 73 |
+
raise Exception("ControlNet is not loaded with canny model")
|
| 74 |
+
|
| 75 |
+
torch.manual_seed(seed)
|
| 76 |
+
|
| 77 |
+
init_image = download_image(imageUrl)
|
| 78 |
+
init_image = self.__canny_detect_edge(init_image)
|
| 79 |
+
|
| 80 |
+
return self.pipe.__call__(
|
| 81 |
+
prompt=prompt,
|
| 82 |
+
image=init_image,
|
| 83 |
+
guidance_scale=9,
|
| 84 |
+
num_images_per_prompt=1,
|
| 85 |
+
negative_prompt=negative_prompt,
|
| 86 |
+
num_inference_steps=steps,
|
| 87 |
+
height=height,
|
| 88 |
+
width=width,
|
| 89 |
+
).images
|
| 90 |
+
|
| 91 |
+
@torch.inference_mode()
|
| 92 |
+
def process_pose(
|
| 93 |
+
self,
|
| 94 |
+
prompt: List[str],
|
| 95 |
+
image: List[Image.Image],
|
| 96 |
+
seed: int,
|
| 97 |
+
steps: int,
|
| 98 |
+
negative_prompt: List[str],
|
| 99 |
+
height: int,
|
| 100 |
+
width: int,
|
| 101 |
+
):
|
| 102 |
+
if self.__current_task_name != "pose":
|
| 103 |
+
raise Exception("ControlNet is not loaded with pose model")
|
| 104 |
+
|
| 105 |
+
torch.manual_seed(seed)
|
| 106 |
+
|
| 107 |
+
return self.pipe.__call__(
|
| 108 |
+
prompt=prompt,
|
| 109 |
+
image=image,
|
| 110 |
+
num_images_per_prompt=1,
|
| 111 |
+
num_inference_steps=steps,
|
| 112 |
+
negative_prompt=negative_prompt,
|
| 113 |
+
height=height,
|
| 114 |
+
width=width,
|
| 115 |
+
).images
|
| 116 |
+
|
| 117 |
+
def detect_pose(self, imageUrl: str) -> Image.Image:
|
| 118 |
+
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
| 119 |
+
image = download_image(imageUrl)
|
| 120 |
+
image = detector.__call__(image)
|
| 121 |
+
return image
|
| 122 |
+
|
| 123 |
+
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
|
| 124 |
+
image_array = np.array(image)
|
| 125 |
+
|
| 126 |
+
low_threshold = 100
|
| 127 |
+
high_threshold = 200
|
| 128 |
+
|
| 129 |
+
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
|
| 130 |
+
image_array = image_array[:, :, None]
|
| 131 |
+
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
|
| 132 |
+
canny_image = Image.fromarray(image_array)
|
| 133 |
+
return canny_image
|
pipelines/inpainter.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import StableDiffusionInpaintPipeline
|
| 5 |
+
from util.commons import disable_safety_checker, download_image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InPainter:
|
| 9 |
+
def load(self):
|
| 10 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 11 |
+
"runwayml/stable-diffusion-inpainting",
|
| 12 |
+
torch_dtype=torch.float16,
|
| 13 |
+
revision="fp16",
|
| 14 |
+
).to("cuda")
|
| 15 |
+
disable_safety_checker(self.pipe)
|
| 16 |
+
|
| 17 |
+
@torch.inference_mode()
|
| 18 |
+
def process(
|
| 19 |
+
self,
|
| 20 |
+
image_url: str,
|
| 21 |
+
mask_image_url: str,
|
| 22 |
+
width: int,
|
| 23 |
+
height: int,
|
| 24 |
+
seed: int,
|
| 25 |
+
prompt: Union[str, List[str]],
|
| 26 |
+
negative_prompt: Union[str, List[str]],
|
| 27 |
+
):
|
| 28 |
+
torch.manual_seed(seed)
|
| 29 |
+
|
| 30 |
+
input_img = download_image(image_url).resize((width, height))
|
| 31 |
+
mask_img = download_image(mask_image_url).resize((width, height))
|
| 32 |
+
|
| 33 |
+
return self.pipe.__call__(
|
| 34 |
+
prompt=prompt,
|
| 35 |
+
image=input_img,
|
| 36 |
+
mask_image=mask_img,
|
| 37 |
+
height=height,
|
| 38 |
+
width=width,
|
| 39 |
+
negative_prompt=negative_prompt,
|
| 40 |
+
).images
|
pipelines/prompt_modifier.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PromptModifier:
|
| 7 |
+
def __init__(self, num_of_sequences: Optional[int] = 4):
|
| 8 |
+
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
|
| 9 |
+
self.__num_of_sequences = num_of_sequences
|
| 10 |
+
|
| 11 |
+
def load(self):
|
| 12 |
+
self.prompter_model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
+
"Gustavosta/MagicPrompt-Stable-Diffusion"
|
| 14 |
+
)
|
| 15 |
+
self.prompter_tokenizer = AutoTokenizer.from_pretrained(
|
| 16 |
+
"Gustavosta/MagicPrompt-Stable-Diffusion"
|
| 17 |
+
)
|
| 18 |
+
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
|
| 19 |
+
self.prompter_tokenizer.padding_side = "left"
|
| 20 |
+
|
| 21 |
+
def modify(self, text: str) -> List[str]:
|
| 22 |
+
eos_id = self.prompter_tokenizer.eos_token_id
|
| 23 |
+
# restricted_words_list = ["octane", "cyber"]
|
| 24 |
+
# restricted_words_token_ids = prompter_tokenizer(
|
| 25 |
+
# restricted_words_list, add_special_tokens=False
|
| 26 |
+
# ).input_ids
|
| 27 |
+
|
| 28 |
+
generation_config = GenerationConfig(
|
| 29 |
+
do_sample=False,
|
| 30 |
+
max_new_tokens=75,
|
| 31 |
+
num_beams=4,
|
| 32 |
+
num_return_sequences=self.__num_of_sequences,
|
| 33 |
+
eos_token_id=eos_id,
|
| 34 |
+
pad_token_id=eos_id,
|
| 35 |
+
length_penalty=-1.0,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids
|
| 39 |
+
outputs = self.prompter_model.generate(
|
| 40 |
+
input_ids, generation_config=generation_config
|
| 41 |
+
)
|
| 42 |
+
output_texts = self.prompter_tokenizer.batch_decode(
|
| 43 |
+
outputs, skip_special_tokens=True
|
| 44 |
+
)
|
| 45 |
+
output_texts = self.__patch_blacklist_words(output_texts)
|
| 46 |
+
return output_texts
|
| 47 |
+
|
| 48 |
+
def __patch_blacklist_words(self, texts: List[str]):
|
| 49 |
+
def replace_all(text, dic):
|
| 50 |
+
for i, j in dic.items():
|
| 51 |
+
text = text.replace(i, j)
|
| 52 |
+
return text
|
| 53 |
+
|
| 54 |
+
return [replace_all(text, self.__blacklist) for text in texts]
|
pipelines/remove_background.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from rembg import remove
|
| 6 |
+
from util.commons import read_url
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RemoveBackground:
|
| 10 |
+
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
|
| 11 |
+
if type(image) is str:
|
| 12 |
+
image = Image.open(io.BytesIO(read_url(image)))
|
| 13 |
+
|
| 14 |
+
output = remove(image)
|
| 15 |
+
return output
|
pipelines/twoStepPipeline.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import StableDiffusionPipeline
|
| 3 |
+
|
| 4 |
+
torch.backends.cudnn.benchmark = True
|
| 5 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 6 |
+
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
from diffusers import StableDiffusionPipeline
|
| 10 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class two_step_pipeline(StableDiffusionPipeline):
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def two_step_pipeline(
|
| 16 |
+
self,
|
| 17 |
+
prompt: Union[str, List[str]] = None,
|
| 18 |
+
modified_prompts: Union[str, List[str]] = None,
|
| 19 |
+
height: Optional[int] = None,
|
| 20 |
+
width: Optional[int] = None,
|
| 21 |
+
num_inference_steps: int = 50,
|
| 22 |
+
guidance_scale: float = 7.5,
|
| 23 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 24 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 25 |
+
eta: float = 0.0,
|
| 26 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 27 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 28 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 29 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 30 |
+
output_type: Optional[str] = "pil",
|
| 31 |
+
return_dict: bool = True,
|
| 32 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 33 |
+
callback_steps: int = 1,
|
| 34 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 35 |
+
iteration: float = 3.0,
|
| 36 |
+
):
|
| 37 |
+
r"""
|
| 38 |
+
Function invoked when calling the pipeline for generation.
|
| 39 |
+
Args:
|
| 40 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 41 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 42 |
+
instead.
|
| 43 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 44 |
+
The height in pixels of the generated image.
|
| 45 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 46 |
+
The width in pixels of the generated image.
|
| 47 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 48 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 49 |
+
expense of slower inference.
|
| 50 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 51 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 52 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 53 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 54 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 55 |
+
usually at the expense of lower image quality.
|
| 56 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 57 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 58 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 59 |
+
less than `1`).
|
| 60 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 61 |
+
The number of images to generate per prompt.
|
| 62 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 63 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 64 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 65 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 66 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 67 |
+
to make generation deterministic.
|
| 68 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 69 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 70 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 71 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 72 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 73 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 74 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 75 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 76 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 77 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 78 |
+
argument.
|
| 79 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 80 |
+
The output format of the generate image. Choose between
|
| 81 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 82 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 84 |
+
plain tuple.
|
| 85 |
+
callback (`Callable`, *optional*):
|
| 86 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 87 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 88 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 89 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 90 |
+
called at every step.
|
| 91 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 92 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 93 |
+
`self.processor` in
|
| 94 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
| 95 |
+
Examples:
|
| 96 |
+
Returns:
|
| 97 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 98 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 99 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 100 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 101 |
+
(nsfw) content, according to the `safety_checker`.
|
| 102 |
+
"""
|
| 103 |
+
# 0. Default height and width to unet
|
| 104 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 105 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 106 |
+
|
| 107 |
+
# 1. Check inputs. Raise error if not correct
|
| 108 |
+
self.check_inputs(
|
| 109 |
+
prompt,
|
| 110 |
+
height,
|
| 111 |
+
width,
|
| 112 |
+
callback_steps,
|
| 113 |
+
negative_prompt,
|
| 114 |
+
prompt_embeds,
|
| 115 |
+
negative_prompt_embeds,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# 2. Define call parameters
|
| 119 |
+
if prompt is not None and isinstance(prompt, str):
|
| 120 |
+
batch_size = 1
|
| 121 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 122 |
+
batch_size = len(prompt)
|
| 123 |
+
else:
|
| 124 |
+
batch_size = prompt_embeds.shape[0]
|
| 125 |
+
|
| 126 |
+
device = self._execution_device
|
| 127 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 128 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 129 |
+
# corresponds to doing no classifier free guidance.
|
| 130 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 131 |
+
|
| 132 |
+
# 3. Encode input prompt
|
| 133 |
+
modified_embeds = self._encode_prompt(
|
| 134 |
+
modified_prompts,
|
| 135 |
+
device,
|
| 136 |
+
num_images_per_prompt,
|
| 137 |
+
do_classifier_free_guidance,
|
| 138 |
+
negative_prompt,
|
| 139 |
+
prompt_embeds=prompt_embeds,
|
| 140 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 141 |
+
)
|
| 142 |
+
print("mod prompt size : ", modified_embeds.size(), modified_embeds.dtype)
|
| 143 |
+
|
| 144 |
+
prompt_embeds = self._encode_prompt(
|
| 145 |
+
prompt,
|
| 146 |
+
device,
|
| 147 |
+
num_images_per_prompt,
|
| 148 |
+
do_classifier_free_guidance,
|
| 149 |
+
negative_prompt,
|
| 150 |
+
prompt_embeds=prompt_embeds,
|
| 151 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
print("prompt size : ", prompt_embeds.size(), prompt_embeds.dtype)
|
| 155 |
+
|
| 156 |
+
# 4. Prepare timesteps
|
| 157 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 158 |
+
timesteps = self.scheduler.timesteps
|
| 159 |
+
|
| 160 |
+
# 5. Prepare latent variables
|
| 161 |
+
num_channels_latents = self.unet.config.in_channels
|
| 162 |
+
latents = self.prepare_latents(
|
| 163 |
+
batch_size * num_images_per_prompt,
|
| 164 |
+
num_channels_latents,
|
| 165 |
+
height,
|
| 166 |
+
width,
|
| 167 |
+
prompt_embeds.dtype,
|
| 168 |
+
device,
|
| 169 |
+
generator,
|
| 170 |
+
latents,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 174 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 175 |
+
|
| 176 |
+
# 7. Denoising loop
|
| 177 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 178 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 179 |
+
for i, t in enumerate(timesteps):
|
| 180 |
+
# expand the latents if we are doing classifier free guidance
|
| 181 |
+
latent_model_input = (
|
| 182 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 183 |
+
)
|
| 184 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 185 |
+
latent_model_input, t
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# predict the noise residual
|
| 189 |
+
noise_pred = self.unet(
|
| 190 |
+
latent_model_input,
|
| 191 |
+
t,
|
| 192 |
+
encoder_hidden_states=prompt_embeds,
|
| 193 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 194 |
+
).sample
|
| 195 |
+
|
| 196 |
+
# perform guidance
|
| 197 |
+
if do_classifier_free_guidance:
|
| 198 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 199 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 200 |
+
noise_pred_text - noise_pred_uncond
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 204 |
+
latents = self.scheduler.step(
|
| 205 |
+
noise_pred, t, latents, **extra_step_kwargs
|
| 206 |
+
).prev_sample
|
| 207 |
+
|
| 208 |
+
# call the callback, if provided
|
| 209 |
+
if i == len(timesteps) - 1 or (
|
| 210 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 211 |
+
):
|
| 212 |
+
progress_bar.update()
|
| 213 |
+
if callback is not None and i % callback_steps == 0:
|
| 214 |
+
callback(i, t, latents)
|
| 215 |
+
|
| 216 |
+
if i == int(len(timesteps) / iteration):
|
| 217 |
+
print("modified prompts")
|
| 218 |
+
prompt_embeds = modified_embeds
|
| 219 |
+
|
| 220 |
+
if output_type == "latent":
|
| 221 |
+
image = latents
|
| 222 |
+
has_nsfw_concept = None
|
| 223 |
+
elif output_type == "pil":
|
| 224 |
+
# 8. Post-processing
|
| 225 |
+
image = self.decode_latents(latents)
|
| 226 |
+
|
| 227 |
+
# 9. Run safety checker
|
| 228 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 229 |
+
image, device, prompt_embeds.dtype
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# 10. Convert to PIL
|
| 233 |
+
image = self.numpy_to_pil(image)
|
| 234 |
+
else:
|
| 235 |
+
# 8. Post-processing
|
| 236 |
+
image = self.decode_latents(latents)
|
| 237 |
+
|
| 238 |
+
# 9. Run safety checker
|
| 239 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 240 |
+
image, device, prompt_embeds.dtype
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Offload last model to CPU
|
| 244 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 245 |
+
self.final_offload_hook.offload()
|
| 246 |
+
|
| 247 |
+
if not return_dict:
|
| 248 |
+
return (image, has_nsfw_concept)
|
| 249 |
+
|
| 250 |
+
return StableDiffusionPipelineOutput(
|
| 251 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
| 252 |
+
)
|
pipelines/upscaler.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 8 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from realesrgan import RealESRGANer
|
| 11 |
+
from util.commons import read_url
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Upscaler:
|
| 15 |
+
__model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
| 16 |
+
__model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
|
| 17 |
+
|
| 18 |
+
def load(self):
|
| 19 |
+
download_dir = Path(Path.home() / ".cache" / "realesrgan")
|
| 20 |
+
download_dir.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
|
| 23 |
+
self.__model_path_anime = self.__preload_model(
|
| 24 |
+
self.__model_esrgan_anime_url, download_dir
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def upscale(self, image: Union[str, bytes]) -> bytes:
|
| 28 |
+
model = RRDBNet(
|
| 29 |
+
num_in_ch=3,
|
| 30 |
+
num_out_ch=3,
|
| 31 |
+
num_feat=64,
|
| 32 |
+
num_block=23,
|
| 33 |
+
num_grow_ch=32,
|
| 34 |
+
scale=4,
|
| 35 |
+
)
|
| 36 |
+
return self.__internal_upscale(image, self.__model_path, model)
|
| 37 |
+
|
| 38 |
+
def upscale_anime(self, image: Union[str, bytes]) -> bytes:
|
| 39 |
+
model = RRDBNet(
|
| 40 |
+
num_in_ch=3,
|
| 41 |
+
num_out_ch=3,
|
| 42 |
+
num_feat=64,
|
| 43 |
+
num_block=23,
|
| 44 |
+
num_grow_ch=32,
|
| 45 |
+
scale=4,
|
| 46 |
+
)
|
| 47 |
+
return self.__internal_upscale(image, self.__model_path_anime, model)
|
| 48 |
+
|
| 49 |
+
def __preload_model(self, url: str, download_dir: Path):
|
| 50 |
+
name = url.split("/")[-1]
|
| 51 |
+
if not os.path.exists(str(download_dir / name)):
|
| 52 |
+
return load_file_from_url(
|
| 53 |
+
url=url,
|
| 54 |
+
model_dir=str(download_dir),
|
| 55 |
+
progress=True,
|
| 56 |
+
file_name=None,
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
return str(download_dir / name)
|
| 60 |
+
|
| 61 |
+
def __internal_upscale(
|
| 62 |
+
self,
|
| 63 |
+
image: Union[str, bytes],
|
| 64 |
+
model_path: str,
|
| 65 |
+
rrbdnet: RRDBNet,
|
| 66 |
+
) -> bytes:
|
| 67 |
+
if type(image) is str:
|
| 68 |
+
image = read_url(image)
|
| 69 |
+
|
| 70 |
+
upsampler = RealESRGANer(
|
| 71 |
+
scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0"
|
| 72 |
+
)
|
| 73 |
+
image_array = np.frombuffer(image, dtype=np.uint8)
|
| 74 |
+
input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
| 75 |
+
output, _ = upsampler.enhance(input_image, outscale=4)
|
| 76 |
+
out_bytes = cv2.imencode(".png", output)[1].tobytes()
|
| 77 |
+
return out_bytes
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aioredis==1.3.1
|
| 2 |
+
boto3==1.24.61
|
| 3 |
+
triton==2.0.0
|
| 4 |
+
diffusers==0.14.0
|
| 5 |
+
fastapi==0.87.0
|
| 6 |
+
Pillow==9.3.0
|
| 7 |
+
redis==4.3.4
|
| 8 |
+
requests==2.28.1
|
| 9 |
+
transformers
|
| 10 |
+
rembg==2.0.30
|
| 11 |
+
accelerate==0.17.0
|
| 12 |
+
gfpgan==1.3.8
|
| 13 |
+
rembg==2.0.30
|
| 14 |
+
controlnet-aux==0.0.1
|
| 15 |
+
realesrgan==0.3.0
|
| 16 |
+
compel==1.0.4
|
| 17 |
+
xformers
|
| 18 |
+
torchvision
|
| 19 |
+
git+https://github.com/cloneofsimo/lora.git
|
util/__init__.py
ADDED
|
File without changes
|
util/cache.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def clear_cuda_and_gc():
|
| 7 |
+
clear_cuda()
|
| 8 |
+
clear_gc()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def clear_cuda():
|
| 12 |
+
torch.cuda.empty_cache()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def clear_gc():
|
| 16 |
+
gc.collect()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def auto_clear_cuda_and_gc(controlnet):
|
| 20 |
+
def auto_clear_cuda_and_gc_wrapper(func):
|
| 21 |
+
def wrapper(*args, **kwargs):
|
| 22 |
+
try:
|
| 23 |
+
return func(*args, **kwargs)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
controlnet.cleanup()
|
| 26 |
+
clear_cuda_and_gc()
|
| 27 |
+
raise e
|
| 28 |
+
|
| 29 |
+
return wrapper
|
| 30 |
+
|
| 31 |
+
return auto_clear_cuda_and_gc_wrapper
|
util/commons.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pprint
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
import boto3
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
s3 = boto3.client("s3")
|
| 11 |
+
import io
|
| 12 |
+
import urllib.request
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
black_list = {"alphonse mucha": "", "adolphe bouguereau": ""}
|
| 17 |
+
pp = pprint.PrettyPrinter(indent=4)
|
| 18 |
+
Avatar = [
|
| 19 |
+
{
|
| 20 |
+
"avatarName": "niomi",
|
| 21 |
+
"codename": "1jMGp1kFkG",
|
| 22 |
+
"avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/7_char_assets/niyomi_1jMGp1kFkG.png",
|
| 23 |
+
"extraPrompt": "1jMGp1kFkG girl",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"avatarName": "riya",
|
| 27 |
+
"codename": "vW6AUQtoaY",
|
| 28 |
+
"avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/12_char_assets/riya_vW6AUQtoaY.png",
|
| 29 |
+
"extraPrompt": "vW6AUQtoaY girl",
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"avatarName": "rajveer",
|
| 33 |
+
"codename": "fSLF0OPkBw",
|
| 34 |
+
"avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/12_character_assets/rajveer_fSLF0OPkBw.png",
|
| 35 |
+
"extraPrompt": "fSLF0OPkBw guy",
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"avatarName": "bheem",
|
| 39 |
+
"codename": "HL79CB3ODZ",
|
| 40 |
+
"avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/7_char_assets/scene01001_1.png",
|
| 41 |
+
"extraPrompt": "HL79CB3ODZ boy",
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"avatarName": "chutki",
|
| 45 |
+
"codename": "SJ7JVIS9M7",
|
| 46 |
+
"avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/7_char_assets/14.png",
|
| 47 |
+
"extraPrompt": "SJ7JVIS9M7 girl",
|
| 48 |
+
},
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
webhook_url = (
|
| 52 |
+
"https://hooks.slack.com/services/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"
|
| 53 |
+
)
|
| 54 |
+
error_webhook = (
|
| 55 |
+
"https://hooks.slack.com/services/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"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
characterSheets = [
|
| 59 |
+
"character+sheets/1.1.png",
|
| 60 |
+
"character+sheets/10.1.png",
|
| 61 |
+
"character+sheets/11.1.png",
|
| 62 |
+
"character+sheets/12.1.png",
|
| 63 |
+
"character+sheets/13.1.png",
|
| 64 |
+
"character+sheets/14.1.png",
|
| 65 |
+
"character+sheets/16.1.png",
|
| 66 |
+
"character+sheets/17.1.png",
|
| 67 |
+
"character+sheets/18.1.png",
|
| 68 |
+
"character+sheets/19.1.png",
|
| 69 |
+
"character+sheets/2.1.png",
|
| 70 |
+
"character+sheets/20.1.png",
|
| 71 |
+
"character+sheets/21.1.png",
|
| 72 |
+
"character+sheets/22.1.png",
|
| 73 |
+
"character+sheets/23.1.png",
|
| 74 |
+
"character+sheets/24.1.png",
|
| 75 |
+
"character+sheets/25.1.png",
|
| 76 |
+
"character+sheets/26.1.png",
|
| 77 |
+
"character+sheets/27.1.png",
|
| 78 |
+
"character+sheets/28.1.png",
|
| 79 |
+
"character+sheets/29.1.png",
|
| 80 |
+
"character+sheets/3.1.png",
|
| 81 |
+
"character+sheets/30.1.png",
|
| 82 |
+
"character+sheets/31.1.png",
|
| 83 |
+
"character+sheets/32.1.png",
|
| 84 |
+
"character+sheets/33.1.png",
|
| 85 |
+
"character+sheets/34.1.png",
|
| 86 |
+
"character+sheets/35.1.png",
|
| 87 |
+
"character+sheets/36.1.png",
|
| 88 |
+
"character+sheets/38.1.png",
|
| 89 |
+
"character+sheets/39.1.png",
|
| 90 |
+
"character+sheets/4.1.png",
|
| 91 |
+
"character+sheets/40.1.png",
|
| 92 |
+
"character+sheets/42.1.png",
|
| 93 |
+
"character+sheets/43.1.png",
|
| 94 |
+
"character+sheets/44.1.png",
|
| 95 |
+
"character+sheets/45.1.png",
|
| 96 |
+
"character+sheets/46.1.png",
|
| 97 |
+
"character+sheets/47.1.png",
|
| 98 |
+
"character+sheets/48.1.png",
|
| 99 |
+
"character+sheets/49.1.png",
|
| 100 |
+
"character+sheets/5.1.png",
|
| 101 |
+
"character+sheets/50.1.png",
|
| 102 |
+
"character+sheets/51.1.png",
|
| 103 |
+
"character+sheets/52.1.png",
|
| 104 |
+
"character+sheets/53.1.png",
|
| 105 |
+
"character+sheets/54.1.png",
|
| 106 |
+
"character+sheets/55.1.png",
|
| 107 |
+
"character+sheets/56.1.png",
|
| 108 |
+
"character+sheets/57.1.png",
|
| 109 |
+
"character+sheets/58.1.png",
|
| 110 |
+
"character+sheets/59.1.png",
|
| 111 |
+
"character+sheets/60.1.png",
|
| 112 |
+
"character+sheets/61.1.png",
|
| 113 |
+
"character+sheets/62.1.png",
|
| 114 |
+
"character+sheets/63.1.png",
|
| 115 |
+
"character+sheets/64.1.png",
|
| 116 |
+
"character+sheets/65.1.png",
|
| 117 |
+
"character+sheets/66.1.png",
|
| 118 |
+
"character+sheets/7.1.png",
|
| 119 |
+
"character+sheets/8.1.png",
|
| 120 |
+
"character+sheets/9.1.png",
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def add_code_names(sentence):
|
| 125 |
+
array_of_objects = Avatar
|
| 126 |
+
|
| 127 |
+
for obj in array_of_objects:
|
| 128 |
+
sentence = (
|
| 129 |
+
re.sub(
|
| 130 |
+
r"\b" + obj["avatarName"] + r"\b",
|
| 131 |
+
obj["extraPrompt"],
|
| 132 |
+
sentence,
|
| 133 |
+
flags=re.IGNORECASE,
|
| 134 |
+
)
|
| 135 |
+
+ " "
|
| 136 |
+
)
|
| 137 |
+
print(sentence)
|
| 138 |
+
return sentence
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def upload_images(images, processName: str, taskId: str):
|
| 142 |
+
imageUrls = []
|
| 143 |
+
for i, image in enumerate(images):
|
| 144 |
+
# img_io = BytesIO()
|
| 145 |
+
# image.save(img_io, "JPEG", quality=70)
|
| 146 |
+
# img_io.seek(0)
|
| 147 |
+
# key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
|
| 148 |
+
# t = s3.put_object(
|
| 149 |
+
# Bucket="comic-assets", Key=key, Body=img_io.getvalue(), ACL="public-read"
|
| 150 |
+
# )
|
| 151 |
+
# print("uploading done to s3", key, t)
|
| 152 |
+
imageUrls.append(
|
| 153 |
+
"https://comic-assets.s3.ap-south-1.amazonaws.com/crecoAI/{}{}_{}.png".format(
|
| 154 |
+
taskId, processName, i
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
print({"promptImages": imageUrls})
|
| 159 |
+
|
| 160 |
+
return imageUrls
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# def upload_image(image: Union[Image.Image, BytesIO], out_path):
|
| 164 |
+
# if type(image) is Image.Image:
|
| 165 |
+
# buffer = io.BytesIO()
|
| 166 |
+
# image.save(buffer, format="PNG")
|
| 167 |
+
# image = buffer
|
| 168 |
+
|
| 169 |
+
# image.seek(0)
|
| 170 |
+
# s3.upload_fileobj(image, "comic-assets", out_path, ExtraArgs={"ACL": "public-read"})
|
| 171 |
+
# image.close()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def download_image(url) -> Image.Image:
|
| 175 |
+
response = requests.get(url)
|
| 176 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def pickPoses():
|
| 180 |
+
random_images = random.sample(characterSheets, 4)
|
| 181 |
+
poses = []
|
| 182 |
+
prefix = "https://comic-assets.s3.ap-south-1.amazonaws.com/"
|
| 183 |
+
|
| 184 |
+
# Use list comprehension to add prefix to all elements in the array
|
| 185 |
+
random_images_with_prefix = [prefix + img for img in random_images]
|
| 186 |
+
|
| 187 |
+
print(random_images_with_prefix)
|
| 188 |
+
for imageUrl in random_images_with_prefix:
|
| 189 |
+
# Download and resize the image
|
| 190 |
+
init_image = download_image(imageUrl).resize((512, 512))
|
| 191 |
+
|
| 192 |
+
# Open the pose image
|
| 193 |
+
imageUrlPose = imageUrl
|
| 194 |
+
# print(imageUrl)
|
| 195 |
+
input_image_bytes = read_url(imageUrlPose)
|
| 196 |
+
# print(input_image_bytes)
|
| 197 |
+
pose_image = Image.open(io.BytesIO(input_image_bytes)).convert("RGB")
|
| 198 |
+
# print(pose_image)
|
| 199 |
+
pose_image = pose_image.resize((512, 512))
|
| 200 |
+
# print(pose_image)
|
| 201 |
+
# Append the result to the poses array
|
| 202 |
+
poses.append(pose_image)
|
| 203 |
+
|
| 204 |
+
return poses
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def construct_default_s3_url(key):
|
| 208 |
+
return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def read_url(url: str):
|
| 212 |
+
with urllib.request.urlopen(url) as u:
|
| 213 |
+
return u.read()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def disable_safety_checker(pipe):
|
| 217 |
+
def dummy(images, **kwargs):
|
| 218 |
+
return images, False
|
| 219 |
+
|
| 220 |
+
pipe.safety_checker = None
|
util/lora_style.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, Union
|
| 3 |
+
|
| 4 |
+
from lora_diffusion import patch_pipe, tune_lora_scale
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LoraStyle:
|
| 8 |
+
class LoraPatcher:
|
| 9 |
+
def __init__(self, pipe, style: Dict[str, Any]):
|
| 10 |
+
self.__style = style
|
| 11 |
+
self.pipe = pipe
|
| 12 |
+
|
| 13 |
+
def patch(self):
|
| 14 |
+
patch_pipe(self.pipe, self.__style["path"])
|
| 15 |
+
tune_lora_scale(self.pipe.unet, self.__style["weight"])
|
| 16 |
+
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
|
| 17 |
+
|
| 18 |
+
def kwargs(self):
|
| 19 |
+
return {}
|
| 20 |
+
|
| 21 |
+
def cleanup(self):
|
| 22 |
+
tune_lora_scale(self.pipe.unet, 0.0)
|
| 23 |
+
tune_lora_scale(self.pipe.text_encoder, 0.0)
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
class EmptyLoraPatcher:
|
| 27 |
+
def patch(self):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def kwargs(self):
|
| 31 |
+
return {}
|
| 32 |
+
|
| 33 |
+
def cleanup(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
def load(self, model_dir: str):
|
| 37 |
+
self.__styles = {
|
| 38 |
+
"nq6akX1CIp": {
|
| 39 |
+
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
|
| 40 |
+
"weight": 0.5,
|
| 41 |
+
"negativePrompt": [""],
|
| 42 |
+
"type": "custom",
|
| 43 |
+
},
|
| 44 |
+
"ghibli": {
|
| 45 |
+
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
|
| 46 |
+
"weight": 1,
|
| 47 |
+
"negativePrompt": [""],
|
| 48 |
+
"type": "custom",
|
| 49 |
+
},
|
| 50 |
+
"eQAmnK2kB2": {
|
| 51 |
+
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
|
| 52 |
+
"weight": 0.5,
|
| 53 |
+
"negativePrompt": [""],
|
| 54 |
+
"type": "custom",
|
| 55 |
+
},
|
| 56 |
+
"to8contrast": {
|
| 57 |
+
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
|
| 58 |
+
"weight": 0.5,
|
| 59 |
+
"negativePrompt": [""],
|
| 60 |
+
"type": "custom",
|
| 61 |
+
},
|
| 62 |
+
"jim lee": {
|
| 63 |
+
"path": model_dir + "/laur_style/e2j9mz0jqj/final_lora.bin",
|
| 64 |
+
"weight": 0.8,
|
| 65 |
+
"negativePrompt": [""],
|
| 66 |
+
"type": "custom",
|
| 67 |
+
},
|
| 68 |
+
}
|
| 69 |
+
self.__verify()
|
| 70 |
+
|
| 71 |
+
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
|
| 72 |
+
if key in self.__styles:
|
| 73 |
+
return f"{key} style {prompt}"
|
| 74 |
+
return prompt
|
| 75 |
+
|
| 76 |
+
def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
|
| 77 |
+
if key in self.__styles:
|
| 78 |
+
style = self.__styles[key]
|
| 79 |
+
return self.LoraPatcher(pipe, style)
|
| 80 |
+
return self.EmptyLoraPatcher()
|
| 81 |
+
|
| 82 |
+
def __verify(self):
|
| 83 |
+
"A method to verify if lora exists within the required path otherwise throw error"
|
| 84 |
+
|
| 85 |
+
for item in self.__styles.keys():
|
| 86 |
+
if not os.path.exists(self.__styles[item]["path"]):
|
| 87 |
+
raise Exception(
|
| 88 |
+
"Lora style model "
|
| 89 |
+
+ item
|
| 90 |
+
+ " not found at path: "
|
| 91 |
+
+ self.__styles[item]["path"]
|
| 92 |
+
)
|
util/slack.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from time import sleep
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
from data.task import Task
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Slack:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
# self.webhook_url = "https://hooks.slack.com/services/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"
|
| 11 |
+
self.webhook_url = "https://hooks.slack.com/services/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"
|
| 12 |
+
self.error_webhook = "https://hooks.slack.com/services/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"
|
| 13 |
+
|
| 14 |
+
def send_alert(self, task: Task, args: Optional[dict]):
|
| 15 |
+
raw = task.get_raw().copy()
|
| 16 |
+
|
| 17 |
+
raw.pop("queue_name", None)
|
| 18 |
+
raw.pop("attempt", None)
|
| 19 |
+
raw.pop("timestamp", None)
|
| 20 |
+
raw.pop("task_id", None)
|
| 21 |
+
raw.pop("maskImageUrl", None)
|
| 22 |
+
|
| 23 |
+
if args is not None:
|
| 24 |
+
raw.update(args.items())
|
| 25 |
+
|
| 26 |
+
message = ""
|
| 27 |
+
for key, value in raw.items():
|
| 28 |
+
if value:
|
| 29 |
+
if type(value) == list:
|
| 30 |
+
message += f"*{key}*: {', '.join(value)}\n"
|
| 31 |
+
else:
|
| 32 |
+
message += f"*{key}*: {value}\n"
|
| 33 |
+
|
| 34 |
+
requests.post(
|
| 35 |
+
self.webhook_url,
|
| 36 |
+
headers={"Content-Type": "application/json"},
|
| 37 |
+
json={"text": message},
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def error_alert(self, task: Task, e: Exception):
|
| 41 |
+
requests.post(
|
| 42 |
+
self.error_webhook,
|
| 43 |
+
headers={"Content-Type": "application/json"},
|
| 44 |
+
json={
|
| 45 |
+
"text": "Task failed:\n{} \n error is: \n {}".format(task.get_raw(), e)
|
| 46 |
+
},
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def auto_send_alert(self, func):
|
| 50 |
+
def inner(*args, **kwargs):
|
| 51 |
+
rargs = func(*args, **kwargs)
|
| 52 |
+
self.send_alert(args[0], rargs)
|
| 53 |
+
return rargs
|
| 54 |
+
|
| 55 |
+
return inner
|