import os import subprocess # --- Ensure SAM checkpoint is present --- SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" if not os.path.exists(SAM_CHECKPOINT): print("Downloading SAM checkpoint...") subprocess.run([ "wget", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" ]) # ---------------------------------------- import gradio as gr import numpy as np from PIL import Image import cv2 import torch from segment_anything import sam_model_registry, SamPredictor # --- CONFIG --- SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" SAM_MODEL_TYPE = "vit_h" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BLUR_RADIUS = 10 # -------------- # Load SAM once sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT) sam.to(device=DEVICE) predictor = SamPredictor(sam) def soft_alpha(mask_uint8, blur_radius=10): blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius) return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0) def isolate_with_click(image: Image.Image, evt: gr.SelectData): img_rgb = np.array(image.convert("RGB")) predictor.set_image(img_rgb) input_point = np.array([[evt.index[0], evt.index[1]]]) input_label = np.array([1]) # foreground masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255 alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS) ys, xs = np.where(best_mask == 255) if len(xs) == 0 or len(ys) == 0: return None, None x0, x1 = xs.min(), xs.max() y0, y1 = ys.min(), ys.max() pad = int(max(img_rgb.shape[:2]) * 0.02) x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad) y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad) fg_rgb = img_rgb[y0:y1+1, x0:x1+1] fg_alpha = alpha[y0:y1+1, x0:x1+1] rgba = np.dstack((fg_rgb, (fg_alpha * 255).astype(np.uint8))) cutout = Image.fromarray(rgba) # Build overlay preview (purple tint on mask) overlay = img_rgb.copy() tint = np.array([180, 0, 180], dtype=np.uint8) # purple sel = best_mask == 255 overlay[sel] = (0.6 * overlay[sel] + 0.4 * tint).astype(np.uint8) overlay_img = Image.fromarray(overlay) return cutout, overlay_img # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.") inp = gr.Image(type="pil", label="Upload image", interactive=True) out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)") out_overlay = gr.Image(type="pil", label="Segmentation overlay preview") inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay]) # Demo example at the bottom gr.Examples( examples=["demo.png"], # make sure demo.png is in your repo inputs=inp, label="Try with demo image" ) demo.launch(share=True)