import os import torch import numpy as np import cv2 import requests import gradio as gr import time from segment_anything import sam_model_registry, SamPredictor import supervision as sv # ------------------------------ # 1. Setup & Model Loading # ------------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device used: {device}") model_type = "vit_b" checkpoint_path = "sam_vit_b_01ec64.pth" # Download model if needed if not os.path.exists(checkpoint_path): print("Downloading SAM checkpoint...") url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" r = requests.get(url, stream=True) with open(checkpoint_path, "wb") as f: for chunk in r.iter_content(1024 * 1024): if chunk: f.write(chunk) sam = sam_model_registry[model_type](checkpoint=checkpoint_path) sam.to(device) # Note: We kept the model in float32. # Using sam.half() often causes runtime type mismatch errors with SamPredictor # unless the input image is explicitly cast to half-precision manually. predictor = SamPredictor(sam) # ------------------------------ # 2. The API Function with Progress # ------------------------------ # We add `progress=gr.Progress()` as a default argument. # Gradio automatically handles this when passed to the interface. def run_sam_api(image_url, box_coords, progress=gr.Progress()): # --- LOGGING HELPER --- def log_step(step_num, total_steps, message): timestamp = time.strftime("%H:%M:%S") print(f"[{timestamp}] Step {step_num}/{total_steps}: {message}") # Update Gradio UI bar (0.0 to 1.0) progress(step_num / total_steps, desc=message) # ---------------------- log_step(1, 5, f"Starting request for {image_url}") print(f"Received Box: {box_coords}") # 1. Download the Image try: headers = {'User-Agent': 'Mozilla/5.0'} resp = requests.get(image_url, stream=True, headers=headers).raw image_array = np.asarray(bytearray(resp.read()), dtype="uint8") image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) if image is None: raise ValueError("Could not decode image.") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) except Exception as e: print(f"Error downloading image: {e}") raise gr.Error(f"Failed to load image: {e}") # 2. Set Image for SAM (The Heavy Step) log_step(2, 5, "Generating Image Embeddings (This may take a moment)...") # Note: This specific line is where the GPU works hardest. # It might 'freeze' here for a few seconds. predictor.set_image(image) # 3. Prepare Box input_box = np.array(box_coords) # 4. Predict log_step(3, 5, "Decoding Masks...") masks, scores, logits = predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False ) # 5. Annotate / Visualize log_step(4, 5, "Annotating Image...") detections = sv.Detections( xyxy=sv.mask_to_xyxy(masks=masks), mask=masks ) # Fix for missing class_id detections.class_id = np.zeros(len(detections), dtype=int) mask_annotator = sv.MaskAnnotator(color=sv.Color.RED) box_annotator = sv.BoxAnnotator(color=sv.Color.RED) annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections) log_step(5, 5, "Done!") return annotated_image demo = gr.Interface( fn=run_sam_api, inputs=[ gr.Textbox(label="Image URL", placeholder="http://..."), gr.JSON(label="Box Coords [x1, y1, x2, y2]", value=[2434, 666, 3737, 1756]) ], outputs=gr.Image(type="numpy", label="Segmented Output"), title="SAM API via Gradio", description="Send an image URL and bounding box coordinates to segment objects.", api_name="predict_api" ) if __name__ == "__main__": # If running on Hugging Face Spaces, just use launch() demo.queue().launch()