Meng Chen commited on
Commit ·
d48d98c
1
Parent(s): b8264a8
update handler
Browse files- handler.py +10 -5
handler.py
CHANGED
|
@@ -11,8 +11,9 @@ class EndpointHandler():
|
|
| 11 |
# Preload all the elements you are going to need at inference.
|
| 12 |
# pseudo:
|
| 13 |
# self.model= load_model(path)
|
|
|
|
| 14 |
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 15 |
-
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 16 |
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
|
| 17 |
|
| 18 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
@@ -23,14 +24,18 @@ class EndpointHandler():
|
|
| 23 |
Return:
|
| 24 |
A :obj:`list` | `dict`: will be serialized and returned
|
| 25 |
"""
|
| 26 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
return [{"error": "Missing 'image' or 'text' key in input data"}]
|
| 28 |
|
| 29 |
try:
|
| 30 |
# Decode base64 image
|
| 31 |
-
image = self.decode_image(
|
| 32 |
-
prompts =
|
| 33 |
-
|
| 34 |
# Preprocess input
|
| 35 |
inputs = self.processor(
|
| 36 |
text=prompts,
|
|
|
|
| 11 |
# Preload all the elements you are going to need at inference.
|
| 12 |
# pseudo:
|
| 13 |
# self.model= load_model(path)
|
| 14 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 16 |
+
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device)
|
| 17 |
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
|
| 18 |
|
| 19 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
|
| 24 |
Return:
|
| 25 |
A :obj:`list` | `dict`: will be serialized and returned
|
| 26 |
"""
|
| 27 |
+
if "inputs" not in data:
|
| 28 |
+
return [{"error": "Missing 'inputs' key"}]
|
| 29 |
+
|
| 30 |
+
inputs_data = data["inputs"]
|
| 31 |
+
if "image" not in inputs_data or "text" not in inputs_data:
|
| 32 |
return [{"error": "Missing 'image' or 'text' key in input data"}]
|
| 33 |
|
| 34 |
try:
|
| 35 |
# Decode base64 image
|
| 36 |
+
image = self.decode_image(inputs_data["image"])
|
| 37 |
+
prompts = inputs_data["text"]
|
| 38 |
+
|
| 39 |
# Preprocess input
|
| 40 |
inputs = self.processor(
|
| 41 |
text=prompts,
|