jayparmr commited on
Commit
f4dcbf7
·
1 Parent(s): 9b49fa9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +107 -98
handler.py CHANGED
@@ -26,101 +26,100 @@ num_return_sequences = 4 # the number of results to generate
26
  auto_mode = False
27
 
28
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
29
- controlnet = ControlNet()
30
  lora_style = LoraStyle()
31
- text2img_pipe = Text2Img()
32
- img2img_pipe = Img2Img()
33
  slack = Slack()
34
 
35
 
36
 
37
- def get_patched_prompt(task: Task):
38
- def add_style_and_character(prompt: List[str]):
39
- for i in range(len(prompt)):
40
- prompt[i] = add_code_names(prompt[i])
41
- prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
42
 
43
- prompt = task.get_prompt()
44
 
45
- if task.is_prompt_engineering():
46
- prompt = prompt_modifier.modify(prompt)
47
- else:
48
- prompt = [prompt] * num_return_sequences
49
 
50
- ori_prompt = [task.get_prompt()] * num_return_sequences
51
 
52
- add_style_and_character(ori_prompt)
53
- add_style_and_character(prompt)
54
 
55
- print({"prompts": prompt})
56
 
57
- return (prompt, ori_prompt)
58
 
59
 
60
- # @update_db
61
- @auto_clear_cuda_and_gc(controlnet)
62
- @slack.auto_send_alert
63
- def canny(task: Task):
64
- prompt, _ = get_patched_prompt(task)
65
 
66
- controlnet.load_canny()
67
 
68
- lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
69
- lora_patcher.patch()
70
 
71
- images = controlnet.process_canny(
72
- prompt=prompt,
73
- imageUrl=task.get_imageUrl(),
74
- seed=task.get_seed(),
75
- steps=task.get_steps(),
76
- width=task.get_width(),
77
- height=task.get_height(),
78
- negative_prompt=[
79
- f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
80
- ]
81
- * num_return_sequences,
82
- **lora_patcher.kwargs(),
83
- )
84
 
85
- generated_image_urls = upload_images(images, "_canny", task.get_taskId())
86
 
87
- lora_patcher.cleanup()
88
- controlnet.cleanup()
89
 
90
- return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
91
 
92
 
93
- # @update_db
94
- @auto_clear_cuda_and_gc(controlnet)
95
- @slack.auto_send_alert
96
- def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
97
- prompt, _ = get_patched_prompt(task)
98
 
99
- controlnet.load_pose()
100
 
101
- lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
102
- lora_patcher.patch()
103
 
104
- if poses is None:
105
- poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
106
 
107
- images = controlnet.process_pose(
108
- prompt=prompt,
109
- image=poses,
110
- seed=task.get_seed(),
111
- steps=task.get_steps(),
112
- negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
113
- width=task.get_width(),
114
- height=task.get_height(),
115
- **lora_patcher.kwargs(),
116
- )
117
 
118
- generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
119
 
120
- lora_patcher.cleanup()
121
- controlnet.cleanup()
122
 
123
- return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
124
 
125
 
126
  # @update_db
@@ -153,30 +152,30 @@ def text2img(task: Task, text2img_pipe ):
153
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
154
 
155
 
156
- # @update_db
157
- @auto_clear_cuda_and_gc(controlnet)
158
- @slack.auto_send_alert
159
- def img2img(task: Task):
160
- prompt, _ = get_patched_prompt(task)
161
 
162
- lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
163
- lora_patcher.patch()
164
 
165
- torch.manual_seed(task.get_seed())
166
 
167
- images = img2img_pipe.process(
168
- prompt=prompt,
169
- imageUrl=task.get_imageUrl(),
170
- negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
171
- steps=task.get_steps(),
172
- **lora_patcher.kwargs(),
173
- )
174
 
175
- generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
176
 
177
- lora_patcher.cleanup()
178
 
179
- return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
180
 
181
 
182
 
