File size: 4,424 Bytes
350a741
 
 
d90f6b4
350a741
 
 
d90f6b4
350a741
 
 
 
d90f6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350a741
d90f6b4
350a741
 
d90f6b4
350a741
 
d90f6b4
 
350a741
 
d90f6b4
350a741
d90f6b4
 
 
350a741
 
d90f6b4
 
 
350a741
 
 
 
d90f6b4
 
 
 
350a741
 
 
 
 
 
 
d90f6b4
350a741
 
 
 
 
 
 
d90f6b4
 
 
 
 
350a741
d90f6b4
350a741
 
d90f6b4
 
350a741
d90f6b4
350a741
d90f6b4
 
 
 
350a741
 
d90f6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350a741
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import SamModel, SamProcessor

# 1. Load the Model and Processor (using the base model for speed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def refine_mask(mask):
    """
    Cleans up the mask by keeping only the largest connected object 
    and smoothing the edges.
    """
    # Convert boolean mask to 8-bit image (0 and 255)
    mask_8bit = (mask.astype(np.uint8)) * 255
    
    # Find all connected 'blobs'
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8bit, connectivity=8)
    
    if num_labels > 1:
        # We ignore index 0 (the background) and find the largest area among the rest
        largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
        refined_mask = (labels == largest_label).astype(np.uint8)
    else:
        refined_mask = mask_8bit / 255

    # Smooth the edges slightly using a Gaussian Blur
    refined_mask = cv2.GaussianBlur(refined_mask.astype(float), (3, 3), 0)
    return refined_mask > 0.5

def segment_object(image_data):
    if image_data is None or "background" not in image_data:
        return None

    # Load the background image
    raw_image = image_data["background"].convert("RGB")
    
    # Extract the user's drawing from the layers
    # We look at the alpha channel of the first layer to see where the user drew
    layers = image_data.get("layers", [])
    if not layers:
        return raw_image 

    # Get coordinates from the drawing layer
    mask_layer = np.array(layers[0].split()[-1]) # Alpha channel
    coords = np.argwhere(mask_layer > 0)
    
    if coords.size == 0:
        return raw_image # Return original if no selection made

    # Define the bounding box [x0, y0, x1, y1]
    y0, x0 = coords.min(axis=0)
    y1, x1 = coords.max(axis=0)
    input_boxes = [[[x0, y0, x1, y1]]]

    # --- AI PREDICTION ---
    inputs = processor(raw_image, return_tensors="pt").to(device)
    image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
    
    inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
    inputs.pop("pixel_values", None)
    inputs["image_embeddings"] = image_embeddings

    with torch.no_grad():
        outputs = model(**inputs)

    # Convert output to a binary mask
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), 
        inputs.original_sizes.cpu(), 
        inputs.reshaped_input_sizes.cpu()
    )
    best_mask = masks[0][0][0].numpy()

    # --- REFINEMENT STEP ---
    # This removes the "spots" you saw in your previous result
    final_mask = refine_mask(best_mask)

    # --- CREATE FINAL IMAGE ---
    raw_np = np.array(raw_image)
    # Create a pure white background
    white_bg = np.ones_like(raw_np) * 255
    
    # Blend: If mask is 1, take original pixel. If 0, take white pixel.
    output_np = np.where(final_mask[..., None], raw_np, white_bg)
    
    return Image.fromarray(output_np.astype('uint8'))

# 3. Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## 🛠️ High-Quality Object Extractor")
    gr.Markdown("Upload an image and **draw a tight rectangle** around the object you want to keep.")
    
    with gr.Row():
        with gr.Column():
            # The ImageEditor allows users to draw rectangles
            img_input = gr.ImageEditor(
                label="Input Image (Draw a Box)",
                type="pil",
                layers=True,
                sources=["upload", "clipboard"],
                canvas_size=(712, 712)
            )
            submit_btn = gr.Button("Extract & Clean Mask", variant="primary")
            
        with gr.Column():
            img_output = gr.Image(label="Result (White Background)", type="pil")

    submit_btn.click(
        fn=segment_object,
        inputs=[img_input],
        outputs=[img_output]
    )

    gr.Markdown("---")
    gr.Markdown("### 💡 Tips for better results:")
    gr.Markdown("- Draw your rectangle as **close to the object edges** as possible.")
    gr.Markdown("- If there are still spots, try using the **brush tool** instead of the rectangle to 'paint' exactly what you want.")

demo.launch()