push app
Browse files
app.py
CHANGED
|
@@ -73,7 +73,7 @@ def load_models():
|
|
| 73 |
return fastsam_model, sd_pipe
|
| 74 |
|
| 75 |
# Initialize the models (this will run only once thanks to caching)
|
| 76 |
-
|
| 77 |
|
| 78 |
# Ensure we have a state for removing_dots
|
| 79 |
if "is_removing_dot" not in st.session_state:
|
|
@@ -111,13 +111,38 @@ else:
|
|
| 111 |
if "coords_list" not in st.session_state:
|
| 112 |
st.session_state.coords_list = []
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")
|
| 120 |
|
|
|
|
|
|
|
| 121 |
# Use the interactive component as the display canvas, showing the image with all dots.
|
| 122 |
new_coord = streamlit_image_coordinates(img_with_dots, key="click_img")
|
| 123 |
|
|
@@ -139,7 +164,7 @@ else:
|
|
| 139 |
else:
|
| 140 |
st.session_state.is_removing_dot = False
|
| 141 |
|
| 142 |
-
st.write("
|
| 143 |
|
| 144 |
|
| 145 |
|
|
|
|
| 73 |
return fastsam_model, sd_pipe
|
| 74 |
|
| 75 |
# Initialize the models (this will run only once thanks to caching)
|
| 76 |
+
fastsam_model, sd_pipe = load_models()
|
| 77 |
|
| 78 |
# Ensure we have a state for removing_dots
|
| 79 |
if "is_removing_dot" not in st.session_state:
|
|
|
|
| 111 |
if "coords_list" not in st.session_state:
|
| 112 |
st.session_state.coords_list = []
|
| 113 |
|
| 114 |
+
# --- Compute Segmentation Overlay ---
|
| 115 |
+
# If any points have been stored, run segmentation with FastSAM.
|
| 116 |
+
if st.session_state.coords_list:
|
| 117 |
+
points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list]
|
| 118 |
+
labels = [1] * len(points)
|
| 119 |
+
results = fastsam_model(img, points=points, labels=labels)
|
| 120 |
+
# Assume results[0].masks.data is a tensor with shape (N, H, W)
|
| 121 |
+
masks_tensor = results[0].masks.data
|
| 122 |
+
masks = masks_tensor.cpu().numpy()
|
| 123 |
+
if masks.ndim == 3 and masks.shape[0] > 0:
|
| 124 |
+
# Combine masks (logical OR via max)
|
| 125 |
+
combined_mask = np.max(masks, axis=0)
|
| 126 |
+
combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8))
|
| 127 |
+
# Create a red overlay with transparency
|
| 128 |
+
overlay = Image.new("RGBA", img.size, (255, 0, 0, 100))
|
| 129 |
+
base = img.convert("RGBA")
|
| 130 |
+
mask_for_overlay = combined_mask_img.convert("L")
|
| 131 |
+
seg_overlay = Image.composite(overlay, base, mask_for_overlay)
|
| 132 |
+
else:
|
| 133 |
+
seg_overlay = img.copy()
|
| 134 |
+
else:
|
| 135 |
+
seg_overlay = img.copy()
|
| 136 |
+
|
| 137 |
+
# --- Draw Red Dots on Top ---
|
| 138 |
+
final_img = seg_overlay.copy()
|
| 139 |
+
draw = ImageDraw.Draw(final_img)
|
| 140 |
+
for pt in st.session_state.coords_list:
|
| 141 |
+
cx, cy = int(pt["x"]), int(pt["y"])
|
| 142 |
draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")
|
| 143 |
|
| 144 |
+
|
| 145 |
+
|
| 146 |
# Use the interactive component as the display canvas, showing the image with all dots.
|
| 147 |
new_coord = streamlit_image_coordinates(img_with_dots, key="click_img")
|
| 148 |
|
|
|
|
| 164 |
else:
|
| 165 |
st.session_state.is_removing_dot = False
|
| 166 |
|
| 167 |
+
st.write("Stored coordinates:", st.session_state.coords_list)
|
| 168 |
|
| 169 |
|
| 170 |
|