AItool's picture
Update app.py
b6780bc verified
raw
history blame
3.04 kB
import os
import subprocess
# --- Ensure SAM checkpoint is present ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
print("Downloading SAM checkpoint...")
subprocess.run([
"wget",
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
])
# ----------------------------------------
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
# --- CONFIG ---
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
SAM_MODEL_TYPE = "vit_h"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BLUR_RADIUS = 10
# --------------
# Load SAM once
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
def soft_alpha(mask_uint8, blur_radius=10):
blurred = cv2.GaussianBlur(mask_uint8, (0,0), sigmaX=blur_radius, sigmaY=blur_radius)
return (blurred.astype(np.float32) / 255.0).clip(0.0, 1.0)
def isolate_with_click(image: Image.Image, evt: gr.SelectData):
img_rgb = np.array(image.convert("RGB"))
predictor.set_image(img_rgb)
input_point = np.array([[evt.index[0], evt.index[1]]])
input_label = np.array([1]) # foreground
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
best_mask = masks[np.argmax(scores)].astype(np.uint8) * 255
alpha = soft_alpha(best_mask, blur_radius=BLUR_RADIUS)
ys, xs = np.where(best_mask == 255)
if len(xs) == 0 or len(ys) == 0:
return None, None
x0, x1 = xs.min(), xs.max()
y0, y1 = ys.min(), ys.max()
pad = int(max(img_rgb.shape[:2]) * 0.02)
x0 = max(0, x0 - pad); x1 = min(img_rgb.shape[1]-1, x1 + pad)
y0 = max(0, y0 - pad); y1 = min(img_rgb.shape[0]-1, y1 + pad)
fg_rgb = img_rgb[y0:y1+1, x0:x1+1]
fg_alpha = alpha[y0:y1+1, x0:x1+1]
rgba = np.dstack((fg_rgb, (fg_alpha * 255).astype(np.uint8)))
cutout = Image.fromarray(rgba)
# Build overlay preview (purple tint on mask)
overlay = img_rgb.copy()
tint = np.array([180, 0, 180], dtype=np.uint8) # purple
sel = best_mask == 255
overlay[sel] = (0.6 * overlay[sel] + 0.4 * tint).astype(np.uint8)
overlay_img = Image.fromarray(overlay)
return cutout, overlay_img
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("### SAM Object Isolation\nUpload an image, then click on the object to isolate it.")
inp = gr.Image(type="pil", label="Upload image", interactive=True)
out_cutout = gr.Image(type="pil", label="Isolated cutout (RGBA)")
out_overlay = gr.Image(type="pil", label="Segmentation overlay preview")
inp.select(isolate_with_click, inputs=[inp], outputs=[out_cutout, out_overlay])
# Demo example at the bottom
gr.Examples(
examples=["demo.png"], # make sure demo.png is in your repo
inputs=inp,
label="Try with demo image"
)
demo.launch(share=True)