mjohanes commited on
Commit
c7b01f0
·
1 Parent(s): c676f9b
Files changed (1) hide show
  1. app.py +32 -7
app.py CHANGED
@@ -73,7 +73,7 @@ def load_models():
73
  return fastsam_model, sd_pipe
74
 
75
  # Initialize the models (this will run only once thanks to caching)
76
- # fastsam_model, sd_pipe = load_models()
77
 
78
  # Ensure we have a state for removing_dots
79
  if "is_removing_dot" not in st.session_state:
@@ -111,13 +111,38 @@ else:
111
  if "coords_list" not in st.session_state:
112
  st.session_state.coords_list = []
113
 
114
- # Create a copy of the image and draw red dots for every stored coordinate.
115
- img_with_dots = img.copy()
116
- draw = ImageDraw.Draw(img_with_dots)
117
- for coord in st.session_state.coords_list:
118
- cx, cy = int(coord["x"]), int(coord["y"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")
120
 
 
 
121
  # Use the interactive component as the display canvas, showing the image with all dots.
122
  new_coord = streamlit_image_coordinates(img_with_dots, key="click_img")
123
 
@@ -139,7 +164,7 @@ else:
139
  else:
140
  st.session_state.is_removing_dot = False
141
 
142
- st.write("Coordinates:", st.session_state.get("coords"))
143
 
144
 
145
 
 
73
  return fastsam_model, sd_pipe
74
 
75
  # Initialize the models (this will run only once thanks to caching)
76
+ fastsam_model, sd_pipe = load_models()
77
 
78
  # Ensure we have a state for removing_dots
79
  if "is_removing_dot" not in st.session_state:
 
111
  if "coords_list" not in st.session_state:
112
  st.session_state.coords_list = []
113
 
114
+ # --- Compute Segmentation Overlay ---
115
+ # If any points have been stored, run segmentation with FastSAM.
116
+ if st.session_state.coords_list:
117
+ points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list]
118
+ labels = [1] * len(points)
119
+ results = fastsam_model(img, points=points, labels=labels)
120
+ # Assume results[0].masks.data is a tensor with shape (N, H, W)
121
+ masks_tensor = results[0].masks.data
122
+ masks = masks_tensor.cpu().numpy()
123
+ if masks.ndim == 3 and masks.shape[0] > 0:
124
+ # Combine masks (logical OR via max)
125
+ combined_mask = np.max(masks, axis=0)
126
+ combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8))
127
+ # Create a red overlay with transparency
128
+ overlay = Image.new("RGBA", img.size, (255, 0, 0, 100))
129
+ base = img.convert("RGBA")
130
+ mask_for_overlay = combined_mask_img.convert("L")
131
+ seg_overlay = Image.composite(overlay, base, mask_for_overlay)
132
+ else:
133
+ seg_overlay = img.copy()
134
+ else:
135
+ seg_overlay = img.copy()
136
+
137
+ # --- Draw Red Dots on Top ---
138
+ final_img = seg_overlay.copy()
139
+ draw = ImageDraw.Draw(final_img)
140
+ for pt in st.session_state.coords_list:
141
+ cx, cy = int(pt["x"]), int(pt["y"])
142
  draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red")
143
 
144
+
145
+
146
  # Use the interactive component as the display canvas, showing the image with all dots.
147
  new_coord = streamlit_image_coordinates(img_with_dots, key="click_img")
148
 
 
164
  else:
165
  st.session_state.is_removing_dot = False
166
 
167
+ st.write("Stored coordinates:", st.session_state.coords_list)
168
 
169
 
170