import gradio as gr import numpy as np import torch import cv2 from PIL import Image from transformers import SamModel, SamProcessor # 1. Load the Model and Processor (using the base model for speed) device = "cuda" if torch.cuda.is_available() else "cpu" model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def refine_mask(mask): """ Cleans up the mask by keeping only the largest connected object and smoothing the edges. """ # Convert boolean mask to 8-bit image (0 and 255) mask_8bit = (mask.astype(np.uint8)) * 255 # Find all connected 'blobs' num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8bit, connectivity=8) if num_labels > 1: # We ignore index 0 (the background) and find the largest area among the rest largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) refined_mask = (labels == largest_label).astype(np.uint8) else: refined_mask = mask_8bit / 255 # Smooth the edges slightly using a Gaussian Blur refined_mask = cv2.GaussianBlur(refined_mask.astype(float), (3, 3), 0) return refined_mask > 0.5 def segment_object(image_data): if image_data is None or "background" not in image_data: return None # Load the background image raw_image = image_data["background"].convert("RGB") # Extract the user's drawing from the layers # We look at the alpha channel of the first layer to see where the user drew layers = image_data.get("layers", []) if not layers: return raw_image # Get coordinates from the drawing layer mask_layer = np.array(layers[0].split()[-1]) # Alpha channel coords = np.argwhere(mask_layer > 0) if coords.size == 0: return raw_image # Return original if no selection made # Define the bounding box [x0, y0, x1, y1] y0, x0 = coords.min(axis=0) y1, x1 = coords.max(axis=0) input_boxes = [[[x0, y0, x1, y1]]] # --- AI PREDICTION --- inputs = processor(raw_image, return_tensors="pt").to(device) image_embeddings = model.get_image_embeddings(inputs["pixel_values"]) inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device) inputs.pop("pixel_values", None) inputs["image_embeddings"] = image_embeddings with torch.no_grad(): outputs = model(**inputs) # Convert output to a binary mask masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs.original_sizes.cpu(), inputs.reshaped_input_sizes.cpu() ) best_mask = masks[0][0][0].numpy() # --- REFINEMENT STEP --- # This removes the "spots" you saw in your previous result final_mask = refine_mask(best_mask) # --- CREATE FINAL IMAGE --- raw_np = np.array(raw_image) # Create a pure white background white_bg = np.ones_like(raw_np) * 255 # Blend: If mask is 1, take original pixel. If 0, take white pixel. output_np = np.where(final_mask[..., None], raw_np, white_bg) return Image.fromarray(output_np.astype('uint8')) # 3. Build the Gradio UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## 🛠️ High-Quality Object Extractor") gr.Markdown("Upload an image and **draw a tight rectangle** around the object you want to keep.") with gr.Row(): with gr.Column(): # The ImageEditor allows users to draw rectangles img_input = gr.ImageEditor( label="Input Image (Draw a Box)", type="pil", layers=True, sources=["upload", "clipboard"], canvas_size=(712, 712) ) submit_btn = gr.Button("Extract & Clean Mask", variant="primary") with gr.Column(): img_output = gr.Image(label="Result (White Background)", type="pil") submit_btn.click( fn=segment_object, inputs=[img_input], outputs=[img_output] ) gr.Markdown("---") gr.Markdown("### 💡 Tips for better results:") gr.Markdown("- Draw your rectangle as **close to the object edges** as possible.") gr.Markdown("- If there are still spots, try using the **brush tool** instead of the rectangle to 'paint' exactly what you want.") demo.launch()