Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| from PIL import Image | |
| from transformers import SamModel, SamProcessor | |
| # 1. Load the Model and Processor (using the base model for speed) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
| def refine_mask(mask): | |
| """ | |
| Cleans up the mask by keeping only the largest connected object | |
| and smoothing the edges. | |
| """ | |
| # Convert boolean mask to 8-bit image (0 and 255) | |
| mask_8bit = (mask.astype(np.uint8)) * 255 | |
| # Find all connected 'blobs' | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8bit, connectivity=8) | |
| if num_labels > 1: | |
| # We ignore index 0 (the background) and find the largest area among the rest | |
| largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) | |
| refined_mask = (labels == largest_label).astype(np.uint8) | |
| else: | |
| refined_mask = mask_8bit / 255 | |
| # Smooth the edges slightly using a Gaussian Blur | |
| refined_mask = cv2.GaussianBlur(refined_mask.astype(float), (3, 3), 0) | |
| return refined_mask > 0.5 | |
| def segment_object(image_data): | |
| if image_data is None or "background" not in image_data: | |
| return None | |
| # Load the background image | |
| raw_image = image_data["background"].convert("RGB") | |
| # Extract the user's drawing from the layers | |
| # We look at the alpha channel of the first layer to see where the user drew | |
| layers = image_data.get("layers", []) | |
| if not layers: | |
| return raw_image | |
| # Get coordinates from the drawing layer | |
| mask_layer = np.array(layers[0].split()[-1]) # Alpha channel | |
| coords = np.argwhere(mask_layer > 0) | |
| if coords.size == 0: | |
| return raw_image # Return original if no selection made | |
| # Define the bounding box [x0, y0, x1, y1] | |
| y0, x0 = coords.min(axis=0) | |
| y1, x1 = coords.max(axis=0) | |
| input_boxes = [[[x0, y0, x1, y1]]] | |
| # --- AI PREDICTION --- | |
| inputs = processor(raw_image, return_tensors="pt").to(device) | |
| image_embeddings = model.get_image_embeddings(inputs["pixel_values"]) | |
| inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device) | |
| inputs.pop("pixel_values", None) | |
| inputs["image_embeddings"] = image_embeddings | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Convert output to a binary mask | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs.original_sizes.cpu(), | |
| inputs.reshaped_input_sizes.cpu() | |
| ) | |
| best_mask = masks[0][0][0].numpy() | |
| # --- REFINEMENT STEP --- | |
| # This removes the "spots" you saw in your previous result | |
| final_mask = refine_mask(best_mask) | |
| # --- CREATE FINAL IMAGE --- | |
| raw_np = np.array(raw_image) | |
| # Create a pure white background | |
| white_bg = np.ones_like(raw_np) * 255 | |
| # Blend: If mask is 1, take original pixel. If 0, take white pixel. | |
| output_np = np.where(final_mask[..., None], raw_np, white_bg) | |
| return Image.fromarray(output_np.astype('uint8')) | |
| # 3. Build the Gradio UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 🛠️ High-Quality Object Extractor") | |
| gr.Markdown("Upload an image and **draw a tight rectangle** around the object you want to keep.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # The ImageEditor allows users to draw rectangles | |
| img_input = gr.ImageEditor( | |
| label="Input Image (Draw a Box)", | |
| type="pil", | |
| layers=True, | |
| sources=["upload", "clipboard"], | |
| canvas_size=(712, 712) | |
| ) | |
| submit_btn = gr.Button("Extract & Clean Mask", variant="primary") | |
| with gr.Column(): | |
| img_output = gr.Image(label="Result (White Background)", type="pil") | |
| submit_btn.click( | |
| fn=segment_object, | |
| inputs=[img_input], | |
| outputs=[img_output] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 💡 Tips for better results:") | |
| gr.Markdown("- Draw your rectangle as **close to the object edges** as possible.") | |
| gr.Markdown("- If there are still spots, try using the **brush tool** instead of the rectangle to 'paint' exactly what you want.") | |
| demo.launch() |