@@ -213,14 +212,23 @@ class EndpointHandler():
213
  # self.multi_controlnet_model[model["model_id"]] = controlnet.load(model["model_id"])
214
  # self.multi_text2image_model[model["model_id"]] = text2img_pipe.load(model["model_id"])
215
  # self.multi_image2image_model[model["model_id"]] = img2img_pipe.load(model["model_id"])
216
- self.multi_controlnet_model[model["model_id"]] = controlnet.load(model["model_id"])
217
- self.multi_text2image_model[model["model_id"]] = text2img_pipe.load( model["model_id"])
218
- self.multi_image2image_model[model["model_id"]] = img2img_pipe.load( model["model_id"])
 
 
 
 
 
 
 
 
 
219
 
220
  print(" Logs: model[model_id]", model["model_id"])
221
- print("Logs: multimodel controlnet pipelines are", path + model["model_id"])
222
- print("Logs: multimodel text2img pipelines are", path + model["model_id"])
223
- print("Logs: multimodel imgtoimage pipelines are", path + model["model_id"])
224
  # controlnet.load(path)
225
  # text2img_pipe.load(path)
226
  # img2img_pipe.load(path)
@@ -274,15 +282,16 @@ class EndpointHandler():
274
  if task_type == TaskType.TEXT_TO_IMAGE:
275
  # character sheet
276
  if "character sheet" in task.get_prompt().lower():
277
- return pose(task, s3_outkey="", poses=pickPoses())
 
278
  else:
279
  return text2img(task, self.multi_text2image_model[model_id])
280
- elif task_type == TaskType.IMAGE_TO_IMAGE:
281
- return img2img(task)
282
- elif task_type == TaskType.CANNY:
283
- return canny(task)
284
- elif task_type == TaskType.POSE:
285
- return pose(task)
286
  else:
287
  raise Exception("Invalid task type")
288
  except Exception as e:
 
26
  auto_mode = False
27
 
28
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
29
+
30
  lora_style = LoraStyle()
31
+
 
32
  slack = Slack()
33
 
34
 
35
 
36
+ # def get_patched_prompt(task: Task):
37
+ # def add_style_and_character(prompt: List[str]):
38
+ # for i in range(len(prompt)):
39
+ # prompt[i] = add_code_names(prompt[i])
40
+ # prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
41
 
42
+ # prompt = task.get_prompt()
43
 
44
+ # if task.is_prompt_engineering():
45
+ # prompt = prompt_modifier.modify(prompt)
46
+ # else:
47
+ # prompt = [prompt] * num_return_sequences
48
 
49
+ # ori_prompt = [task.get_prompt()] * num_return_sequences
50
 
51
+ # add_style_and_character(ori_prompt)
52
+ # add_style_and_character(prompt)
53
 
54
+ # print({"prompts": prompt})
55
 
56
+ # return (prompt, ori_prompt)
57
 
58
 
59
+ # # @update_db
60
+ # @auto_clear_cuda_and_gc(controlnet)
61
+ # @slack.auto_send_alert
62
+ # def canny(task: Task):
63
+ # prompt, _ = get_patched_prompt(task)
64
 
65
+ # controlnet.load_canny()
66
 
67
+ # lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
68
+ # lora_patcher.patch()
69
 
70
+ # images = controlnet.process_canny(
71
+ # prompt=prompt,
72
+ # imageUrl=task.get_imageUrl(),
73
+ # seed=task.get_seed(),
74
+ # steps=task.get_steps(),
75
+ # width=task.get_width(),
76
+ # height=task.get_height(),
77
+ # negative_prompt=[
78
+ # f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
79
+ # ]
80
+ # * num_return_sequences,
81
+ # **lora_patcher.kwargs(),
82
+ # )
83
 
84
+ # generated_image_urls = upload_images(images, "_canny", task.get_taskId())
85
 
86
+ # lora_patcher.cleanup()
87
+ # controlnet.cleanup()
88
 
89
+ # return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
90
 
91
 
