Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import requests | |
| import gradio as gr | |
| import time | |
| from segment_anything import sam_model_registry, SamPredictor | |
| import supervision as sv | |
| # ------------------------------ | |
| # 1. Setup & Model Loading | |
| # ------------------------------ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device used: {device}") | |
| model_type = "vit_b" | |
| checkpoint_path = "sam_vit_b_01ec64.pth" | |
| # Download model if needed | |
| if not os.path.exists(checkpoint_path): | |
| print("Downloading SAM checkpoint...") | |
| url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" | |
| r = requests.get(url, stream=True) | |
| with open(checkpoint_path, "wb") as f: | |
| for chunk in r.iter_content(1024 * 1024): | |
| if chunk: | |
| f.write(chunk) | |
| sam = sam_model_registry[model_type](checkpoint=checkpoint_path) | |
| sam.to(device) | |
| # Note: We kept the model in float32. | |
| # Using sam.half() often causes runtime type mismatch errors with SamPredictor | |
| # unless the input image is explicitly cast to half-precision manually. | |
| predictor = SamPredictor(sam) | |
| # ------------------------------ | |
| # 2. The API Function with Progress | |
| # ------------------------------ | |
| # We add `progress=gr.Progress()` as a default argument. | |
| # Gradio automatically handles this when passed to the interface. | |
| def run_sam_api(image_url, box_coords, progress=gr.Progress()): | |
| # --- LOGGING HELPER --- | |
| def log_step(step_num, total_steps, message): | |
| timestamp = time.strftime("%H:%M:%S") | |
| print(f"[{timestamp}] Step {step_num}/{total_steps}: {message}") | |
| # Update Gradio UI bar (0.0 to 1.0) | |
| progress(step_num / total_steps, desc=message) | |
| # ---------------------- | |
| log_step(1, 5, f"Starting request for {image_url}") | |
| print(f"Received Box: {box_coords}") | |
| # 1. Download the Image | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0'} | |
| resp = requests.get(image_url, stream=True, headers=headers).raw | |
| image_array = np.asarray(bytearray(resp.read()), dtype="uint8") | |
| image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) | |
| if image is None: | |
| raise ValueError("Could not decode image.") | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| except Exception as e: | |
| print(f"Error downloading image: {e}") | |
| raise gr.Error(f"Failed to load image: {e}") | |
| # 2. Set Image for SAM (The Heavy Step) | |
| log_step(2, 5, "Generating Image Embeddings (This may take a moment)...") | |
| # Note: This specific line is where the GPU works hardest. | |
| # It might 'freeze' here for a few seconds. | |
| predictor.set_image(image) | |
| # 3. Prepare Box | |
| input_box = np.array(box_coords) | |
| # 4. Predict | |
| log_step(3, 5, "Decoding Masks...") | |
| masks, scores, logits = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_box[None, :], | |
| multimask_output=False | |
| ) | |
| # 5. Annotate / Visualize | |
| log_step(4, 5, "Annotating Image...") | |
| detections = sv.Detections( | |
| xyxy=sv.mask_to_xyxy(masks=masks), | |
| mask=masks | |
| ) | |
| # Fix for missing class_id | |
| detections.class_id = np.zeros(len(detections), dtype=int) | |
| mask_annotator = sv.MaskAnnotator(color=sv.Color.RED) | |
| box_annotator = sv.BoxAnnotator(color=sv.Color.RED) | |
| annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) | |
| annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections) | |
| log_step(5, 5, "Done!") | |
| return annotated_image | |
| demo = gr.Interface( | |
| fn=run_sam_api, | |
| inputs=[ | |
| gr.Textbox(label="Image URL", placeholder="http://..."), | |
| gr.JSON(label="Box Coords [x1, y1, x2, y2]", value=[2434, 666, 3737, 1756]) | |
| ], | |
| outputs=gr.Image(type="numpy", label="Segmented Output"), | |
| title="SAM API via Gradio", | |
| description="Send an image URL and bounding box coordinates to segment objects.", | |
| api_name="predict_api" | |
| ) | |
| if __name__ == "__main__": | |
| # If running on Hugging Face Spaces, just use launch() | |
| demo.queue().launch() |