mjohanes commited on
Commit
aa64939
·
1 Parent(s): 11f70e4
Files changed (1) hide show
  1. app.py +36 -26
app.py CHANGED
@@ -21,6 +21,18 @@ st.set_page_config(page_title="Inpainting Demo", layout="wide")
21
  FASTSAM_CHECKPOINT = "FastSAM-x.pt" # file name of the FastSAM model weights
22
  SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" # HF Hub model for SD Inpainting v1.5
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Ensure FastSAM model weights are available (download if not present)
25
  if not os.path.exists(FASTSAM_CHECKPOINT):
26
  # Download FastSAM weights (if not already in the repo)
@@ -68,22 +80,27 @@ canvas = st.empty()
68
  picture = st.camera_input("Take a picture")
69
 
70
  if picture is not None:
71
- # When an image is captured, display it and allow point selection
72
- img = Image.open(picture) # read image as PIL
 
73
  canvas.image(img, caption="Captured Image", use_container_width=True)
74
 
75
- # Let user click a point on the image. This returns a dict with 'x' and 'y'.
76
- coords = streamlit_image_coordinates(img, key="click_img")
77
- if coords:
78
- # If a point was clicked, mark it on the image for user feedback
79
- cx, cy = int(coords['x']), int(coords['y'])
80
- # Draw a small red circle on the image copy to show selected point
81
- img_with_dot = img.copy()
82
- draw = ImageDraw.Draw(img_with_dot)
83
- draw.ellipse((cx-5, cy-5, cx+5, cy+5), fill='red')
84
- canvas.image(img_with_dot, caption=f"Selected Point: ({cx}, {cy})", use_container_width=True)
85
- else:
86
- cx = cy = None
 
 
 
 
87
 
88
  # Prompt input for inpainting
89
  prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):")
@@ -92,21 +109,14 @@ if picture is not None:
92
  if coords and prompt:
93
  cx, cy = int(coords['x']), int(coords['y'])
94
  st.write("Generating mask with FastSAM...")
95
- # Run FastSAM segmentation with the selected point as prompt
96
- # Using the Ultralytics API: points=[[x,y]] and labels=[1] for a positive point prompt
97
  results = fastsam_model(img, points=[[cx, cy]], labels=[1])
98
- # The results object holds masks; extract the first mask (closest object to the point)
99
- mask_data = results[0].masks.data[0] # mask tensor (H x W)
100
- mask_array = mask_data.cpu().numpy() # convert to numpy array
101
- # Create a PIL Image for the mask: convert 1.0 to 255 (white), 0.0 to 0 (black)
102
  mask_image = Image.fromarray((mask_array * 255).astype(np.uint8))
103
 
104
- # For debugging, we can display the mask – uncomment if needed
105
- # st.image(mask_image, caption="Segmentation Mask", use_column_width=True)
106
-
107
  st.write("Running Stable Diffusion Inpainting...")
108
- # Run the Stable Diffusion inpainting pipeline
109
  result = sd_pipe(prompt=prompt, image=img, mask_image=mask_image).images[0]
110
 
111
- # Display the final inpainted image
112
- canvas.image(result, caption="Inpainted Image", use_container_width=True)
 
 
21
  FASTSAM_CHECKPOINT = "FastSAM-x.pt" # file name of the FastSAM model weights
22
  SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" # HF Hub model for SD Inpainting v1.5
23
 
24
+ # Helper function: center crop and resize to 512x512
25
+ def crop_resize_image(image, size=512):
26
+ width, height = image.size
27
+ if width != height:
28
+ min_dim = min(width, height)
29
+ left = (width - min_dim) // 2
30
+ top = (height - min_dim) // 2
31
+ right = left + min_dim
32
+ bottom = top + min_dim
33
+ image = image.crop((left, top, right, bottom))
34
+ return image.resize((size, size))
35
+
36
  # Ensure FastSAM model weights are available (download if not present)
37
  if not os.path.exists(FASTSAM_CHECKPOINT):
38
  # Download FastSAM weights (if not already in the repo)
 
80
  picture = st.camera_input("Take a picture")
81
 
82
  if picture is not None:
83
+ # Open, crop & resize the captured image to 512x512
84
+ img = Image.open(picture)
85
+ img = crop_resize_image(img, size=512)
86
  canvas.image(img, caption="Captured Image", use_container_width=True)
87
 
88
+ # Use the canvas container for all related UI elements
89
+ with canvas.container():
90
+ st.image(img, caption="Captured (512x512) Image", use_container_width=True)
91
+
92
+ # Place the interactive component inside the same container
93
+ coords = streamlit_image_coordinates(img, key="click_img")
94
+
95
+ if coords:
96
+ cx, cy = int(coords['x']), int(coords['y'])
97
+ # Draw a red circle on the image to indicate the selected point
98
+ img_with_dot = img.copy()
99
+ draw = ImageDraw.Draw(img_with_dot)
100
+ draw.ellipse((cx-5, cy-5, cx+5, cy+5), fill='red')
101
+ st.image(img_with_dot, caption=f"Selected Point: ({cx}, {cy})", use_container_width=True)
102
+ else:
103
+ cx = cy = None
104
 
105
  # Prompt input for inpainting
106
  prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):")
 
109
  if coords and prompt:
110
  cx, cy = int(coords['x']), int(coords['y'])
111
  st.write("Generating mask with FastSAM...")
 
 
112
  results = fastsam_model(img, points=[[cx, cy]], labels=[1])
113
+ mask_data = results[0].masks.data[0]
114
+ mask_array = mask_data.cpu().numpy()
 
 
115
  mask_image = Image.fromarray((mask_array * 255).astype(np.uint8))
116
 
 
 
 
117
  st.write("Running Stable Diffusion Inpainting...")
 
118
  result = sd_pipe(prompt=prompt, image=img, mask_image=mask_image).images[0]
119
 
120
+ # Finally, update the same canvas with the inpainted image
121
+ with canvas.container():
122
+ st.image(result, caption="Inpainted Image", use_container_width=True)