jayparmr commited on
Commit
4adca93
·
1 Parent(s): 3d5d3f1

Upload 18 files

Browse files
data/__init__.py ADDED
File without changes
data/dataAccessor.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from data.task import Task
3
+ from util.slack import Slack
4
+
5
+ comic_url = "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
6
+
7
+
8
+ def updateSource(sourceId, userId, state):
9
+ print("update source is called")
10
+ url = comic_url + f"/comic-crecoai/source/{sourceId}"
11
+ headers = {"Content-Type": "application/json", "user-id": str(userId)}
12
+
13
+ data = {"state": state}
14
+
15
+ try:
16
+ response = requests.patch(url, headers=headers, json=data, timeout=10)
17
+ print("update source response", response)
18
+ except requests.exceptions.Timeout:
19
+ print("Request timed out while updating source")
20
+ except requests.exceptions.RequestException as e:
21
+ print(f"Error while updating source: {e}")
22
+
23
+ return
24
+
25
+
26
+ def saveGeneratedImages(sourceId, userId):
27
+ print("save generation called")
28
+ url = comic_url + "/comic-crecoai/source/" + str(sourceId) + "/generatedImages"
29
+ headers = {"Content-Type": "application/json", "user-id": str(userId)}
30
+ data = {"state": "ACTIVE"}
31
+
32
+ try:
33
+ requests.patch(url, headers=headers, json=data)
34
+ # print("save generation response", response)
35
+ except requests.exceptions.Timeout:
36
+ print("Request timed out while saving image")
37
+ except requests.exceptions.RequestException as e:
38
+ print("Failed to mark source as active: ", e)
39
+ return
40
+ return
41
+
42
+
43
+ def update_db(func):
44
+ def caller(*args, **kwargs):
45
+ if type(args[0]) is not Task:
46
+ raise Exception("First argument must be a Task object")
47
+ task = args[0]
48
+ try:
49
+ updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
50
+ rargs = func(*args, **kwargs)
51
+ updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
52
+ saveGeneratedImages(task.get_sourceId(), task.get_userId())
53
+ return rargs
54
+ except Exception as e:
55
+ print("Error processing image: {}".format(str(e)))
56
+ slack = Slack()
57
+ slack.error_alert(task, e)
58
+ updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
59
+
60
+ return caller
data/task.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+
6
+
7
+ class TaskType(Enum):
8
+ TEXT_TO_IMAGE = "GENERATE_AI_IMAGE"
9
+ IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE"
10
+ POSE = "POSE"
11
+ CANNY = "CANNY"
12
+ REMOVE_BG = "REMOVE_BG"
13
+ INPAINT = "INPAINT"
14
+ UPSCALE_IMAGE = "UPSCALE_IMAGE"
15
+
16
+
17
+ class ModelType(Enum):
18
+ REAL = 10000
19
+ ANIME = 10001
20
+ COMIC = 10002
21
+
22
+
23
+ class Task:
24
+ def __init__(self, data):
25
+ self.__data = data
26
+ if data.get("seed", -1) == None or self.get_seed() == -1:
27
+ self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max)
28
+
29
+ def get_taskId(self) -> str:
30
+ return self.__data.get("task_id")
31
+
32
+ def get_sourceId(self) -> str:
33
+ return self.__data.get("source_id")
34
+
35
+ def get_imageUrl(self) -> str:
36
+ return self.__data.get("imageUrl")
37
+
38
+ def get_prompt(self) -> str:
39
+ return self.__data.get("prompt")
40
+
41
+ def get_userId(self) -> str:
42
+ return self.__data.get("userId", "")
43
+
44
+ def get_email(self) -> str:
45
+ return self.__data.get("email", "")
46
+
47
+ def get_style(self) -> str:
48
+ return self.__data.get("style", None)
49
+
50
+ def get_iteration(self) -> float:
51
+ return float(self.__data.get("iteration", 3.0))
52
+
53
+ def get_modelType(self) -> ModelType:
54
+ id = int(self.__data.get("model_id", 10000))
55
+ return ModelType(id)
56
+
57
+ def get_width(self) -> int:
58
+ return int(self.__data.get("width", 512))
59
+
60
+ def get_height(self) -> int:
61
+ return int(self.__data.get("height", 512))
62
+
63
+ def get_seed(self) -> int:
64
+ return int(self.__data.get("seed", -1))
65
+
66
+ def get_steps(self) -> int:
67
+ return int(self.__data.get("steps", "75"))
68
+
69
+ def get_type(self) -> Union[TaskType, None]:
70
+ try:
71
+ return TaskType(self.__data.get("task_type"))
72
+ except ValueError:
73
+ return None
74
+
75
+ def get_maskImageUrl(self) -> str:
76
+ return self.__data.get("maskImageUrl")
77
+
78
+ def get_negative_prompt(self) -> str:
79
+ return self.__data.get("negative_prompt", "")
80
+
81
+ def is_prompt_engineering(self) -> bool:
82
+ return self.__data.get("auto_mode", True)
83
+
84
+ def get_raw(self) -> dict:
85
+ return self.__data.copy()
inference.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from data.dataAccessor import update_db
5
+ from data.task import Task, TaskType
6
+ from pipelines.commons import Img2Img, Text2Img
7
+ from pipelines.controlnets import ControlNet
8
+ from pipelines.prompt_modifier import PromptModifier
9
+ from util.cache import auto_clear_cuda_and_gc, clear_cuda
10
+ from util.commons import add_code_names, pickPoses, upload_images
11
+ from util.lora_style import LoraStyle
12
+ from util.slack import Slack
13
+
14
+ torch.backends.cudnn.benchmark = True
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+
17
+ num_return_sequences = 4 # the number of results to generate
18
+ auto_mode = False
19
+
20
+ prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
21
+ controlnet = ControlNet()
22
+ lora_style = LoraStyle()
23
+ text2img_pipe = Text2Img()
24
+ img2img_pipe = Img2Img()
25
+ slack = Slack()
26
+
27
+
28
+ def get_patched_prompt(task: Task):
29
+ def add_style_and_character(prompt: List[str]):
30
+ for i in range(len(prompt)):
31
+ prompt[i] = add_code_names(prompt[i])
32
+ prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
33
+
34
+ prompt = task.get_prompt()
35
+
36
+ if task.is_prompt_engineering():
37
+ prompt = prompt_modifier.modify(prompt)
38
+ else:
39
+ prompt = [prompt] * num_return_sequences
40
+
41
+ ori_prompt = [task.get_prompt()] * num_return_sequences
42
+
43
+ add_style_and_character(ori_prompt)
44
+ add_style_and_character(prompt)
45
+
46
+ print({"prompts": prompt})
47
+
48
+ return (prompt, ori_prompt)
49
+
50
+
51
+ @update_db
52
+ @auto_clear_cuda_and_gc(controlnet)
53
+ @slack.auto_send_alert
54
+ def canny(task: Task):
55
+ prompt, _ = get_patched_prompt(task)
56
+
57
+ controlnet.load_canny()
58
+
59
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
60
+ lora_patcher.patch()
61
+
62
+ images = controlnet.process_canny(
63
+ prompt=prompt,
64
+ imageUrl=task.get_imageUrl(),
65
+ seed=task.get_seed(),
66
+ steps=task.get_steps(),
67
+ width=task.get_width(),
68
+ height=task.get_height(),
69
+ negative_prompt=[
70
+ f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
71
+ ]
72
+ * num_return_sequences,
73
+ **lora_patcher.kwargs(),
74
+ )
75
+
76
+ generated_image_urls = upload_images(images, "_canny", task.get_taskId())
77
+
78
+ lora_patcher.cleanup()
79
+ controlnet.cleanup()
80
+
81
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
82
+
83
+
84
+ @update_db
85
+ @auto_clear_cuda_and_gc(controlnet)
86
+ @slack.auto_send_alert
87
+ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
88
+ prompt, _ = get_patched_prompt(task)
89
+
90
+ controlnet.load_pose()
91
+
92
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
93
+ lora_patcher.patch()
94
+
95
+ if poses is None:
96
+ poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
97
+
98
+ images = controlnet.process_pose(
99
+ prompt=prompt,
100
+ image=poses,
101
+ seed=task.get_seed(),
102
+ steps=task.get_steps(),
103
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
104
+ width=task.get_width(),
105
+ height=task.get_height(),
106
+ **lora_patcher.kwargs(),
107
+ )
108
+
109
+ generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
110
+
111
+ lora_patcher.cleanup()
112
+ controlnet.cleanup()
113
+
114
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
115
+
116
+
117
+ @update_db
118
+ @auto_clear_cuda_and_gc(controlnet)
119
+ @slack.auto_send_alert
120
+ def text2img(task: Task):
121
+ prompt, ori_prompt = get_patched_prompt(task)
122
+
123
+ lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
124
+ lora_patcher.patch()
125
+
126
+ torch.manual_seed(task.get_seed())
127
+
128
+ images = text2img_pipe.process(
129
+ prompt=ori_prompt,
130
+ modified_prompts=prompt,
131
+ num_inference_steps=task.get_steps(),
132
+ guidance_scale=7.5,
133
+ height=task.get_height(),
134
+ width=task.get_width(),
135
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
136
+ iteration=task.get_iteration(),
137
+ **lora_patcher.kwargs(),
138
+ )
139
+
140
+ generated_image_urls = upload_images(images, "", task.get_taskId())
141
+
142
+ lora_patcher.cleanup()
143
+
144
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
145
+
146
+
147
+ @update_db
148
+ @auto_clear_cuda_and_gc(controlnet)
149
+ @slack.auto_send_alert
150
+ def img2img(task: Task):
151
+ prompt, _ = get_patched_prompt(task)
152
+
153
+ lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
154
+ lora_patcher.patch()
155
+
156
+ torch.manual_seed(task.get_seed())
157
+
158
+ images = img2img_pipe.process(
159
+ prompt=prompt,
160
+ imageUrl=task.get_imageUrl(),
161
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
162
+ steps=task.get_steps(),
163
+ **lora_patcher.kwargs(),
164
+ )
165
+
166
+ generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
167
+
168
+ lora_patcher.cleanup()
169
+
170
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
171
+
172
+
173
+ def model_fn(model_dir):
174
+ print("Logs: model loaded .... starts")
175
+
176
+ prompt_modifier.load()
177
+
178
+ lora_style.load(model_dir)
179
+ controlnet.load(model_dir)
180
+
181
+ text2img_pipe.load(model_dir)
182
+ img2img_pipe.load(model_dir)
183
+
184
+ print("Logs: model loaded ....")
185
+ return
186
+
187
+
188
+ def predict_fn(data, pipe):
189
+ task = Task(data)
190
+ print("task is ", data)
191
+
192
+ try:
193
+ task_type = task.get_type()
194
+
195
+ if task_type == TaskType.TEXT_TO_IMAGE:
196
+ # character sheet
197
+ if "character sheet" in task.get_prompt().lower():
198
+ return pose(task, s3_outkey="", poses=pickPoses())
199
+ else:
200
+ return text2img(task)
201
+ elif task_type == TaskType.IMAGE_TO_IMAGE:
202
+ return img2img(task)
203
+ elif task_type == TaskType.CANNY:
204
+ return canny(task)
205
+ elif task_type == TaskType.POSE:
206
+ return pose(task)
207
+ else:
208
+ raise Exception("Invalid task type")
209
+ except Exception as e:
210
+ print(f"Error: {e}")
211
+ slack.error_alert(task, e)
212
+ return None
inference2.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import torch
4
+ from data.dataAccessor import update_db
5
+ from data.task import ModelType, Task, TaskType
6
+ from pipelines.inpainter import InPainter
7
+ from pipelines.prompt_modifier import PromptModifier
8
+ from pipelines.remove_background import RemoveBackground
9
+ from pipelines.upscaler import Upscaler
10
+ from util.cache import clear_cuda
11
+ from util.commons import (
12
+ add_code_names,
13
+ construct_default_s3_url,
14
+ upload_image,
15
+ upload_images,
16
+ )
17
+ from util.slack import Slack
18
+
19
+ torch.backends.cudnn.benchmark = True
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+
22
+ num_return_sequences = 4
23
+ auto_mode = False
24
+
25
+ slack = Slack()
26
+
27
+ prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
28
+ upscaler = Upscaler()
29
+ inpainter = InPainter()
30
+
31
+
32
+ @update_db
33
+ @slack.auto_send_alert
34
+ def remove_bg(task: Task):
35
+ remove_background = RemoveBackground()
36
+ output_image = remove_background.remove(task.get_imageUrl())
37
+
38
+ output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
39
+ upload_image(output_image, output_key)
40
+
41
+ return {"generated_image_url": construct_default_s3_url(output_key)}
42
+
43
+
44
+ @update_db
45
+ @slack.auto_send_alert
46
+ def inpaint(task: Task):
47
+ prompt = add_code_names(task.get_prompt())
48
+ if task.is_prompt_engineering():
49
+ prompt = prompt_modifier.modify(prompt)
50
+ else:
51
+ prompt = [prompt] * num_return_sequences
52
+
53
+ print({"prompts": prompt})
54
+
55
+ images = inpainter.process(
56
+ prompt=prompt,
57
+ image_url=task.get_imageUrl(),
58
+ mask_image_url=task.get_maskImageUrl(),
59
+ width=task.get_width(),
60
+ height=task.get_height(),
61
+ seed=task.get_seed(),
62
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
63
+ )
64
+ generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
65
+
66
+ clear_cuda()
67
+
68
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
69
+
70
+
71
+ @update_db
72
+ @slack.auto_send_alert
73
+ def upscale_image(task: Task):
74
+ output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
75
+ out_img = None
76
+ if task.get_modelType() == ModelType.ANIME:
77
+ print("Using Anime model")
78
+ out_img = upscaler.upscale_anime(task.get_imageUrl())
79
+ else:
80
+ print("Using Real model")
81
+ out_img = upscaler.upscale(task.get_imageUrl())
82
+
83
+ upload_image(BytesIO(out_img), output_key)
84
+ return {"generated_image_url": construct_default_s3_url(output_key)}
85
+
86
+
87
+ def model_fn(model_dir):
88
+ print("Logs: model loaded .... starts")
89
+
90
+ prompt_modifier.load()
91
+ upscaler.load()
92
+ inpainter.load()
93
+
94
+ print("Logs: model loaded ....")
95
+ return
96
+
97
+
98
+ def predict_fn(data, pipe):
99
+ task = Task(data)
100
+ print("task is ", data)
101
+
102
+ try:
103
+ task_type = task.get_type()
104
+
105
+ if task_type == TaskType.REMOVE_BG:
106
+ return remove_bg(task)
107
+ elif task_type == TaskType.INPAINT:
108
+ return inpaint(task)
109
+ elif task_type == TaskType.UPSCALE_IMAGE:
110
+ return upscale_image(task)
111
+ else:
112
+ raise Exception("Invalid task type")
113
+ except Exception as e:
114
+ print(f"Error: {e}")
115
+ slack.error_alert(task, e)
116
+ return None
pipelines/commons.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionImg2ImgPipeline
5
+ from pipelines.twoStepPipeline import two_step_pipeline
6
+ from util.commons import disable_safety_checker, download_image
7
+
8
+
9
+ class Text2Img:
10
+ def load(self, model_dir: str):
11
+ self.pipe = two_step_pipeline.from_pretrained(
12
+ model_dir, torch_dtype=torch.float16
13
+ ).to("cuda")
14
+ self.pipe.enable_xformers_memory_efficient_attention()
15
+ disable_safety_checker(self.pipe)
16
+
17
+ @torch.inference_mode()
18
+ def process(
19
+ self,
20
+ prompt: Union[str, List[str]] = None,
21
+ modified_prompts: Union[str, List[str]] = None,
22
+ height: Optional[int] = None,
23
+ width: Optional[int] = None,
24
+ num_inference_steps: int = 50,
25
+ guidance_scale: float = 7.5,
26
+ negative_prompt: Optional[Union[str, List[str]]] = None,
27
+ num_images_per_prompt: Optional[int] = 1,
28
+ eta: float = 0.0,
29
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
30
+ latents: Optional[torch.FloatTensor] = None,
31
+ prompt_embeds: Optional[torch.FloatTensor] = None,
32
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
33
+ output_type: Optional[str] = "pil",
34
+ return_dict: bool = True,
35
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
36
+ callback_steps: int = 1,
37
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
38
+ iteration: float = 3.0,
39
+ ):
40
+ return self.pipe.two_step_pipeline(
41
+ prompt=prompt,
42
+ modified_prompts=modified_prompts,
43
+ height=height,
44
+ width=width,
45
+ num_inference_steps=num_inference_steps,
46
+ guidance_scale=guidance_scale,
47
+ negative_prompt=negative_prompt,
48
+ num_images_per_prompt=num_images_per_prompt,
49
+ eta=eta,
50
+ generator=generator,
51
+ latents=latents,
52
+ prompt_embeds=prompt_embeds,
53
+ negative_prompt_embeds=negative_prompt_embeds,
54
+ output_type=output_type,
55
+ return_dict=return_dict,
56
+ callback=callback,
57
+ callback_steps=callback_steps,
58
+ cross_attention_kwargs=cross_attention_kwargs,
59
+ iteration=iteration,
60
+ ).images
61
+
62
+
63
+ class Img2Img:
64
+ def load(self, model_dir: str):
65
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
66
+ model_dir, torch_dtype=torch.float16
67
+ ).to("cuda")
68
+ self.pipe.enable_xformers_memory_efficient_attention()
69
+ disable_safety_checker(self.pipe)
70
+
71
+ @torch.inference_mode()
72
+ def process(
73
+ self, prompt: List[str], imageUrl: str, negative_prompt: List[str], steps: int
74
+ ):
75
+ image = download_image(imageUrl)
76
+
77
+ return self.pipe.__call__(
78
+ prompt=prompt,
79
+ image=image,
80
+ strength=0.75,
81
+ negative_prompt=negative_prompt,
82
+ guidance_scale=7.5,
83
+ num_images_per_prompt=1,
84
+ num_inference_steps=steps,
85
+ ).images
pipelines/controlnets.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from controlnet_aux import OpenposeDetector
7
+ from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline,
8
+ UniPCMultistepScheduler)
9
+ from PIL import Image
10
+ from util.cache import clear_cuda_and_gc
11
+ from util.commons import disable_safety_checker, download_image
12
+
13
+
14
+ class ControlNet:
15
+ __current_task_name = ""
16
+
17
+ def load(self, model_dir: str):
18
+ # we will load canny by default
19
+ self.load_canny()
20
+
21
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
22
+ model_dir, controlnet=self.controlnet, torch_dtype=torch.float16
23
+ )
24
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
25
+ pipe.enable_model_cpu_offload()
26
+ pipe.enable_xformers_memory_efficient_attention()
27
+ disable_safety_checker(pipe)
28
+ self.pipe = pipe
29
+
30
+ def load_canny(self):
31
+ if self.__current_task_name == "canny":
32
+ return
33
+ canny = ControlNetModel.from_pretrained(
34
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
35
+ ).to("cuda")
36
+ self.__current_task_name = "canny"
37
+ self.controlnet = canny
38
+ if hasattr(self, "pipe"):
39
+ self.pipe.controlnet = canny
40
+ clear_cuda_and_gc()
41
+
42
+ def load_pose(self):
43
+ if self.__current_task_name == "pose":
44
+ return
45
+ pose = ControlNetModel.from_pretrained(
46
+ "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
47
+ ).to("cuda")
48
+ self.__current_task_name = "pose"
49
+ self.controlnet = pose
50
+ if hasattr(self, "pipe"):
51
+ self.pipe.controlnet = pose
52
+ clear_cuda_and_gc()
53
+
54
+ def cleanup(self):
55
+ self.pipe.controlnet = None
56
+ self.controlnet = None
57
+ self.__current_task_name = ""
58
+
59
+ clear_cuda_and_gc()
60
+
61
+ @torch.inference_mode()
62
+ def process_canny(
63
+ self,
64
+ prompt: List[str],
65
+ imageUrl: str,
66
+ seed: int,
67
+ steps: int,
68
+ negative_prompt: List[str],
69
+ height: int,
70
+ width: int,
71
+ ):
72
+ if self.__current_task_name != "canny":
73
+ raise Exception("ControlNet is not loaded with canny model")
74
+
75
+ torch.manual_seed(seed)
76
+
77
+ init_image = download_image(imageUrl)
78
+ init_image = self.__canny_detect_edge(init_image)
79
+
80
+ return self.pipe.__call__(
81
+ prompt=prompt,
82
+ image=init_image,
83
+ guidance_scale=9,
84
+ num_images_per_prompt=1,
85
+ negative_prompt=negative_prompt,
86
+ num_inference_steps=steps,
87
+ height=height,
88
+ width=width,
89
+ ).images
90
+
91
+ @torch.inference_mode()
92
+ def process_pose(
93
+ self,
94
+ prompt: List[str],
95
+ image: List[Image.Image],
96
+ seed: int,
97
+ steps: int,
98
+ negative_prompt: List[str],
99
+ height: int,
100
+ width: int,
101
+ ):
102
+ if self.__current_task_name != "pose":
103
+ raise Exception("ControlNet is not loaded with pose model")
104
+
105
+ torch.manual_seed(seed)
106
+
107
+ return self.pipe.__call__(
108
+ prompt=prompt,
109
+ image=image,
110
+ num_images_per_prompt=1,
111
+ num_inference_steps=steps,
112
+ negative_prompt=negative_prompt,
113
+ height=height,
114
+ width=width,
115
+ ).images
116
+
117
+ def detect_pose(self, imageUrl: str) -> Image.Image:
118
+ detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
119
+ image = download_image(imageUrl)
120
+ image = detector.__call__(image)
121
+ return image
122
+
123
+ def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
124
+ image_array = np.array(image)
125
+
126
+ low_threshold = 100
127
+ high_threshold = 200
128
+
129
+ image_array = cv2.Canny(image_array, low_threshold, high_threshold)
130
+ image_array = image_array[:, :, None]
131
+ image_array = np.concatenate([image_array, image_array, image_array], axis=2)
132
+ canny_image = Image.fromarray(image_array)
133
+ return canny_image
pipelines/inpainter.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ from util.commons import disable_safety_checker, download_image
6
+
7
+
8
+ class InPainter:
9
+ def load(self):
10
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
11
+ "runwayml/stable-diffusion-inpainting",
12
+ torch_dtype=torch.float16,
13
+ revision="fp16",
14
+ ).to("cuda")
15
+ disable_safety_checker(self.pipe)
16
+
17
+ @torch.inference_mode()
18
+ def process(
19
+ self,
20
+ image_url: str,
21
+ mask_image_url: str,
22
+ width: int,
23
+ height: int,
24
+ seed: int,
25
+ prompt: Union[str, List[str]],
26
+ negative_prompt: Union[str, List[str]],
27
+ ):
28
+ torch.manual_seed(seed)
29
+
30
+ input_img = download_image(image_url).resize((width, height))
31
+ mask_img = download_image(mask_image_url).resize((width, height))
32
+
33
+ return self.pipe.__call__(
34
+ prompt=prompt,
35
+ image=input_img,
36
+ mask_image=mask_img,
37
+ height=height,
38
+ width=width,
39
+ negative_prompt=negative_prompt,
40
+ ).images
pipelines/prompt_modifier.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+
5
+
6
+ class PromptModifier:
7
+ def __init__(self, num_of_sequences: Optional[int] = 4):
8
+ self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
9
+ self.__num_of_sequences = num_of_sequences
10
+
11
+ def load(self):
12
+ self.prompter_model = AutoModelForCausalLM.from_pretrained(
13
+ "Gustavosta/MagicPrompt-Stable-Diffusion"
14
+ )
15
+ self.prompter_tokenizer = AutoTokenizer.from_pretrained(
16
+ "Gustavosta/MagicPrompt-Stable-Diffusion"
17
+ )
18
+ self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
19
+ self.prompter_tokenizer.padding_side = "left"
20
+
21
+ def modify(self, text: str) -> List[str]:
22
+ eos_id = self.prompter_tokenizer.eos_token_id
23
+ # restricted_words_list = ["octane", "cyber"]
24
+ # restricted_words_token_ids = prompter_tokenizer(
25
+ # restricted_words_list, add_special_tokens=False
26
+ # ).input_ids
27
+
28
+ generation_config = GenerationConfig(
29
+ do_sample=False,
30
+ max_new_tokens=75,
31
+ num_beams=4,
32
+ num_return_sequences=self.__num_of_sequences,
33
+ eos_token_id=eos_id,
34
+ pad_token_id=eos_id,
35
+ length_penalty=-1.0,
36
+ )
37
+
38
+ input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids
39
+ outputs = self.prompter_model.generate(
40
+ input_ids, generation_config=generation_config
41
+ )
42
+ output_texts = self.prompter_tokenizer.batch_decode(
43
+ outputs, skip_special_tokens=True
44
+ )
45
+ output_texts = self.__patch_blacklist_words(output_texts)
46
+ return output_texts
47
+
48
+ def __patch_blacklist_words(self, texts: List[str]):
49
+ def replace_all(text, dic):
50
+ for i, j in dic.items():
51
+ text = text.replace(i, j)
52
+ return text
53
+
54
+ return [replace_all(text, self.__blacklist) for text in texts]
pipelines/remove_background.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Union
3
+
4
+ from PIL import Image
5
+ from rembg import remove
6
+ from util.commons import read_url
7
+
8
+
9
+ class RemoveBackground:
10
+ def remove(self, image: Union[str, Image.Image]) -> Image.Image:
11
+ if type(image) is str:
12
+ image = Image.open(io.BytesIO(read_url(image)))
13
+
14
+ output = remove(image)
15
+ return output
pipelines/twoStepPipeline.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+
4
+ torch.backends.cudnn.benchmark = True
5
+ torch.backends.cuda.matmul.allow_tf32 = True
6
+
7
+ from typing import Any, Callable, Dict, List, Optional, Union
8
+
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
11
+
12
+
13
+ class two_step_pipeline(StableDiffusionPipeline):
14
+ @torch.no_grad()
15
+ def two_step_pipeline(
16
+ self,
17
+ prompt: Union[str, List[str]] = None,
18
+ modified_prompts: Union[str, List[str]] = None,
19
+ height: Optional[int] = None,
20
+ width: Optional[int] = None,
21
+ num_inference_steps: int = 50,
22
+ guidance_scale: float = 7.5,
23
+ negative_prompt: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ eta: float = 0.0,
26
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
27
+ latents: Optional[torch.FloatTensor] = None,
28
+ prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33
+ callback_steps: int = 1,
34
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35
+ iteration: float = 3.0,
36
+ ):
37
+ r"""
38
+ Function invoked when calling the pipeline for generation.
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ Examples:
96
+ Returns:
97
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
98
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
99
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
100
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
101
+ (nsfw) content, according to the `safety_checker`.
102
+ """
103
+ # 0. Default height and width to unet
104
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
105
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
106
+
107
+ # 1. Check inputs. Raise error if not correct
108
+ self.check_inputs(
109
+ prompt,
110
+ height,
111
+ width,
112
+ callback_steps,
113
+ negative_prompt,
114
+ prompt_embeds,
115
+ negative_prompt_embeds,
116
+ )
117
+
118
+ # 2. Define call parameters
119
+ if prompt is not None and isinstance(prompt, str):
120
+ batch_size = 1
121
+ elif prompt is not None and isinstance(prompt, list):
122
+ batch_size = len(prompt)
123
+ else:
124
+ batch_size = prompt_embeds.shape[0]
125
+
126
+ device = self._execution_device
127
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
128
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
129
+ # corresponds to doing no classifier free guidance.
130
+ do_classifier_free_guidance = guidance_scale > 1.0
131
+
132
+ # 3. Encode input prompt
133
+ modified_embeds = self._encode_prompt(
134
+ modified_prompts,
135
+ device,
136
+ num_images_per_prompt,
137
+ do_classifier_free_guidance,
138
+ negative_prompt,
139
+ prompt_embeds=prompt_embeds,
140
+ negative_prompt_embeds=negative_prompt_embeds,
141
+ )
142
+ print("mod prompt size : ", modified_embeds.size(), modified_embeds.dtype)
143
+
144
+ prompt_embeds = self._encode_prompt(
145
+ prompt,
146
+ device,
147
+ num_images_per_prompt,
148
+ do_classifier_free_guidance,
149
+ negative_prompt,
150
+ prompt_embeds=prompt_embeds,
151
+ negative_prompt_embeds=negative_prompt_embeds,
152
+ )
153
+
154
+ print("prompt size : ", prompt_embeds.size(), prompt_embeds.dtype)
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+
160
+ # 5. Prepare latent variables
161
+ num_channels_latents = self.unet.config.in_channels
162
+ latents = self.prepare_latents(
163
+ batch_size * num_images_per_prompt,
164
+ num_channels_latents,
165
+ height,
166
+ width,
167
+ prompt_embeds.dtype,
168
+ device,
169
+ generator,
170
+ latents,
171
+ )
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
179
+ for i, t in enumerate(timesteps):
180
+ # expand the latents if we are doing classifier free guidance
181
+ latent_model_input = (
182
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
183
+ )
184
+ latent_model_input = self.scheduler.scale_model_input(
185
+ latent_model_input, t
186
+ )
187
+
188
+ # predict the noise residual
189
+ noise_pred = self.unet(
190
+ latent_model_input,
191
+ t,
192
+ encoder_hidden_states=prompt_embeds,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ ).sample
195
+
196
+ # perform guidance
197
+ if do_classifier_free_guidance:
198
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
199
+ noise_pred = noise_pred_uncond + guidance_scale * (
200
+ noise_pred_text - noise_pred_uncond
201
+ )
202
+
203
+ # compute the previous noisy sample x_t -> x_t-1
204
+ latents = self.scheduler.step(
205
+ noise_pred, t, latents, **extra_step_kwargs
206
+ ).prev_sample
207
+
208
+ # call the callback, if provided
209
+ if i == len(timesteps) - 1 or (
210
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
211
+ ):
212
+ progress_bar.update()
213
+ if callback is not None and i % callback_steps == 0:
214
+ callback(i, t, latents)
215
+
216
+ if i == int(len(timesteps) / iteration):
217
+ print("modified prompts")
218
+ prompt_embeds = modified_embeds
219
+
220
+ if output_type == "latent":
221
+ image = latents
222
+ has_nsfw_concept = None
223
+ elif output_type == "pil":
224
+ # 8. Post-processing
225
+ image = self.decode_latents(latents)
226
+
227
+ # 9. Run safety checker
228
+ image, has_nsfw_concept = self.run_safety_checker(
229
+ image, device, prompt_embeds.dtype
230
+ )
231
+
232
+ # 10. Convert to PIL
233
+ image = self.numpy_to_pil(image)
234
+ else:
235
+ # 8. Post-processing
236
+ image = self.decode_latents(latents)
237
+
238
+ # 9. Run safety checker
239
+ image, has_nsfw_concept = self.run_safety_checker(
240
+ image, device, prompt_embeds.dtype
241
+ )
242
+
243
+ # Offload last model to CPU
244
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
245
+ self.final_offload_hook.offload()
246
+
247
+ if not return_dict:
248
+ return (image, has_nsfw_concept)
249
+
250
+ return StableDiffusionPipelineOutput(
251
+ images=image, nsfw_content_detected=has_nsfw_concept
252
+ )
pipelines/upscaler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from basicsr.archs.rrdbnet_arch import RRDBNet
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from PIL import Image
10
+ from realesrgan import RealESRGANer
11
+ from util.commons import read_url
12
+
13
+
14
+ class Upscaler:
15
+ __model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
16
+ __model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
17
+
18
+ def load(self):
19
+ download_dir = Path(Path.home() / ".cache" / "realesrgan")
20
+ download_dir.mkdir(parents=True, exist_ok=True)
21
+
22
+ self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
23
+ self.__model_path_anime = self.__preload_model(
24
+ self.__model_esrgan_anime_url, download_dir
25
+ )
26
+
27
+ def upscale(self, image: Union[str, bytes]) -> bytes:
28
+ model = RRDBNet(
29
+ num_in_ch=3,
30
+ num_out_ch=3,
31
+ num_feat=64,
32
+ num_block=23,
33
+ num_grow_ch=32,
34
+ scale=4,
35
+ )
36
+ return self.__internal_upscale(image, self.__model_path, model)
37
+
38
+ def upscale_anime(self, image: Union[str, bytes]) -> bytes:
39
+ model = RRDBNet(
40
+ num_in_ch=3,
41
+ num_out_ch=3,
42
+ num_feat=64,
43
+ num_block=23,
44
+ num_grow_ch=32,
45
+ scale=4,
46
+ )
47
+ return self.__internal_upscale(image, self.__model_path_anime, model)
48
+
49
+ def __preload_model(self, url: str, download_dir: Path):
50
+ name = url.split("/")[-1]
51
+ if not os.path.exists(str(download_dir / name)):
52
+ return load_file_from_url(
53
+ url=url,
54
+ model_dir=str(download_dir),
55
+ progress=True,
56
+ file_name=None,
57
+ )
58
+ else:
59
+ return str(download_dir / name)
60
+
61
+ def __internal_upscale(
62
+ self,
63
+ image: Union[str, bytes],
64
+ model_path: str,
65
+ rrbdnet: RRDBNet,
66
+ ) -> bytes:
67
+ if type(image) is str:
68
+ image = read_url(image)
69
+
70
+ upsampler = RealESRGANer(
71
+ scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0"
72
+ )
73
+ image_array = np.frombuffer(image, dtype=np.uint8)
74
+ input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
75
+ output, _ = upsampler.enhance(input_image, outscale=4)
76
+ out_bytes = cv2.imencode(".png", output)[1].tobytes()
77
+ return out_bytes
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aioredis==1.3.1
2
+ boto3==1.24.61
3
+ triton==2.0.0
4
+ diffusers==0.14.0
5
+ fastapi==0.87.0
6
+ Pillow==9.3.0
7
+ redis==4.3.4
8
+ requests==2.28.1
9
+ transformers
10
+ rembg==2.0.30
11
+ accelerate==0.17.0
12
+ gfpgan==1.3.8
13
+ rembg==2.0.30
14
+ controlnet-aux==0.0.1
15
+ realesrgan==0.3.0
16
+ compel==1.0.4
17
+ xformers
18
+ torchvision
19
+ git+https://github.com/cloneofsimo/lora.git
util/__init__.py ADDED
File without changes
util/cache.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import torch
4
+
5
+
6
+ def clear_cuda_and_gc():
7
+ clear_cuda()
8
+ clear_gc()
9
+
10
+
11
+ def clear_cuda():
12
+ torch.cuda.empty_cache()
13
+
14
+
15
+ def clear_gc():
16
+ gc.collect()
17
+
18
+
19
+ def auto_clear_cuda_and_gc(controlnet):
20
+ def auto_clear_cuda_and_gc_wrapper(func):
21
+ def wrapper(*args, **kwargs):
22
+ try:
23
+ return func(*args, **kwargs)
24
+ except Exception as e:
25
+ controlnet.cleanup()
26
+ clear_cuda_and_gc()
27
+ raise e
28
+
29
+ return wrapper
30
+
31
+ return auto_clear_cuda_and_gc_wrapper
util/commons.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import random
3
+ import re
4
+ from io import BytesIO
5
+ from typing import Union
6
+
7
+ import boto3
8
+ import requests
9
+
10
+ s3 = boto3.client("s3")
11
+ import io
12
+ import urllib.request
13
+
14
+ from PIL import Image
15
+
16
+ black_list = {"alphonse mucha": "", "adolphe bouguereau": ""}
17
+ pp = pprint.PrettyPrinter(indent=4)
18
+ Avatar = [
19
+ {
20
+ "avatarName": "niomi",
21
+ "codename": "1jMGp1kFkG",
22
+ "avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/7_char_assets/niyomi_1jMGp1kFkG.png",
23
+ "extraPrompt": "1jMGp1kFkG girl",
24
+ },
25
+ {
26
+ "avatarName": "riya",
27
+ "codename": "vW6AUQtoaY",
28
+ "avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/12_char_assets/riya_vW6AUQtoaY.png",
29
+ "extraPrompt": "vW6AUQtoaY girl",
30
+ },
31
+ {
32
+ "avatarName": "rajveer",
33
+ "codename": "fSLF0OPkBw",
34
+ "avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/12_character_assets/rajveer_fSLF0OPkBw.png",
35
+ "extraPrompt": "fSLF0OPkBw guy",
36
+ },
37
+ {
38
+ "avatarName": "bheem",
39
+ "codename": "HL79CB3ODZ",
40
+ "avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/7_char_assets/scene01001_1.png",
41
+ "extraPrompt": "HL79CB3ODZ boy",
42
+ },
43
+ {
44
+ "avatarName": "chutki",
45
+ "codename": "SJ7JVIS9M7",
46
+ "avatarImage": "https://comic-assets.s3.ap-south-1.amazonaws.com/7_char_assets/14.png",
47
+ "extraPrompt": "SJ7JVIS9M7 girl",
48
+ },
49
+ ]
50
+
51
+ webhook_url = (
52
+ "https://hooks.slack.com/services/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"
53
+ )
54
+ error_webhook = (
55
+ "https://hooks.slack.com/services/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"
56
+ )
57
+
58
+ characterSheets = [
59
+ "character+sheets/1.1.png",
60
+ "character+sheets/10.1.png",
61
+ "character+sheets/11.1.png",
62
+ "character+sheets/12.1.png",
63
+ "character+sheets/13.1.png",
64
+ "character+sheets/14.1.png",
65
+ "character+sheets/16.1.png",
66
+ "character+sheets/17.1.png",
67
+ "character+sheets/18.1.png",
68
+ "character+sheets/19.1.png",
69
+ "character+sheets/2.1.png",
70
+ "character+sheets/20.1.png",
71
+ "character+sheets/21.1.png",
72
+ "character+sheets/22.1.png",
73
+ "character+sheets/23.1.png",
74
+ "character+sheets/24.1.png",
75
+ "character+sheets/25.1.png",
76
+ "character+sheets/26.1.png",
77
+ "character+sheets/27.1.png",
78
+ "character+sheets/28.1.png",
79
+ "character+sheets/29.1.png",
80
+ "character+sheets/3.1.png",
81
+ "character+sheets/30.1.png",
82
+ "character+sheets/31.1.png",
83
+ "character+sheets/32.1.png",
84
+ "character+sheets/33.1.png",
85
+ "character+sheets/34.1.png",
86
+ "character+sheets/35.1.png",
87
+ "character+sheets/36.1.png",
88
+ "character+sheets/38.1.png",
89
+ "character+sheets/39.1.png",
90
+ "character+sheets/4.1.png",
91
+ "character+sheets/40.1.png",
92
+ "character+sheets/42.1.png",
93
+ "character+sheets/43.1.png",
94
+ "character+sheets/44.1.png",
95
+ "character+sheets/45.1.png",
96
+ "character+sheets/46.1.png",
97
+ "character+sheets/47.1.png",
98
+ "character+sheets/48.1.png",
99
+ "character+sheets/49.1.png",
100
+ "character+sheets/5.1.png",
101
+ "character+sheets/50.1.png",
102
+ "character+sheets/51.1.png",
103
+ "character+sheets/52.1.png",
104
+ "character+sheets/53.1.png",
105
+ "character+sheets/54.1.png",
106
+ "character+sheets/55.1.png",
107
+ "character+sheets/56.1.png",
108
+ "character+sheets/57.1.png",
109
+ "character+sheets/58.1.png",
110
+ "character+sheets/59.1.png",
111
+ "character+sheets/60.1.png",
112
+ "character+sheets/61.1.png",
113
+ "character+sheets/62.1.png",
114
+ "character+sheets/63.1.png",
115
+ "character+sheets/64.1.png",
116
+ "character+sheets/65.1.png",
117
+ "character+sheets/66.1.png",
118
+ "character+sheets/7.1.png",
119
+ "character+sheets/8.1.png",
120
+ "character+sheets/9.1.png",
121
+ ]
122
+
123
+
124
+ def add_code_names(sentence):
125
+ array_of_objects = Avatar
126
+
127
+ for obj in array_of_objects:
128
+ sentence = (
129
+ re.sub(
130
+ r"\b" + obj["avatarName"] + r"\b",
131
+ obj["extraPrompt"],
132
+ sentence,
133
+ flags=re.IGNORECASE,
134
+ )
135
+ + " "
136
+ )
137
+ print(sentence)
138
+ return sentence
139
+
140
+
141
+ def upload_images(images, processName: str, taskId: str):
142
+ imageUrls = []
143
+ for i, image in enumerate(images):
144
+ # img_io = BytesIO()
145
+ # image.save(img_io, "JPEG", quality=70)
146
+ # img_io.seek(0)
147
+ # key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
148
+ # t = s3.put_object(
149
+ # Bucket="comic-assets", Key=key, Body=img_io.getvalue(), ACL="public-read"
150
+ # )
151
+ # print("uploading done to s3", key, t)
152
+ imageUrls.append(
153
+ "https://comic-assets.s3.ap-south-1.amazonaws.com/crecoAI/{}{}_{}.png".format(
154
+ taskId, processName, i
155
+ )
156
+ )
157
+
158
+ print({"promptImages": imageUrls})
159
+
160
+ return imageUrls
161
+
162
+
163
+ # def upload_image(image: Union[Image.Image, BytesIO], out_path):
164
+ # if type(image) is Image.Image:
165
+ # buffer = io.BytesIO()
166
+ # image.save(buffer, format="PNG")
167
+ # image = buffer
168
+
169
+ # image.seek(0)
170
+ # s3.upload_fileobj(image, "comic-assets", out_path, ExtraArgs={"ACL": "public-read"})
171
+ # image.close()
172
+
173
+
174
+ def download_image(url) -> Image.Image:
175
+ response = requests.get(url)
176
+ return Image.open(BytesIO(response.content)).convert("RGB")
177
+
178
+
179
+ def pickPoses():
180
+ random_images = random.sample(characterSheets, 4)
181
+ poses = []
182
+ prefix = "https://comic-assets.s3.ap-south-1.amazonaws.com/"
183
+
184
+ # Use list comprehension to add prefix to all elements in the array
185
+ random_images_with_prefix = [prefix + img for img in random_images]
186
+
187
+ print(random_images_with_prefix)
188
+ for imageUrl in random_images_with_prefix:
189
+ # Download and resize the image
190
+ init_image = download_image(imageUrl).resize((512, 512))
191
+
192
+ # Open the pose image
193
+ imageUrlPose = imageUrl
194
+ # print(imageUrl)
195
+ input_image_bytes = read_url(imageUrlPose)
196
+ # print(input_image_bytes)
197
+ pose_image = Image.open(io.BytesIO(input_image_bytes)).convert("RGB")
198
+ # print(pose_image)
199
+ pose_image = pose_image.resize((512, 512))
200
+ # print(pose_image)
201
+ # Append the result to the poses array
202
+ poses.append(pose_image)
203
+
204
+ return poses
205
+
206
+
207
+ def construct_default_s3_url(key):
208
+ return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
209
+
210
+
211
+ def read_url(url: str):
212
+ with urllib.request.urlopen(url) as u:
213
+ return u.read()
214
+
215
+
216
+ def disable_safety_checker(pipe):
217
+ def dummy(images, **kwargs):
218
+ return images, False
219
+
220
+ pipe.safety_checker = None
util/lora_style.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, Union
3
+
4
+ from lora_diffusion import patch_pipe, tune_lora_scale
5
+
6
+
7
+ class LoraStyle:
8
+ class LoraPatcher:
9
+ def __init__(self, pipe, style: Dict[str, Any]):
10
+ self.__style = style
11
+ self.pipe = pipe
12
+
13
+ def patch(self):
14
+ patch_pipe(self.pipe, self.__style["path"])
15
+ tune_lora_scale(self.pipe.unet, self.__style["weight"])
16
+ tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
17
+
18
+ def kwargs(self):
19
+ return {}
20
+
21
+ def cleanup(self):
22
+ tune_lora_scale(self.pipe.unet, 0.0)
23
+ tune_lora_scale(self.pipe.text_encoder, 0.0)
24
+ pass
25
+
26
+ class EmptyLoraPatcher:
27
+ def patch(self):
28
+ pass
29
+
30
+ def kwargs(self):
31
+ return {}
32
+
33
+ def cleanup(self):
34
+ pass
35
+
36
+ def load(self, model_dir: str):
37
+ self.__styles = {
38
+ "nq6akX1CIp": {
39
+ "path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
40
+ "weight": 0.5,
41
+ "negativePrompt": [""],
42
+ "type": "custom",
43
+ },
44
+ "ghibli": {
45
+ "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
46
+ "weight": 1,
47
+ "negativePrompt": [""],
48
+ "type": "custom",
49
+ },
50
+ "eQAmnK2kB2": {
51
+ "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
52
+ "weight": 0.5,
53
+ "negativePrompt": [""],
54
+ "type": "custom",
55
+ },
56
+ "to8contrast": {
57
+ "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
58
+ "weight": 0.5,
59
+ "negativePrompt": [""],
60
+ "type": "custom",
61
+ },
62
+ "jim lee": {
63
+ "path": model_dir + "/laur_style/e2j9mz0jqj/final_lora.bin",
64
+ "weight": 0.8,
65
+ "negativePrompt": [""],
66
+ "type": "custom",
67
+ },
68
+ }
69
+ self.__verify()
70
+
71
+ def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
72
+ if key in self.__styles:
73
+ return f"{key} style {prompt}"
74
+ return prompt
75
+
76
+ def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
77
+ if key in self.__styles:
78
+ style = self.__styles[key]
79
+ return self.LoraPatcher(pipe, style)
80
+ return self.EmptyLoraPatcher()
81
+
82
+ def __verify(self):
83
+ "A method to verify if lora exists within the required path otherwise throw error"
84
+
85
+ for item in self.__styles.keys():
86
+ if not os.path.exists(self.__styles[item]["path"]):
87
+ raise Exception(
88
+ "Lora style model "
89
+ + item
90
+ + " not found at path: "
91
+ + self.__styles[item]["path"]
92
+ )
util/slack.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import sleep
2
+ from typing import Optional
3
+
4
+ import requests
5
+ from data.task import Task
6
+
7
+
8
+ class Slack:
9
+ def __init__(self):
10
+ # self.webhook_url = "https://hooks.slack.com/services/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"
11
+ self.webhook_url = "https://hooks.slack.com/services/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"
12
+ self.error_webhook = "https://hooks.slack.com/services/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"
13
+
14
+ def send_alert(self, task: Task, args: Optional[dict]):
15
+ raw = task.get_raw().copy()
16
+
17
+ raw.pop("queue_name", None)
18
+ raw.pop("attempt", None)
19
+ raw.pop("timestamp", None)
20
+ raw.pop("task_id", None)
21
+ raw.pop("maskImageUrl", None)
22
+
23
+ if args is not None:
24
+ raw.update(args.items())
25
+
26
+ message = ""
27
+ for key, value in raw.items():
28
+ if value:
29
+ if type(value) == list:
30
+ message += f"*{key}*: {', '.join(value)}\n"
31
+ else:
32
+ message += f"*{key}*: {value}\n"
33
+
34
+ requests.post(
35
+ self.webhook_url,
36
+ headers={"Content-Type": "application/json"},
37
+ json={"text": message},
38
+ )
39
+
40
+ def error_alert(self, task: Task, e: Exception):
41
+ requests.post(
42
+ self.error_webhook,
43
+ headers={"Content-Type": "application/json"},
44
+ json={
45
+ "text": "Task failed:\n{} \n error is: \n {}".format(task.get_raw(), e)
46
+ },
47
+ )
48
+
49
+ def auto_send_alert(self, func):
50
+ def inner(*args, **kwargs):
51
+ rargs = func(*args, **kwargs)
52
+ self.send_alert(args[0], rargs)
53
+ return rargs
54
+
55
+ return inner