jayparmr commited on
Commit
ea8fc97
·
1 Parent(s): 4adca93

Create ler.py

Browse files
Files changed (1) hide show
  1. ler.py +284 -0
ler.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import base64
6
+ from io import BytesIO
7
+
8
+
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+ from data.dataAccessor import update_db
13
+ from data.task import Task, TaskType
14
+ from pipelines.commons import Img2Img, Text2Img
15
+ from pipelines.controlnets import ControlNet
16
+ from pipelines.prompt_modifier import PromptModifier
17
+ from util.cache import auto_clear_cuda_and_gc, clear_cuda
18
+ from util.commons import add_code_names, pickPoses, upload_images
19
+ from util.lora_style import LoraStyle
20
+ from util.slack import Slack
21
+
22
+ torch.backends.cudnn.benchmark = True
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+
25
+ num_return_sequences = 4 # the number of results to generate
26
+ auto_mode = False
27
+
28
+ prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
29
+ controlnet = ControlNet()
30
+ lora_style = LoraStyle()
31
+ text2img_pipe = Text2Img()
32
+ img2img_pipe = Img2Img()
33
+ slack = Slack()
34
+
35
+
36
+
37
+ def get_patched_prompt(task: Task):
38
+ def add_style_and_character(prompt: List[str]):
39
+ for i in range(len(prompt)):
40
+ prompt[i] = add_code_names(prompt[i])
41
+ prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
42
+
43
+ prompt = task.get_prompt()
44
+
45
+ if task.is_prompt_engineering():
46
+ prompt = prompt_modifier.modify(prompt)
47
+ else:
48
+ prompt = [prompt] * num_return_sequences
49
+
50
+ ori_prompt = [task.get_prompt()] * num_return_sequences
51
+
52
+ add_style_and_character(ori_prompt)
53
+ add_style_and_character(prompt)
54
+
55
+ print({"prompts": prompt})
56
+
57
+ return (prompt, ori_prompt)
58
+
59
+
60
+ # @update_db
61
+ @auto_clear_cuda_and_gc(controlnet)
62
+ @slack.auto_send_alert
63
+ def canny(task: Task):
64
+ prompt, _ = get_patched_prompt(task)
65
+
66
+ controlnet.load_canny()
67
+
68
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
69
+ lora_patcher.patch()
70
+
71
+ images = controlnet.process_canny(
72
+ prompt=prompt,
73
+ imageUrl=task.get_imageUrl(),
74
+ seed=task.get_seed(),
75
+ steps=task.get_steps(),
76
+ width=task.get_width(),
77
+ height=task.get_height(),
78
+ negative_prompt=[
79
+ f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
80
+ ]
81
+ * num_return_sequences,
82
+ **lora_patcher.kwargs(),
83
+ )
84
+
85
+ generated_image_urls = upload_images(images, "_canny", task.get_taskId())
86
+
87
+ lora_patcher.cleanup()
88
+ controlnet.cleanup()
89
+
90
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
91
+
92
+
93
+ # @update_db
94
+ @auto_clear_cuda_and_gc(controlnet)
95
+ @slack.auto_send_alert
96
+ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
97
+ prompt, _ = get_patched_prompt(task)
98
+
99
+ controlnet.load_pose()
100
+
101
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
102
+ lora_patcher.patch()
103
+
104
+ if poses is None:
105
+ poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
106
+
107
+ images = controlnet.process_pose(
108
+ prompt=prompt,
109
+ image=poses,
110
+ seed=task.get_seed(),
111
+ steps=task.get_steps(),
112
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
113
+ width=task.get_width(),
114
+ height=task.get_height(),
115
+ **lora_patcher.kwargs(),
116
+ )
117
+
118
+ generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
119
+
120
+ lora_patcher.cleanup()
121
+ controlnet.cleanup()
122
+
123
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
124
+
125
+
126
+ # @update_db
127
+ @auto_clear_cuda_and_gc(controlnet)
128
+ @slack.auto_send_alert
129
+ def text2img(task: Task):
130
+ prompt, ori_prompt = get_patched_prompt(task)
131
+
132
+ lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
133
+ lora_patcher.patch()
134
+
135
+ torch.manual_seed(task.get_seed())
136
+
137
+ images = text2img_pipe.process(
138
+ prompt=ori_prompt,
139
+ modified_prompts=prompt,
140
+ num_inference_steps=task.get_steps(),
141
+ guidance_scale=7.5,
142
+ height=task.get_height(),
143
+ width=task.get_width(),
144
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
145
+ iteration=task.get_iteration(),
146
+ **lora_patcher.kwargs(),
147
+ )
148
+
149
+ generated_image_urls = upload_images(images, "", task.get_taskId())
150
+
151
+ lora_patcher.cleanup()
152
+
153
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
154
+
155
+
156
+ # @update_db
157
+ @auto_clear_cuda_and_gc(controlnet)
158
+ @slack.auto_send_alert
159
+ def img2img(task: Task):
160
+ prompt, _ = get_patched_prompt(task)
161
+
162
+ lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
163
+ lora_patcher.patch()
164
+
165
+ torch.manual_seed(task.get_seed())
166
+
167
+ images = img2img_pipe.process(
168
+ prompt=prompt,
169
+ imageUrl=task.get_imageUrl(),
170
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
171
+ steps=task.get_steps(),
172
+ **lora_patcher.kwargs(),
173
+ )
174
+
175
+ generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
176
+
177
+ lora_patcher.cleanup()
178
+
179
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
180
+
181
+
182
+
183
+ # set device
184
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
185
+
186
+ if device.type != 'cuda':
187
+ raise ValueError("need to run on GPU")
188
+
189
+ multi_model_list = [
190
+ {"model_id": "/model_v4"},
191
+ {"model_id": "/model_v2"},
192
+ {"model_id": "/model_v3"}
193
+ ]
194
+
195
+ class EndpointHandler():
196
+ def __init__(self, path=""):
197
+ # load the optimized model
198
+ print("Logs: model loaded .... starts")
199
+ print("Logs: path is ", path)
200
+ prompt_modifier.load()
201
+
202
+ lora_style.load(path)
203
+
204
+ self.multi_controlnet_model={}
205
+ self.multi_text2image_model={}
206
+ self.multi_image2image_model={}
207
+ self.path = path
208
+
209
+ for model in multi_model_list:
210
+ print("Logs: model value is", model)
211
+ print("Logs: model path value is",path + model["model_id"] )
212
+ # self.multi_controlnet_model[model["model_id"]] = controlnet.load(model["model_id"])
213
+ # self.multi_text2image_model[model["model_id"]] = text2img_pipe.load(model["model_id"])
214
+ # self.multi_image2image_model[model["model_id"]] = img2img_pipe.load(model["model_id"])
215
+ self.multi_controlnet_model[model["model_id"]] = controlnet.load(path + model["model_id"])
216
+ self.multi_text2image_model[model["model_id"]] = text2img_pipe.load(path + model["model_id"])
217
+ self.multi_image2image_model[model["model_id"]] = img2img_pipe.load(path + model["model_id"])
218
+
219
+ print(" Logs: model[model_id]", model["model_id"])
220
+ print("Logs: multimodel controlnet pipelines are", path + model["model_id"])
221
+ print("Logs: multimodel text2img pipelines are", path + model["model_id"])
222
+ print("Logs: multimodel imgtoimage pipelines are", path + model["model_id"])
223
+ # controlnet.load(path)
224
+ # text2img_pipe.load(path)
225
+ # img2img_pipe.load(path)
226
+
227
+ print("Logs: model loaded ....")
228
+
229
+
230
+
231
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
232
+ """
233
+ Args:
234
+ data (:obj:):
235
+ includes the input data and the parameters for the inference.
236
+ Return:
237
+ A :obj:`dict`:. base64 encoded image
238
+ """
239
+ print("Logs post: self.path",self.path)
240
+ print("Logs post: task is ", data)
241
+ inputs = data.pop("inputs", data)
242
+ parameters = data.pop("parameters", None)
243
+ model_id = data.pop("model_id", None)
244
+
245
+ model_id =""
246
+ print("Logs post: model_id is", model_id)
247
+ task = Task(data)
248
+
249
+
250
+ try:
251
+ task_type = task.get_type()
252
+
253
+ if task_type == TaskType.TEXT_TO_IMAGE:
254
+ # character sheet
255
+ if "character sheet" in task.get_prompt().lower():
256
+ return pose(task, s3_outkey="", poses=pickPoses())
257
+ else:
258
+ return self.multi_text2image_model[ self.path + multi_model_list[0][model_id]](task)
259
+ elif task_type == TaskType.IMAGE_TO_IMAGE:
260
+ return img2img(task)
261
+ elif task_type == TaskType.CANNY:
262
+ return canny(task)
263
+ elif task_type == TaskType.POSE:
264
+ return pose(task)
265
+ else:
266
+ raise Exception("Invalid task type")
267
+ except Exception as e:
268
+ print(f"Error: {e}")
269
+ slack.error_alert(task, e)
270
+ return None
271
+
272
+ # inputs = data.pop("inputs", data)
273
+
274
+ # # run inference pipeline
275
+ # with autocast(device.type):
276
+ # image = self.pipe(inputs, guidance_scale=7.5)
277
+
278
+ # # encode image as base 64
279
+ # buffered = BytesIO()
280
+ # # image.save(buffered, format="JPEG")
281
+ # # img_str = base64.b64encode(buffered.getvalue())
282
+ # print(image)
283
+ # # postprocess the prediction
284
+ # return image["images"]