Spaces:
Sleeping
Sleeping
File size: 4,424 Bytes
350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 d90f6b4 350a741 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import SamModel, SamProcessor
# 1. Load the Model and Processor (using the base model for speed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def refine_mask(mask):
"""
Cleans up the mask by keeping only the largest connected object
and smoothing the edges.
"""
# Convert boolean mask to 8-bit image (0 and 255)
mask_8bit = (mask.astype(np.uint8)) * 255
# Find all connected 'blobs'
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8bit, connectivity=8)
if num_labels > 1:
# We ignore index 0 (the background) and find the largest area among the rest
largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
refined_mask = (labels == largest_label).astype(np.uint8)
else:
refined_mask = mask_8bit / 255
# Smooth the edges slightly using a Gaussian Blur
refined_mask = cv2.GaussianBlur(refined_mask.astype(float), (3, 3), 0)
return refined_mask > 0.5
def segment_object(image_data):
if image_data is None or "background" not in image_data:
return None
# Load the background image
raw_image = image_data["background"].convert("RGB")
# Extract the user's drawing from the layers
# We look at the alpha channel of the first layer to see where the user drew
layers = image_data.get("layers", [])
if not layers:
return raw_image
# Get coordinates from the drawing layer
mask_layer = np.array(layers[0].split()[-1]) # Alpha channel
coords = np.argwhere(mask_layer > 0)
if coords.size == 0:
return raw_image # Return original if no selection made
# Define the bounding box [x0, y0, x1, y1]
y0, x0 = coords.min(axis=0)
y1, x1 = coords.max(axis=0)
input_boxes = [[[x0, y0, x1, y1]]]
# --- AI PREDICTION ---
inputs = processor(raw_image, return_tensors="pt").to(device)
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
inputs.pop("pixel_values", None)
inputs["image_embeddings"] = image_embeddings
with torch.no_grad():
outputs = model(**inputs)
# Convert output to a binary mask
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs.original_sizes.cpu(),
inputs.reshaped_input_sizes.cpu()
)
best_mask = masks[0][0][0].numpy()
# --- REFINEMENT STEP ---
# This removes the "spots" you saw in your previous result
final_mask = refine_mask(best_mask)
# --- CREATE FINAL IMAGE ---
raw_np = np.array(raw_image)
# Create a pure white background
white_bg = np.ones_like(raw_np) * 255
# Blend: If mask is 1, take original pixel. If 0, take white pixel.
output_np = np.where(final_mask[..., None], raw_np, white_bg)
return Image.fromarray(output_np.astype('uint8'))
# 3. Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🛠️ High-Quality Object Extractor")
gr.Markdown("Upload an image and **draw a tight rectangle** around the object you want to keep.")
with gr.Row():
with gr.Column():
# The ImageEditor allows users to draw rectangles
img_input = gr.ImageEditor(
label="Input Image (Draw a Box)",
type="pil",
layers=True,
sources=["upload", "clipboard"],
canvas_size=(712, 712)
)
submit_btn = gr.Button("Extract & Clean Mask", variant="primary")
with gr.Column():
img_output = gr.Image(label="Result (White Background)", type="pil")
submit_btn.click(
fn=segment_object,
inputs=[img_input],
outputs=[img_output]
)
gr.Markdown("---")
gr.Markdown("### 💡 Tips for better results:")
gr.Markdown("- Draw your rectangle as **close to the object edges** as possible.")
gr.Markdown("- If there are still spots, try using the **brush tool** instead of the rectangle to 'paint' exactly what you want.")
demo.launch() |