Spaces:
Running
on
Zero
Running
on
Zero
app.py
CHANGED
|
@@ -69,17 +69,25 @@ def predict_masks(image, points):
|
|
| 69 |
return image # Return the original image if no points are selected
|
| 70 |
PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
input_labels = [1] * len(points_list)
|
| 75 |
|
| 76 |
with torch.inference_mode():
|
| 77 |
-
PREDICTOR.set_image(
|
| 78 |
masks, _, _ = PREDICTOR.predict(
|
| 79 |
point_coords=points_list, point_labels=input_labels, multimask_output=False
|
| 80 |
)
|
| 81 |
|
| 82 |
# Prepare the overlay image
|
|
|
|
| 83 |
red_mask = np.zeros_like(image_np)
|
| 84 |
if masks and len(masks) > 0:
|
| 85 |
red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
|
|
@@ -90,7 +98,6 @@ def predict_masks(image, points):
|
|
| 90 |
else:
|
| 91 |
return image_np
|
| 92 |
|
| 93 |
-
|
| 94 |
def update_mask(prompts):
|
| 95 |
"""Update the mask based on the prompts."""
|
| 96 |
image = prompts["image"]
|
|
|
|
| 69 |
return image # Return the original image if no points are selected
|
| 70 |
PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
|
| 71 |
|
| 72 |
+
# Debugging: Print the structure of points
|
| 73 |
+
print(f"Points structure: {points}")
|
| 74 |
+
|
| 75 |
+
# Ensure points is a list of lists with at least two elements
|
| 76 |
+
if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
|
| 77 |
+
points_list = [[point[0], point[1]] for point in points]
|
| 78 |
+
else:
|
| 79 |
+
return image # Return the original image if points structure is unexpected
|
| 80 |
+
|
| 81 |
input_labels = [1] * len(points_list)
|
| 82 |
|
| 83 |
with torch.inference_mode():
|
| 84 |
+
PREDICTOR.set_image(np.array(image))
|
| 85 |
masks, _, _ = PREDICTOR.predict(
|
| 86 |
point_coords=points_list, point_labels=input_labels, multimask_output=False
|
| 87 |
)
|
| 88 |
|
| 89 |
# Prepare the overlay image
|
| 90 |
+
image_np = np.array(image)
|
| 91 |
red_mask = np.zeros_like(image_np)
|
| 92 |
if masks and len(masks) > 0:
|
| 93 |
red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
|
|
|
|
| 98 |
else:
|
| 99 |
return image_np
|
| 100 |
|
|
|
|
| 101 |
def update_mask(prompts):
|
| 102 |
"""Update the mask based on the prompts."""
|
| 103 |
image = prompts["image"]
|