Inam65's picture
Update app.py
d90f6b4 verified
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import SamModel, SamProcessor
# 1. Load the Model and Processor (using the base model for speed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def refine_mask(mask):
"""
Cleans up the mask by keeping only the largest connected object
and smoothing the edges.
"""
# Convert boolean mask to 8-bit image (0 and 255)
mask_8bit = (mask.astype(np.uint8)) * 255
# Find all connected 'blobs'
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8bit, connectivity=8)
if num_labels > 1:
# We ignore index 0 (the background) and find the largest area among the rest
largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
refined_mask = (labels == largest_label).astype(np.uint8)
else:
refined_mask = mask_8bit / 255
# Smooth the edges slightly using a Gaussian Blur
refined_mask = cv2.GaussianBlur(refined_mask.astype(float), (3, 3), 0)
return refined_mask > 0.5
def segment_object(image_data):
if image_data is None or "background" not in image_data:
return None
# Load the background image
raw_image = image_data["background"].convert("RGB")
# Extract the user's drawing from the layers
# We look at the alpha channel of the first layer to see where the user drew
layers = image_data.get("layers", [])
if not layers:
return raw_image
# Get coordinates from the drawing layer
mask_layer = np.array(layers[0].split()[-1]) # Alpha channel
coords = np.argwhere(mask_layer > 0)
if coords.size == 0:
return raw_image # Return original if no selection made
# Define the bounding box [x0, y0, x1, y1]
y0, x0 = coords.min(axis=0)
y1, x1 = coords.max(axis=0)
input_boxes = [[[x0, y0, x1, y1]]]
# --- AI PREDICTION ---
inputs = processor(raw_image, return_tensors="pt").to(device)
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
inputs.pop("pixel_values", None)
inputs["image_embeddings"] = image_embeddings
with torch.no_grad():
outputs = model(**inputs)
# Convert output to a binary mask
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs.original_sizes.cpu(),
inputs.reshaped_input_sizes.cpu()
)
best_mask = masks[0][0][0].numpy()
# --- REFINEMENT STEP ---
# This removes the "spots" you saw in your previous result
final_mask = refine_mask(best_mask)
# --- CREATE FINAL IMAGE ---
raw_np = np.array(raw_image)
# Create a pure white background
white_bg = np.ones_like(raw_np) * 255
# Blend: If mask is 1, take original pixel. If 0, take white pixel.
output_np = np.where(final_mask[..., None], raw_np, white_bg)
return Image.fromarray(output_np.astype('uint8'))
# 3. Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🛠️ High-Quality Object Extractor")
gr.Markdown("Upload an image and **draw a tight rectangle** around the object you want to keep.")
with gr.Row():
with gr.Column():
# The ImageEditor allows users to draw rectangles
img_input = gr.ImageEditor(
label="Input Image (Draw a Box)",
type="pil",
layers=True,
sources=["upload", "clipboard"],
canvas_size=(712, 712)
)
submit_btn = gr.Button("Extract & Clean Mask", variant="primary")
with gr.Column():
img_output = gr.Image(label="Result (White Background)", type="pil")
submit_btn.click(
fn=segment_object,
inputs=[img_input],
outputs=[img_output]
)
gr.Markdown("---")
gr.Markdown("### 💡 Tips for better results:")
gr.Markdown("- Draw your rectangle as **close to the object edges** as possible.")
gr.Markdown("- If there are still spots, try using the **brush tool** instead of the rectangle to 'paint' exactly what you want.")
demo.launch()