CodeJackR
commited on
Commit
·
ea3669d
1
Parent(s):
b0c81ad
add point for masking
Browse files- handler.py +4 -4
handler.py
CHANGED
|
@@ -52,11 +52,11 @@ class EndpointHandler():
|
|
| 52 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 53 |
|
| 54 |
# 2. Prepare prompts and process the image
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
-
inputs = self.processor(img, return_tensors="pt").to(device)
|
| 60 |
|
| 61 |
# 3. Generate masks
|
| 62 |
with torch.no_grad():
|
|
|
|
| 52 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 53 |
|
| 54 |
# 2. Prepare prompts and process the image
|
| 55 |
+
height, width = img.size[1], img.size[0]
|
| 56 |
+
input_points = [[[width // 2, height // 2]]]
|
| 57 |
+
input_labels = [[1]]
|
| 58 |
|
| 59 |
+
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
|
| 60 |
|
| 61 |
# 3. Generate masks
|
| 62 |
with torch.no_grad():
|