SAM_Gardio / app.py
sajabdoli's picture
Upload 3 files
662e85c verified
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import numpy as np
import torch
# Load model
sam_checkpoint = "sam_vit_b.pth" # Upload this manually to the Hugging Face Space
model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)
def segment_with_sam(image):
image_np = np.array(image.convert("RGB"))
predictor.set_image(image_np)
# Dummy point for prompt (center)
h, w, _ = image_np.shape
point = np.array([[w // 2, h // 2]])
label = np.array([1])
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=label,
multimask_output=False
)
mask = masks[0]
mask_img = (mask[..., None] * np.array([255, 0, 0])).astype(np.uint8)
overlay = Image.fromarray((0.5 * image_np + 0.5 * mask_img).astype(np.uint8))
return overlay
iface = gr.Interface(fn=segment_with_sam,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Segment Anything with SAM",
description="Simple SAM demo using Gradio.")
iface.launch()