92
+ # # @update_db
93
+ # @auto_clear_cuda_and_gc(controlnet)
94
+ # @slack.auto_send_alert
95
+ # def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
96
+ # prompt, _ = get_patched_prompt(task)
97
 
98
+ # controlnet.load_pose()
99
 
100
+ # lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
101
+ # lora_patcher.patch()
102
 
103
+ # if poses is None:
104
+ # poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
105
 
106
+ # images = controlnet.process_pose(
107
+ # prompt=prompt,
108
+ # image=poses,
109
+ # seed=task.get_seed(),
110
+ # steps=task.get_steps(),
111
+ # negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
112
+ # width=task.get_width(),
113
+ # height=task.get_height(),
114
+ # **lora_patcher.kwargs(),
115
+ # )
116
 
117
+ # generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
118
 
119
+ # lora_patcher.cleanup()
120
+ # controlnet.cleanup()
121
 
122
+ # return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
123
 
124
 
125
  # @update_db
 
152
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
153
 
154
 
155
+ # # @update_db
156
+ # @auto_clear_cuda_and_gc(controlnet)
157
+ # @slack.auto_send_alert
158
+ # def img2img(task: Task):
159
+ # prompt, _ = get_patched_prompt(task)
160
 
161
+ # lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
162
+ # lora_patcher.patch()
163
 
164
+ # torch.manual_seed(task.get_seed())
165
 
166
+ # images = img2img_pipe.process(
167
+ # prompt=prompt,
168
+ # imageUrl=task.get_imageUrl(),
169
+ # negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
170
+ # steps=task.get_steps(),
171
+ # **lora_patcher.kwargs(),
172
+ # )
173
 
174
+ # generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
175
 
176
+ # lora_patcher.cleanup()
177
 
178
+ # return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
179
 
180
 
181
 
 
212
  # self.multi_controlnet_model[model["model_id"]] = controlnet.load(model["model_id"])
213
  # self.multi_text2image_model[model["model_id"]] = text2img_pipe.load(model["model_id"])
214
  # self.multi_image2image_model[model["model_id"]] = img2img_pipe.load(model["model_id"])
215
+ controlnet = ControlNet()
216
+ img2img_pipe = Img2Img()
217
+ text2img_pipe = Text2Img()
218
+ self.multi_controlnet_model[model["model_id"]] = controlnet;
219
+ controlnet.load(model["model_id"])
220
+
221
+
222
+ self.multi_text2image_model[model["model_id"]] = text2img_pipe;
223
+ text2img_pipe.load( model["model_id"])
224
+
225
+ self.multi_image2image_model[model["model_id"]] = img2img_pipe;
226
+ img2img_pipe.load( model["model_id"])
227
 
228
  print(" Logs: model[model_id]", model["model_id"])
229
+ print("Logs: multimodel controlnet pipelines are", self.multi_controlnet_model[model["model_id"]])
230
+ print("Logs: multimodel text2img pipelines are", self.multi_text2image_model[model["model_id"]])
231
+ print("Logs: multimodel imgtoimage pipelines are", self.multi_image2image_model[model["model_id"]])
232
  # controlnet.load(path)
233
  # text2img_pipe.load(path)
234
  # img2img_pipe.load(path)
 
282
  if task_type == TaskType.TEXT_TO_IMAGE:
283
  # character sheet
284
  if "character sheet" in task.get_prompt().lower():
285
+ print("pose is here")
286
+ # return pose(task, s3_outkey="", poses=pickPoses())
287
  else:
288
  return text2img(task, self.multi_text2image_model[model_id])
289
+ # elif task_type == TaskType.IMAGE_TO_IMAGE:
290
+ # return img2img(task)
291
+ # elif task_type == TaskType.CANNY:
292
+ # return canny(task)
293
+ # elif task_type == TaskType.POSE:
294
+ # return pose(task)
295
  else:
296
  raise Exception("Invalid task type")
297
  except Exception as e: