Inam65 commited on
Commit
d90f6b4
·
verified ·
1 Parent(s): 1c45d68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -46
app.py CHANGED
@@ -1,50 +1,66 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
4
  from PIL import Image
5
  from transformers import SamModel, SamProcessor
6
 
7
- # 1. Load the Model and Processor
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
10
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def segment_object(image_data):
13
- # image_data is a dictionary from the Gradio ImageEditor
14
- if image_data is None or "composite" not in image_data:
15
  return None
16
 
 
17
  raw_image = image_data["background"].convert("RGB")
18
 
19
- # Get the bounding box from the editor's layers
20
- # Gradio returns shapes in a list. We look for the rectangle.
21
  layers = image_data.get("layers", [])
22
  if not layers:
23
- return raw_image # Return original if no box drawn
24
-
25
- # For simplicity, we take the first box found
26
- # In a real app, you'd iterate to find the 'crop' or 'rect' layer
27
- # Here we use the composite mask logic for a beginner-friendly approach
28
-
29
- # Convert image for model
30
- inputs = processor(raw_image, return_tensors="pt").to(device)
31
- image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
32
 
33
- # In this simple version, we'll use the 'mask' drawn by the user
34
- # to find the object. If you use the 'brush' or 'rect' tool:
35
- mask = image_data["layers"][0].split()[-1] # Alpha channel of the drawing layer
36
- mask = np.array(mask)
37
 
38
- # Find the coordinates of the drawn rectangle
39
- coords = np.argwhere(mask > 0)
40
  if coords.size == 0:
41
- return raw_image
42
-
 
43
  y0, x0 = coords.min(axis=0)
44
  y1, x1 = coords.max(axis=0)
45
  input_boxes = [[[x0, y0, x1, y1]]]
46
 
47
- # 2. Predict the mask
 
 
 
48
  inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
49
  inputs.pop("pixel_values", None)
50
  inputs["image_embeddings"] = image_embeddings
@@ -52,42 +68,57 @@ def segment_object(image_data):
52
  with torch.no_grad():
53
  outputs = model(**inputs)
54
 
55
- # 3. Process the results
56
  masks = processor.image_processor.post_process_masks(
57
  outputs.pred_masks.cpu(),
58
  inputs.original_sizes.cpu(),
59
  inputs.reshaped_input_sizes.cpu()
60
  )
61
-
62
- # Take the first mask (best guess)
63
  best_mask = masks[0][0][0].numpy()
64
 
65
- # 4. Create High-Quality White Background
 
 
 
 
66
  raw_np = np.array(raw_image)
67
- # Create an image where the background is white [255, 255, 255]
68
  white_bg = np.ones_like(raw_np) * 255
69
 
70
- # Place object on white background
71
- # We use the mask to choose between original pixels and white pixels
72
- final_img = np.where(best_mask[..., None], raw_np, white_bg)
73
 
74
- return Image.fromarray(final_img.astype('uint8'))
75
 
76
- # 3. Create the Gradio Interface
77
- with gr.Blocks() as demo:
78
- gr.Markdown("# 🖌️ Object Extractor to White Background")
79
- gr.Markdown("1. Upload an image. 2. Use the **Box** or **Brush** tool to highlight the object. 3. Click Submit.")
80
 
81
  with gr.Row():
82
- input_img = gr.ImageEditor(
83
- label="Input Image",
84
- type="pil",
85
- layers=True,
86
- canvas_size=(512, 512)
87
- )
88
- output_img = gr.Image(label="Extracted Object", type="pil")
89
-
90
- submit_btn = gr.Button("Extract Object")
91
- submit_btn.click(segment_object, inputs=[input_img], outputs=[output_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ import cv2
5
  from PIL import Image
6
  from transformers import SamModel, SamProcessor
7
 
8
+ # 1. Load the Model and Processor (using the base model for speed)
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
11
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
12
 
13
+ def refine_mask(mask):
14
+ """
15
+ Cleans up the mask by keeping only the largest connected object
16
+ and smoothing the edges.
17
+ """
18
+ # Convert boolean mask to 8-bit image (0 and 255)
19
+ mask_8bit = (mask.astype(np.uint8)) * 255
20
+
21
+ # Find all connected 'blobs'
22
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8bit, connectivity=8)
23
+
24
+ if num_labels > 1:
25
+ # We ignore index 0 (the background) and find the largest area among the rest
26
+ largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
27
+ refined_mask = (labels == largest_label).astype(np.uint8)
28
+ else:
29
+ refined_mask = mask_8bit / 255
30
+
31
+ # Smooth the edges slightly using a Gaussian Blur
32
+ refined_mask = cv2.GaussianBlur(refined_mask.astype(float), (3, 3), 0)
33
+ return refined_mask > 0.5
34
+
35
  def segment_object(image_data):
36
+ if image_data is None or "background" not in image_data:
 
37
  return None
38
 
39
+ # Load the background image
40
  raw_image = image_data["background"].convert("RGB")
41
 
42
+ # Extract the user's drawing from the layers
43
+ # We look at the alpha channel of the first layer to see where the user drew
44
  layers = image_data.get("layers", [])
45
  if not layers:
46
+ return raw_image
 
 
 
 
 
 
 
 
47
 
48
+ # Get coordinates from the drawing layer
49
+ mask_layer = np.array(layers[0].split()[-1]) # Alpha channel
50
+ coords = np.argwhere(mask_layer > 0)
 
51
 
 
 
52
  if coords.size == 0:
53
+ return raw_image # Return original if no selection made
54
+
55
+ # Define the bounding box [x0, y0, x1, y1]
56
  y0, x0 = coords.min(axis=0)
57
  y1, x1 = coords.max(axis=0)
58
  input_boxes = [[[x0, y0, x1, y1]]]
59
 
60
+ # --- AI PREDICTION ---
61
+ inputs = processor(raw_image, return_tensors="pt").to(device)
62
+ image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
63
+
64
  inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
65
  inputs.pop("pixel_values", None)
66
  inputs["image_embeddings"] = image_embeddings
 
68
  with torch.no_grad():
69
  outputs = model(**inputs)
70
 
71
+ # Convert output to a binary mask
72
  masks = processor.image_processor.post_process_masks(
73
  outputs.pred_masks.cpu(),
74
  inputs.original_sizes.cpu(),
75
  inputs.reshaped_input_sizes.cpu()
76
  )
 
 
77
  best_mask = masks[0][0][0].numpy()
78
 
79
+ # --- REFINEMENT STEP ---
80
+ # This removes the "spots" you saw in your previous result
81
+ final_mask = refine_mask(best_mask)
82
+
83
+ # --- CREATE FINAL IMAGE ---
84
  raw_np = np.array(raw_image)
85
+ # Create a pure white background
86
  white_bg = np.ones_like(raw_np) * 255
87
 
88
+ # Blend: If mask is 1, take original pixel. If 0, take white pixel.
89
+ output_np = np.where(final_mask[..., None], raw_np, white_bg)
 
90
 
91
+ return Image.fromarray(output_np.astype('uint8'))
92
 
93
+ # 3. Build the Gradio UI
94
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
95
+ gr.Markdown("## 🛠High-Quality Object Extractor")
96
+ gr.Markdown("Upload an image and **draw a tight rectangle** around the object you want to keep.")
97
 
98
  with gr.Row():
99
+ with gr.Column():
100
+ # The ImageEditor allows users to draw rectangles
101
+ img_input = gr.ImageEditor(
102
+ label="Input Image (Draw a Box)",
103
+ type="pil",
104
+ layers=True,
105
+ sources=["upload", "clipboard"],
106
+ canvas_size=(712, 712)
107
+ )
108
+ submit_btn = gr.Button("Extract & Clean Mask", variant="primary")
109
+
110
+ with gr.Column():
111
+ img_output = gr.Image(label="Result (White Background)", type="pil")
112
+
113
+ submit_btn.click(
114
+ fn=segment_object,
115
+ inputs=[img_input],
116
+ outputs=[img_output]
117
+ )
118
+
119
+ gr.Markdown("---")
120
+ gr.Markdown("### 💡 Tips for better results:")
121
+ gr.Markdown("- Draw your rectangle as **close to the object edges** as possible.")
122
+ gr.Markdown("- If there are still spots, try using the **brush tool** instead of the rectangle to 'paint' exactly what you want.")
123
 
124
  demo.launch()