| """
|
| Hugging Face Spaces deployment for SAM2 Auto Annotation API.
|
| This file serves as the entry point for the FastAPI application on Hugging Face Spaces.
|
| """
|
| import sys
|
| import os
|
|
|
|
|
| _current_dir = os.path.dirname(os.path.abspath(__file__))
|
| _sam2_dir = os.path.join(_current_dir, "sam2")
|
|
|
| abs_sam2_dir = os.path.abspath(_sam2_dir)
|
| if abs_sam2_dir not in sys.path:
|
| sys.path.insert(0, abs_sam2_dir)
|
|
|
| from fastapi import FastAPI, HTTPException
|
| from fastapi.middleware.cors import CORSMiddleware
|
| import cv2
|
| import numpy as np
|
| import torch
|
| import psutil
|
| import PIL.Image
|
| from requests.exceptions import Timeout, RequestException
|
|
|
|
|
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| from model.sam_model import predict_polygon, predict_polygon_from_point
|
| from model.utils import load_image_from_url, mask_to_polygon
|
| from model.sam2_detection_function import SAM2AutoAnnotation, create_sam2_auto_annotation
|
|
|
|
|
| HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large"
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| sam2_auto_annotation_global = None
|
|
|
| app = FastAPI(
|
| title="SAM Auto Annotation API (BBox ➜ Polygon)",
|
| description="AI-powered auto-annotation API using Meta's Segment Anything Model (SAM)",
|
| version="1.0.0"
|
| )
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| @app.get("/")
|
| def root():
|
| """Root endpoint - API information."""
|
| return {
|
| "status": "Service is up and running!",
|
| "message": "Backend service is active",
|
| "api": "SAM Auto Annotation API",
|
| "version": "1.0.0"
|
| }
|
|
|
|
|
| @app.get("/health")
|
| def health_check():
|
| """Health check endpoint."""
|
| return {"status": "healthy", "service": "same model segmenticAPI"}
|
|
|
|
|
| @app.post("/segment")
|
| def segment(data: dict):
|
| """
|
| Segment image using SAM2 model to convert bounding box to polygon (CVAT-style).
|
| Bbox is used as a prompt to identify the object, not as a constraint.
|
|
|
| **Input:**
|
| ```json
|
| {
|
| "imageUrl": "https://example.com/image.jpg",
|
| "bbox": {"x": 494.97, "y": 187.22, "width": 137.99, "height": 98.00, "label": "Object"},
|
| "imageSize": {"width": 663.07, "height": 442}
|
| }
|
| ```
|
|
|
| OR
|
|
|
| ```json
|
| {
|
| "imageUrl": "https://example.com/image.jpg",
|
| "bbox": [494.97, 187.22, 137.99, 98.00], // [x, y, width, height]
|
| "imageSize": [663.07, 442] // [width, height]
|
| }
|
| ```
|
|
|
| **Output:**
|
| ```json
|
| {
|
| "polygon": [x1, y1, x2, y2, x3, y3, ...], // CVAT format: flattened coordinates
|
| "confidence": 0.96
|
| }
|
| ```
|
| """
|
| try:
|
|
|
| if "imageUrl" not in data:
|
| raise HTTPException(status_code=400, detail="Missing required field: imageUrl")
|
| if "bbox" not in data:
|
| raise HTTPException(status_code=400, detail="Missing required field: bbox")
|
|
|
| image_url = data["imageUrl"]
|
| bbox = data["bbox"]
|
| image_size = data.get("imageSize")
|
|
|
|
|
| if isinstance(bbox, dict):
|
| required_keys = ["x", "y", "width", "height"]
|
| if not all(key in bbox for key in required_keys):
|
| raise HTTPException(
|
| status_code=400,
|
| detail=f"bbox dict must contain: {required_keys}"
|
| )
|
| elif isinstance(bbox, list):
|
| if len(bbox) != 4:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="bbox list must contain exactly 4 values: [x, y, width, height]"
|
| )
|
| else:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="bbox must be either a dict or a list"
|
| )
|
|
|
|
|
| if image_size is not None:
|
| if isinstance(image_size, dict):
|
| if not ("width" in image_size and "height" in image_size):
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize dict must contain 'width' and 'height'"
|
| )
|
| elif isinstance(image_size, list):
|
| if len(image_size) != 2:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize list must contain exactly 2 values: [width, height]"
|
| )
|
| else:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize must be either a dict or a list"
|
| )
|
|
|
|
|
| img_bgr = load_image_from_url(image_url)
|
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| mask, confidence, scale_factors = predict_polygon(img_rgb, bbox, image_size)
|
|
|
|
|
| polygon = mask_to_polygon(mask, scale_factors)
|
|
|
| if not polygon:
|
| raise HTTPException(status_code=400, detail="No polygon found in mask")
|
|
|
| return {
|
| "polygon": polygon,
|
| "confidence": confidence
|
| }
|
| except KeyError as e:
|
| raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}")
|
| except ValueError as e:
|
| raise HTTPException(status_code=400, detail=str(e))
|
| except FileNotFoundError as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
| except ImportError as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=f"Segment Anything library not installed. Please run: pip install -e . in segment-anything directory"
|
| )
|
| except Timeout as e:
|
| raise HTTPException(
|
| status_code=504,
|
| detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL."
|
| )
|
| except RequestException as e:
|
| raise HTTPException(
|
| status_code=502,
|
| detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again."
|
| )
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
| @app.post("/segment/point")
|
| def segment_from_point(data: dict):
|
| """
|
| Segment image using SAM2 model with a point click to select object.
|
| The point identifies which object to segment.
|
|
|
| **Input:**
|
| ```json
|
| {
|
| "imageUrl": "https://example.com/image.jpg",
|
| "point": {"x": 494.97, "y": 187.22},
|
| "imageSize": {"width": 663.07, "height": 442}
|
| }
|
| ```
|
|
|
| OR
|
|
|
| ```json
|
| {
|
| "imageUrl": "https://example.com/image.jpg",
|
| "point": [494.97, 187.22], // [x, y]
|
| "imageSize": [663.07, 442] // [width, height]
|
| }
|
| ```
|
|
|
| **Output:**
|
| ```json
|
| {
|
| "polygon": [x1, y1, x2, y2, x3, y3, ...], // CVAT format: flattened coordinates
|
| "confidence": 0.96
|
| }
|
| ```
|
| """
|
| try:
|
|
|
| if "imageUrl" not in data:
|
| raise HTTPException(status_code=400, detail="Missing required field: imageUrl")
|
| if "point" not in data:
|
| raise HTTPException(status_code=400, detail="Missing required field: point")
|
|
|
| image_url = data["imageUrl"]
|
| point = data["point"]
|
| image_size = data.get("imageSize")
|
|
|
|
|
| if isinstance(point, dict):
|
| required_keys = ["x", "y"]
|
| if not all(key in point for key in required_keys):
|
| raise HTTPException(
|
| status_code=400,
|
| detail=f"point dict must contain: {required_keys}"
|
| )
|
| elif isinstance(point, list):
|
| if len(point) != 2:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="point list must contain exactly 2 values: [x, y]"
|
| )
|
| else:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="point must be either a dict or a list"
|
| )
|
|
|
|
|
| if image_size is not None:
|
| if isinstance(image_size, dict):
|
| if not ("width" in image_size and "height" in image_size):
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize dict must contain 'width' and 'height'"
|
| )
|
| elif isinstance(image_size, list):
|
| if len(image_size) != 2:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize list must contain exactly 2 values: [width, height]"
|
| )
|
| else:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize must be either a dict or a list"
|
| )
|
|
|
|
|
| img_bgr = load_image_from_url(image_url)
|
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| mask, confidence, scale_factors = predict_polygon_from_point(img_rgb, point, image_size)
|
|
|
|
|
| polygon = mask_to_polygon(mask, scale_factors)
|
|
|
| if not polygon:
|
| raise HTTPException(status_code=400, detail="No polygon found in mask. Try clicking on a different point.")
|
|
|
| return {
|
| "polygon": polygon,
|
| "confidence": confidence
|
| }
|
| except KeyError as e:
|
| raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}")
|
| except ValueError as e:
|
| raise HTTPException(status_code=400, detail=str(e))
|
| except FileNotFoundError as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
| except ImportError as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=f"Segment Anything library not installed. Please run: pip install -e . in segment-anything directory"
|
| )
|
| except Timeout as e:
|
| raise HTTPException(
|
| status_code=504,
|
| detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL."
|
| )
|
| except RequestException as e:
|
| raise HTTPException(
|
| status_code=502,
|
| detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again."
|
| )
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
| @app.post("/auto-annotate")
|
| def auto_annotate(data: dict):
|
| """
|
| Automatically detect and segment all objects in an image using SAM2 from Hugging Face.
|
| Uses SAM2AutomaticMaskGenerator (facebook/sam2.1-hiera-large) to detect all objects without requiring prompts (bbox or points).
|
|
|
| **Input:**
|
| ```json
|
| {
|
| "imageUrl": "https://example.com/image.jpg",
|
| "imageSize": {"width": 663.07, "height": 442},
|
| "minArea": 100,
|
| "minConfidence": 0.5,
|
| "maxImageDimension": 1024,
|
| "pointsPerSide": 32,
|
| "pointsPerBatch": 64,
|
| "filterObjectsOnly": true
|
| }
|
| ```
|
|
|
| **Output:**
|
| ```json
|
| {
|
| "masks": [
|
| {
|
| "polygon": [x1, y1, x2, y2, x3, y3, ...],
|
| "confidence": 0.93,
|
| "area": 12345
|
| },
|
| ...
|
| ],
|
| "count": 10,
|
| "memoryInfo": {
|
| "before_mb": 512.5,
|
| "after_mb": 1024.3,
|
| "peak_mb": 1024.3,
|
| "estimated_mb": 800.0,
|
| "memory_used_mb": 511.8
|
| },
|
| "imageInfo": {
|
| "wasResized": true,
|
| "originalSize": [1920, 1080],
|
| "processedSize": [1024, 576],
|
| "resizeScale": [1.875, 1.875]
|
| }
|
| }
|
| ```
|
| """
|
| try:
|
|
|
| if "imageUrl" not in data:
|
| raise HTTPException(status_code=400, detail="Missing required field: imageUrl")
|
|
|
| image_url = data["imageUrl"]
|
| image_size = data.get("imageSize")
|
| min_area = data.get("minArea", 100)
|
| min_confidence = data.get("minConfidence", 0.5)
|
| max_image_dimension = data.get("maxImageDimension", 1024)
|
|
|
| points_per_side = data.get("pointsPerSide", 32)
|
| points_per_batch = data.get("pointsPerBatch", 64)
|
| filter_objects_only = data.get("filterObjectsOnly", False)
|
|
|
|
|
| if image_size is not None:
|
| if isinstance(image_size, dict):
|
| if not ("width" in image_size and "height" in image_size):
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize dict must contain 'width' and 'height'"
|
| )
|
| elif isinstance(image_size, list):
|
| if len(image_size) != 2:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize list must contain exactly 2 values: [width, height]"
|
| )
|
| else:
|
| raise HTTPException(
|
| status_code=400,
|
| detail="imageSize must be either a dict or a list"
|
| )
|
|
|
|
|
| try:
|
| min_area = int(min_area)
|
| if min_area < 0:
|
| raise HTTPException(status_code=400, detail="minArea must be >= 0")
|
| except (ValueError, TypeError):
|
| raise HTTPException(status_code=400, detail="minArea must be an integer")
|
|
|
| try:
|
| min_confidence = float(min_confidence)
|
| if not (0.0 <= min_confidence <= 1.0):
|
| raise HTTPException(status_code=400, detail="minConfidence must be between 0.0 and 1.0")
|
| except (ValueError, TypeError):
|
| raise HTTPException(status_code=400, detail="minConfidence must be a float between 0.0 and 1.0")
|
|
|
|
|
| try:
|
| max_image_dimension = int(max_image_dimension)
|
| if max_image_dimension < 256:
|
| raise HTTPException(status_code=400, detail="maxImageDimension must be >= 256")
|
| if max_image_dimension > 4096:
|
| raise HTTPException(status_code=400, detail="maxImageDimension must be <= 4096")
|
| except (ValueError, TypeError):
|
| raise HTTPException(status_code=400, detail="maxImageDimension must be an integer between 256 and 4096")
|
|
|
|
|
| try:
|
| points_per_side = int(points_per_side)
|
| if points_per_side < 8:
|
| raise HTTPException(status_code=400, detail="pointsPerSide must be >= 8")
|
| if points_per_side > 128:
|
| raise HTTPException(status_code=400, detail="pointsPerSide must be <= 128")
|
| except (ValueError, TypeError):
|
| raise HTTPException(status_code=400, detail="pointsPerSide must be an integer between 8 and 128")
|
|
|
|
|
| try:
|
| points_per_batch = int(points_per_batch)
|
| if points_per_batch < 16:
|
| raise HTTPException(status_code=400, detail="pointsPerBatch must be >= 16")
|
| if points_per_batch > 256:
|
| raise HTTPException(status_code=400, detail="pointsPerBatch must be <= 256")
|
| except (ValueError, TypeError):
|
| raise HTTPException(status_code=400, detail="pointsPerBatch must be an integer between 16 and 256")
|
|
|
|
|
| process = psutil.Process(os.getpid())
|
| memory_before = process.memory_info().rss / (1024 * 1024)
|
|
|
|
|
| img_bgr = load_image_from_url(image_url)
|
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| original_h, original_w = img_rgb.shape[:2]
|
| original_size = [original_w, original_h]
|
|
|
| processed_image = img_rgb
|
| resize_scale = [1.0, 1.0]
|
| was_resized = False
|
|
|
| if max(original_h, original_w) > max_image_dimension:
|
| was_resized = True
|
| if original_h > original_w:
|
| new_h = max_image_dimension
|
| new_w = int(original_w * (max_image_dimension / original_h))
|
| else:
|
| new_w = max_image_dimension
|
| new_h = int(original_h * (max_image_dimension / original_w))
|
| processed_image = cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| resize_scale = [original_w / new_w, original_h / new_h]
|
|
|
| processed_h, processed_w = processed_image.shape[:2]
|
| processed_size = [processed_w, processed_h]
|
|
|
|
|
| estimated_mb = ((processed_w * processed_h * 3 * 4) + (processed_w * processed_h * 256 * 4) + (processed_w * processed_h * 100 * 1)) / (1024 * 1024)
|
|
|
|
|
|
|
|
|
|
|
| scale_factor_x, scale_factor_y = 1.0, 1.0
|
|
|
| if image_size is not None:
|
| if isinstance(image_size, dict):
|
| display_w = float(image_size.get("width", processed_w))
|
| display_h = float(image_size.get("height", processed_h))
|
| else:
|
| display_w, display_h = float(image_size[0]), float(image_size[1])
|
|
|
|
|
|
|
| scale_factor_x = processed_w / display_w if display_w > 0 else 1.0
|
| scale_factor_y = processed_h / display_h if display_h > 0 else 1.0
|
|
|
|
|
| total_image_area = processed_w * processed_h
|
|
|
|
|
|
|
|
|
| global sam2_auto_annotation_global
|
|
|
| if sam2_auto_annotation_global is None:
|
| try:
|
| sam2_auto_annotation_global = create_sam2_auto_annotation(
|
| points_per_side=points_per_side,
|
| points_per_batch=points_per_batch,
|
| pred_iou_thresh=0.88,
|
| stability_score_thresh=0.95,
|
| min_mask_region_area=min_area,
|
| )
|
| except ImportError as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=f"Failed to import required modules. Please ensure 'sam2' and 'huggingface_hub' are installed. Error: {str(e)}"
|
| )
|
| except Exception as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=f"Failed to load SAM2 Auto Annotation from Hugging Face ({HUGGINGFACE_MODEL_ID}). Error: {str(e)}"
|
| )
|
|
|
|
|
|
|
| mask_results = sam2_auto_annotation_global.generate_masks(
|
| image=processed_image,
|
| min_confidence=min_confidence,
|
| min_area=min_area,
|
| filter_blank_regions=True,
|
| scale_factors=(scale_factor_x, scale_factor_y)
|
| )
|
|
|
|
|
| memory_after = process.memory_info().rss / (1024 * 1024)
|
| memory_used = memory_after - memory_before
|
|
|
|
|
| results = []
|
|
|
| for mask_result in mask_results:
|
|
|
| polygon = mask_result.get("polygon")
|
| score = mask_result.get("confidence")
|
| area = mask_result.get("area")
|
|
|
|
|
| if area < min_area or score < min_confidence:
|
| continue
|
|
|
|
|
| if filter_objects_only:
|
| coverage_ratio = area / total_image_area if total_image_area > 0 else 0
|
| if coverage_ratio >= 0.8:
|
| continue
|
|
|
|
|
|
|
| if polygon and len(polygon) >= 6:
|
| mask_obj = {
|
| "polygon": polygon
|
| }
|
| if score is not None:
|
| mask_obj["confidence"] = score
|
| if area is not None:
|
| mask_obj["area"] = area
|
| results.append(mask_obj)
|
|
|
|
|
| response = {
|
| "masks": results,
|
| "count": len(results),
|
| "memoryInfo": {
|
| "before_mb": round(memory_before, 2),
|
| "after_mb": round(memory_after, 2),
|
| "peak_mb": round(memory_after, 2),
|
| "estimated_mb": round(estimated_mb, 2),
|
| "memory_used_mb": round(memory_used, 2)
|
| },
|
| "imageInfo": {
|
| "wasResized": was_resized,
|
| "originalSize": original_size,
|
| "processedSize": processed_size,
|
| "resizeScale": resize_scale
|
| }
|
| }
|
|
|
| return response
|
|
|
| except KeyError as e:
|
| raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}")
|
| except ValueError as e:
|
| raise HTTPException(status_code=400, detail=str(e))
|
| except FileNotFoundError as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
| except ImportError as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=f"Segment Anything library not installed. Please ensure 'sam2' and 'huggingface_hub' are installed."
|
| )
|
| except Timeout as e:
|
| raise HTTPException(
|
| status_code=504,
|
| detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL."
|
| )
|
| except RequestException as e:
|
| raise HTTPException(
|
| status_code=502,
|
| detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again."
|
| )
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|