Update handler.py
Browse files- handler.py +4 -4
handler.py
CHANGED
|
@@ -126,7 +126,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
| 126 |
# @update_db
|
| 127 |
@auto_clear_cuda_and_gc(controlnet)
|
| 128 |
@slack.auto_send_alert
|
| 129 |
-
def text2img(task: Task):
|
| 130 |
prompt, ori_prompt = get_patched_prompt(task)
|
| 131 |
|
| 132 |
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
|
|
@@ -243,7 +243,7 @@ class EndpointHandler():
|
|
| 243 |
model_id = data.pop("model_id", None)
|
| 244 |
|
| 245 |
# check if model_id is in the list of models
|
| 246 |
-
if model_id is None or model_id not in self.
|
| 247 |
raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
|
| 248 |
|
| 249 |
# # pass inputs with all kwargs in data
|
|
@@ -263,7 +263,7 @@ class EndpointHandler():
|
|
| 263 |
parameters = data.pop("parameters", None)
|
| 264 |
model_id = data.pop("model_id", None)
|
| 265 |
|
| 266 |
-
|
| 267 |
print("Logs post: model_id is", model_id)
|
| 268 |
task = Task(data)
|
| 269 |
|
|
@@ -276,7 +276,7 @@ class EndpointHandler():
|
|
| 276 |
if "character sheet" in task.get_prompt().lower():
|
| 277 |
return pose(task, s3_outkey="", poses=pickPoses())
|
| 278 |
else:
|
| 279 |
-
return self.multi_text2image_model[model_id]
|
| 280 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 281 |
return img2img(task)
|
| 282 |
elif task_type == TaskType.CANNY:
|
|
|
|
| 126 |
# @update_db
|
| 127 |
@auto_clear_cuda_and_gc(controlnet)
|
| 128 |
@slack.auto_send_alert
|
| 129 |
+
def text2img(task: Task, text2img_pipe ):
|
| 130 |
prompt, ori_prompt = get_patched_prompt(task)
|
| 131 |
|
| 132 |
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
|
|
|
|
| 243 |
model_id = data.pop("model_id", None)
|
| 244 |
|
| 245 |
# check if model_id is in the list of models
|
| 246 |
+
if model_id is None or model_id not in self.multi_model_list:
|
| 247 |
raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
|
| 248 |
|
| 249 |
# # pass inputs with all kwargs in data
|
|
|
|
| 263 |
parameters = data.pop("parameters", None)
|
| 264 |
model_id = data.pop("model_id", None)
|
| 265 |
|
| 266 |
+
|
| 267 |
print("Logs post: model_id is", model_id)
|
| 268 |
task = Task(data)
|
| 269 |
|
|
|
|
| 276 |
if "character sheet" in task.get_prompt().lower():
|
| 277 |
return pose(task, s3_outkey="", poses=pickPoses())
|
| 278 |
else:
|
| 279 |
+
return text2img(task, self.multi_text2image_model[model_id])
|
| 280 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 281 |
return img2img(task)
|
| 282 |
elif task_type == TaskType.CANNY:
|