Rename ler.py to handler.py
Browse files- ler.py → handler.py +23 -4
ler.py → handler.py
RENAMED
|
@@ -186,10 +186,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
| 186 |
if device.type != 'cuda':
|
| 187 |
raise ValueError("need to run on GPU")
|
| 188 |
|
|
|
|
| 189 |
multi_model_list = [
|
| 190 |
-
{"model_id": "/
|
| 191 |
-
{"model_id": "/model_v2"},
|
| 192 |
-
{"model_id": "/model_v3"}
|
| 193 |
]
|
| 194 |
|
| 195 |
class EndpointHandler():
|
|
@@ -236,6 +235,26 @@ class EndpointHandler():
|
|
| 236 |
Return:
|
| 237 |
A :obj:`dict`:. base64 encoded image
|
| 238 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
print("Logs post: self.path",self.path)
|
| 240 |
print("Logs post: task is ", data)
|
| 241 |
inputs = data.pop("inputs", data)
|
|
@@ -255,7 +274,7 @@ class EndpointHandler():
|
|
| 255 |
if "character sheet" in task.get_prompt().lower():
|
| 256 |
return pose(task, s3_outkey="", poses=pickPoses())
|
| 257 |
else:
|
| 258 |
-
return self.multi_text2image_model[
|
| 259 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 260 |
return img2img(task)
|
| 261 |
elif task_type == TaskType.CANNY:
|
|
|
|
| 186 |
if device.type != 'cuda':
|
| 187 |
raise ValueError("need to run on GPU")
|
| 188 |
|
| 189 |
+
# multi-model list
|
| 190 |
multi_model_list = [
|
| 191 |
+
{"model_id": "jayparmr/icbinp", "task": "text-classification"},
|
|
|
|
|
|
|
| 192 |
]
|
| 193 |
|
| 194 |
class EndpointHandler():
|
|
|
|
| 235 |
Return:
|
| 236 |
A :obj:`dict`:. base64 encoded image
|
| 237 |
"""
|
| 238 |
+
# deserialize incomin request
|
| 239 |
+
inputs = data.pop("inputs", data)
|
| 240 |
+
# parameters = data.pop("parameters", None)
|
| 241 |
+
model_id = data.pop("model_id", None)
|
| 242 |
+
|
| 243 |
+
# check if model_id is in the list of models
|
| 244 |
+
if model_id is None or model_id not in self.multi_model:
|
| 245 |
+
raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
|
| 246 |
+
|
| 247 |
+
# # pass inputs with all kwargs in data
|
| 248 |
+
# if parameters is not None:
|
| 249 |
+
# prediction = self.multi_model[model_id](inputs, **parameters)
|
| 250 |
+
# else:
|
| 251 |
+
# prediction = self.multi_model[model_id](inputs)
|
| 252 |
+
# # postprocess the prediction
|
| 253 |
+
# return prediction
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
print("Logs post: self.path",self.path)
|
| 259 |
print("Logs post: task is ", data)
|
| 260 |
inputs = data.pop("inputs", data)
|
|
|
|
| 274 |
if "character sheet" in task.get_prompt().lower():
|
| 275 |
return pose(task, s3_outkey="", poses=pickPoses())
|
| 276 |
else:
|
| 277 |
+
return self.multi_text2image_model[model_id](task)
|
| 278 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 279 |
return img2img(task)
|
| 280 |
elif task_type == TaskType.CANNY:
|