sam / app.py
AxZyzzz's picture
Update app.py
572e785 verified
Raw
History Blame Contribute Delete
4.13 kB
import os
import torch
import numpy as np
import cv2
import requests
import gradio as gr
import time
from segment_anything import sam_model_registry, SamPredictor
import supervision as sv
# ------------------------------
# 1. Setup & Model Loading
# ------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device used: {device}")
model_type = "vit_b"
checkpoint_path = "sam_vit_b_01ec64.pth"
# Download model if needed
if not os.path.exists(checkpoint_path):
print("Downloading SAM checkpoint...")
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
r = requests.get(url, stream=True)
with open(checkpoint_path, "wb") as f:
for chunk in r.iter_content(1024 * 1024):
if chunk:
f.write(chunk)
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device)
# Note: We kept the model in float32.
# Using sam.half() often causes runtime type mismatch errors with SamPredictor
# unless the input image is explicitly cast to half-precision manually.
predictor = SamPredictor(sam)
# ------------------------------
# 2. The API Function with Progress
# ------------------------------
# We add `progress=gr.Progress()` as a default argument.
# Gradio automatically handles this when passed to the interface.
def run_sam_api(image_url, box_coords, progress=gr.Progress()):
# --- LOGGING HELPER ---
def log_step(step_num, total_steps, message):
timestamp = time.strftime("%H:%M:%S")
print(f"[{timestamp}] Step {step_num}/{total_steps}: {message}")
# Update Gradio UI bar (0.0 to 1.0)
progress(step_num / total_steps, desc=message)
# ----------------------
log_step(1, 5, f"Starting request for {image_url}")
print(f"Received Box: {box_coords}")
# 1. Download the Image
try:
headers = {'User-Agent': 'Mozilla/5.0'}
resp = requests.get(image_url, stream=True, headers=headers).raw
image_array = np.asarray(bytearray(resp.read()), dtype="uint8")
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Could not decode image.")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except Exception as e:
print(f"Error downloading image: {e}")
raise gr.Error(f"Failed to load image: {e}")
# 2. Set Image for SAM (The Heavy Step)
log_step(2, 5, "Generating Image Embeddings (This may take a moment)...")
# Note: This specific line is where the GPU works hardest.
# It might 'freeze' here for a few seconds.
predictor.set_image(image)
# 3. Prepare Box
input_box = np.array(box_coords)
# 4. Predict
log_step(3, 5, "Decoding Masks...")
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False
)
# 5. Annotate / Visualize
log_step(4, 5, "Annotating Image...")
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks
)
# Fix for missing class_id
detections.class_id = np.zeros(len(detections), dtype=int)
mask_annotator = sv.MaskAnnotator(color=sv.Color.RED)
box_annotator = sv.BoxAnnotator(color=sv.Color.RED)
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
log_step(5, 5, "Done!")
return annotated_image
demo = gr.Interface(
fn=run_sam_api,
inputs=[
gr.Textbox(label="Image URL", placeholder="http://..."),
gr.JSON(label="Box Coords [x1, y1, x2, y2]", value=[2434, 666, 3737, 1756])
],
outputs=gr.Image(type="numpy", label="Segmented Output"),
title="SAM API via Gradio",
description="Send an image URL and bounding box coordinates to segment objects.",
api_name="predict_api"
)
if __name__ == "__main__":
# If running on Hugging Face Spaces, just use launch()
demo.queue().launch()