""" Hugging Face Spaces deployment for SAM2 Auto Annotation API. This file serves as the entry point for the FastAPI application on Hugging Face Spaces. """ import sys import os # Add sam2 folder to path to import from local sam2 directory _current_dir = os.path.dirname(os.path.abspath(__file__)) _sam2_dir = os.path.join(_current_dir, "sam2") # Add sam2 directory to sys.path if not already there abs_sam2_dir = os.path.abspath(_sam2_dir) if abs_sam2_dir not in sys.path: sys.path.insert(0, abs_sam2_dir) from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import cv2 import numpy as np import torch import psutil import PIL.Image from requests.exceptions import Timeout, RequestException # Import sam2 from local folder from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from model.sam_model import predict_polygon, predict_polygon_from_point from model.utils import load_image_from_url, mask_to_polygon from model.sam2_detection_function import SAM2AutoAnnotation, create_sam2_auto_annotation # Hugging Face model ID for SAM2.1 Hiera Large model HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large" device = "cuda" if torch.cuda.is_available() else "cpu" # Global SAM2 auto annotation (initialized once) sam2_auto_annotation_global = None app = FastAPI( title="SAM Auto Annotation API (BBox ➜ Polygon)", description="AI-powered auto-annotation API using Meta's Segment Anything Model (SAM)", version="1.0.0" ) # Add CORS middleware to handle preflight OPTIONS requests app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods including OPTIONS allow_headers=["*"], # Allows all headers ) @app.get("/") def root(): """Root endpoint - API information.""" return { "status": "Service is up and running!", "message": "Backend service is active", "api": "SAM Auto Annotation API", "version": "1.0.0" } @app.get("/health") def health_check(): """Health check endpoint.""" return {"status": "healthy", "service": "same model segmenticAPI"} @app.post("/segment") def segment(data: dict): """ Segment image using SAM2 model to convert bounding box to polygon (CVAT-style). Bbox is used as a prompt to identify the object, not as a constraint. **Input:** ```json { "imageUrl": "https://example.com/image.jpg", "bbox": {"x": 494.97, "y": 187.22, "width": 137.99, "height": 98.00, "label": "Object"}, "imageSize": {"width": 663.07, "height": 442} } ``` OR ```json { "imageUrl": "https://example.com/image.jpg", "bbox": [494.97, 187.22, 137.99, 98.00], // [x, y, width, height] "imageSize": [663.07, 442] // [width, height] } ``` **Output:** ```json { "polygon": [x1, y1, x2, y2, x3, y3, ...], // CVAT format: flattened coordinates "confidence": 0.96 } ``` """ try: # Validate input if "imageUrl" not in data: raise HTTPException(status_code=400, detail="Missing required field: imageUrl") if "bbox" not in data: raise HTTPException(status_code=400, detail="Missing required field: bbox") image_url = data["imageUrl"] bbox = data["bbox"] image_size = data.get("imageSize") # Optional: for coordinate scaling # Validate bbox format if isinstance(bbox, dict): required_keys = ["x", "y", "width", "height"] if not all(key in bbox for key in required_keys): raise HTTPException( status_code=400, detail=f"bbox dict must contain: {required_keys}" ) elif isinstance(bbox, list): if len(bbox) != 4: raise HTTPException( status_code=400, detail="bbox list must contain exactly 4 values: [x, y, width, height]" ) else: raise HTTPException( status_code=400, detail="bbox must be either a dict or a list" ) # Validate imageSize format if provided if image_size is not None: if isinstance(image_size, dict): if not ("width" in image_size and "height" in image_size): raise HTTPException( status_code=400, detail="imageSize dict must contain 'width' and 'height'" ) elif isinstance(image_size, list): if len(image_size) != 2: raise HTTPException( status_code=400, detail="imageSize list must contain exactly 2 values: [width, height]" ) else: raise HTTPException( status_code=400, detail="imageSize must be either a dict or a list" ) # Load image from URL img_bgr = load_image_from_url(image_url) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # Predict polygon using SAM2 (bbox as prompt, CVAT-style) mask, confidence, scale_factors = predict_polygon(img_rgb, bbox, image_size) # Convert mask to polygon (CVAT-style) polygon = mask_to_polygon(mask, scale_factors) if not polygon: raise HTTPException(status_code=400, detail="No polygon found in mask") return { "polygon": polygon, # CVAT format: flattened coordinates "confidence": confidence } except KeyError as e: raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except FileNotFoundError as e: raise HTTPException(status_code=500, detail=str(e)) except ImportError as e: raise HTTPException( status_code=500, detail=f"Segment Anything library not installed. Please run: pip install -e . in segment-anything directory" ) except Timeout as e: raise HTTPException( status_code=504, detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL." ) except RequestException as e: raise HTTPException( status_code=502, detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again." ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/segment/point") def segment_from_point(data: dict): """ Segment image using SAM2 model with a point click to select object. The point identifies which object to segment. **Input:** ```json { "imageUrl": "https://example.com/image.jpg", "point": {"x": 494.97, "y": 187.22}, "imageSize": {"width": 663.07, "height": 442} } ``` OR ```json { "imageUrl": "https://example.com/image.jpg", "point": [494.97, 187.22], // [x, y] "imageSize": [663.07, 442] // [width, height] } ``` **Output:** ```json { "polygon": [x1, y1, x2, y2, x3, y3, ...], // CVAT format: flattened coordinates "confidence": 0.96 } ``` """ try: # Validate input if "imageUrl" not in data: raise HTTPException(status_code=400, detail="Missing required field: imageUrl") if "point" not in data: raise HTTPException(status_code=400, detail="Missing required field: point") image_url = data["imageUrl"] point = data["point"] image_size = data.get("imageSize") # Optional: for coordinate scaling # Validate point format if isinstance(point, dict): required_keys = ["x", "y"] if not all(key in point for key in required_keys): raise HTTPException( status_code=400, detail=f"point dict must contain: {required_keys}" ) elif isinstance(point, list): if len(point) != 2: raise HTTPException( status_code=400, detail="point list must contain exactly 2 values: [x, y]" ) else: raise HTTPException( status_code=400, detail="point must be either a dict or a list" ) # Validate imageSize format if provided if image_size is not None: if isinstance(image_size, dict): if not ("width" in image_size and "height" in image_size): raise HTTPException( status_code=400, detail="imageSize dict must contain 'width' and 'height'" ) elif isinstance(image_size, list): if len(image_size) != 2: raise HTTPException( status_code=400, detail="imageSize list must contain exactly 2 values: [width, height]" ) else: raise HTTPException( status_code=400, detail="imageSize must be either a dict or a list" ) # Load image from URL img_bgr = load_image_from_url(image_url) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # Predict polygon using SAM2 (point click as prompt) mask, confidence, scale_factors = predict_polygon_from_point(img_rgb, point, image_size) # Convert mask to polygon (CVAT-style) polygon = mask_to_polygon(mask, scale_factors) if not polygon: raise HTTPException(status_code=400, detail="No polygon found in mask. Try clicking on a different point.") return { "polygon": polygon, # CVAT format: flattened coordinates "confidence": confidence } except KeyError as e: raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except FileNotFoundError as e: raise HTTPException(status_code=500, detail=str(e)) except ImportError as e: raise HTTPException( status_code=500, detail=f"Segment Anything library not installed. Please run: pip install -e . in segment-anything directory" ) except Timeout as e: raise HTTPException( status_code=504, detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL." ) except RequestException as e: raise HTTPException( status_code=502, detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again." ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/auto-annotate") def auto_annotate(data: dict): """ Automatically detect and segment all objects in an image using SAM2 from Hugging Face. Uses SAM2AutomaticMaskGenerator (facebook/sam2.1-hiera-large) to detect all objects without requiring prompts (bbox or points). **Input:** ```json { "imageUrl": "https://example.com/image.jpg", "imageSize": {"width": 663.07, "height": 442}, "minArea": 100, "minConfidence": 0.5, "maxImageDimension": 1024, "pointsPerSide": 32, "pointsPerBatch": 64, "filterObjectsOnly": true } ``` **Output:** ```json { "masks": [ { "polygon": [x1, y1, x2, y2, x3, y3, ...], "confidence": 0.93, "area": 12345 }, ... ], "count": 10, "memoryInfo": { "before_mb": 512.5, "after_mb": 1024.3, "peak_mb": 1024.3, "estimated_mb": 800.0, "memory_used_mb": 511.8 }, "imageInfo": { "wasResized": true, "originalSize": [1920, 1080], "processedSize": [1024, 576], "resizeScale": [1.875, 1.875] } } ``` """ try: # Validate input if "imageUrl" not in data: raise HTTPException(status_code=400, detail="Missing required field: imageUrl") image_url = data["imageUrl"] image_size = data.get("imageSize") # Optional: for coordinate scaling min_area = data.get("minArea", 100) # Optional: minimum mask area min_confidence = data.get("minConfidence", 0.5) # Optional: minimum confidence max_image_dimension = data.get("maxImageDimension", 1024) # Optional: max dimension before resizing # Lower default values for faster processing points_per_side = data.get("pointsPerSide", 32) # Optional: points per side (lower = faster) points_per_batch = data.get("pointsPerBatch", 64) # Optional: points per batch (lower = faster) filter_objects_only = data.get("filterObjectsOnly", False) # Optional: filter out background masks # Validate imageSize format if provided if image_size is not None: if isinstance(image_size, dict): if not ("width" in image_size and "height" in image_size): raise HTTPException( status_code=400, detail="imageSize dict must contain 'width' and 'height'" ) elif isinstance(image_size, list): if len(image_size) != 2: raise HTTPException( status_code=400, detail="imageSize list must contain exactly 2 values: [width, height]" ) else: raise HTTPException( status_code=400, detail="imageSize must be either a dict or a list" ) # Validate minArea and minConfidence try: min_area = int(min_area) if min_area < 0: raise HTTPException(status_code=400, detail="minArea must be >= 0") except (ValueError, TypeError): raise HTTPException(status_code=400, detail="minArea must be an integer") try: min_confidence = float(min_confidence) if not (0.0 <= min_confidence <= 1.0): raise HTTPException(status_code=400, detail="minConfidence must be between 0.0 and 1.0") except (ValueError, TypeError): raise HTTPException(status_code=400, detail="minConfidence must be a float between 0.0 and 1.0") # Validate maxImageDimension try: max_image_dimension = int(max_image_dimension) if max_image_dimension < 256: raise HTTPException(status_code=400, detail="maxImageDimension must be >= 256") if max_image_dimension > 4096: raise HTTPException(status_code=400, detail="maxImageDimension must be <= 4096") except (ValueError, TypeError): raise HTTPException(status_code=400, detail="maxImageDimension must be an integer between 256 and 4096") # Validate pointsPerSide try: points_per_side = int(points_per_side) if points_per_side < 8: raise HTTPException(status_code=400, detail="pointsPerSide must be >= 8") if points_per_side > 128: raise HTTPException(status_code=400, detail="pointsPerSide must be <= 128") except (ValueError, TypeError): raise HTTPException(status_code=400, detail="pointsPerSide must be an integer between 8 and 128") # Validate pointsPerBatch try: points_per_batch = int(points_per_batch) if points_per_batch < 16: raise HTTPException(status_code=400, detail="pointsPerBatch must be >= 16") if points_per_batch > 256: raise HTTPException(status_code=400, detail="pointsPerBatch must be <= 256") except (ValueError, TypeError): raise HTTPException(status_code=400, detail="pointsPerBatch must be an integer between 16 and 256") # Get memory before processing process = psutil.Process(os.getpid()) memory_before = process.memory_info().rss / (1024 * 1024) # MB # Load image from URL img_bgr = load_image_from_url(image_url) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # Resize image if needed to reduce memory usage original_h, original_w = img_rgb.shape[:2] original_size = [original_w, original_h] processed_image = img_rgb resize_scale = [1.0, 1.0] was_resized = False if max(original_h, original_w) > max_image_dimension: was_resized = True if original_h > original_w: new_h = max_image_dimension new_w = int(original_w * (max_image_dimension / original_h)) else: new_w = max_image_dimension new_h = int(original_h * (max_image_dimension / original_w)) processed_image = cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR) resize_scale = [original_w / new_w, original_h / new_h] processed_h, processed_w = processed_image.shape[:2] processed_size = [processed_w, processed_h] # Estimate memory requirements estimated_mb = ((processed_w * processed_h * 3 * 4) + (processed_w * processed_h * 256 * 4) + (processed_w * processed_h * 100 * 1)) / (1024 * 1024) # Calculate scale factors for coordinate scaling (matching predict_polygon_from_point logic) # We need to scale FROM processed image TO display size (imageSize) # mask_to_polygon expects scale_factors that represent: FROM processed TO display # It divides by these factors, so we pass (processed_w/display_w, processed_h/display_h) scale_factor_x, scale_factor_y = 1.0, 1.0 if image_size is not None: if isinstance(image_size, dict): display_w = float(image_size.get("width", processed_w)) display_h = float(image_size.get("height", processed_h)) else: display_w, display_h = float(image_size[0]), float(image_size[1]) # Calculate scale factors: FROM processed image TO display size # These will be used in mask_to_polygon: polygon / scale_factor = display coords scale_factor_x = processed_w / display_w if display_w > 0 else 1.0 scale_factor_y = processed_h / display_h if display_h > 0 else 1.0 # Get image dimensions for filtering total_image_area = processed_w * processed_h # Initialize SAM2 Auto Annotation # This uses facebook/sam2.1-hiera-large model from Hugging Face # Cache the annotation instance globally to avoid reloading on every request global sam2_auto_annotation_global if sam2_auto_annotation_global is None: try: sam2_auto_annotation_global = create_sam2_auto_annotation( points_per_side=points_per_side, points_per_batch=points_per_batch, pred_iou_thresh=0.88, stability_score_thresh=0.95, min_mask_region_area=min_area, ) except ImportError as e: raise HTTPException( status_code=500, detail=f"Failed to import required modules. Please ensure 'sam2' and 'huggingface_hub' are installed. Error: {str(e)}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to load SAM2 Auto Annotation from Hugging Face ({HUGGINGFACE_MODEL_ID}). Error: {str(e)}" ) # Generate masks using SAM2AutoAnnotation with proper scaling (matching predict_polygon_from_point) # Pass scale_factors to scale FROM processed image TO display size mask_results = sam2_auto_annotation_global.generate_masks( image=processed_image, min_confidence=min_confidence, min_area=min_area, filter_blank_regions=True, scale_factors=(scale_factor_x, scale_factor_y) ) # Get memory after processing memory_after = process.memory_info().rss / (1024 * 1024) # MB memory_used = memory_after - memory_before # Process mask results (polygons are already scaled to display size by generate_masks) results = [] for mask_result in mask_results: # Extract mask information polygon = mask_result.get("polygon") score = mask_result.get("confidence") area = mask_result.get("area") # Early filtering: Skip masks that don't meet basic criteria if area < min_area or score < min_confidence: continue # Filter out background masks if filterObjectsOnly is True if filter_objects_only: coverage_ratio = area / total_image_area if total_image_area > 0 else 0 if coverage_ratio >= 0.8: # Skip masks covering >80% (likely background) continue # Polygon is already scaled to display size by generate_masks (using mask_to_polygon with scale_factors) # Return polygon in flattened format [x1, y1, x2, y2, ...] if polygon and len(polygon) >= 6: # At least 3 points mask_obj = { "polygon": polygon # Already in flattened format and scaled to display size } if score is not None: mask_obj["confidence"] = score if area is not None: mask_obj["area"] = area results.append(mask_obj) # Build response with all required fields response = { "masks": results, "count": len(results), "memoryInfo": { "before_mb": round(memory_before, 2), "after_mb": round(memory_after, 2), "peak_mb": round(memory_after, 2), "estimated_mb": round(estimated_mb, 2), "memory_used_mb": round(memory_used, 2) }, "imageInfo": { "wasResized": was_resized, "originalSize": original_size, "processedSize": processed_size, "resizeScale": resize_scale } } return response except KeyError as e: raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except FileNotFoundError as e: raise HTTPException(status_code=500, detail=str(e)) except ImportError as e: raise HTTPException( status_code=500, detail=f"Segment Anything library not installed. Please ensure 'sam2' and 'huggingface_hub' are installed." ) except Timeout as e: raise HTTPException( status_code=504, detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL." ) except RequestException as e: raise HTTPException( status_code=502, detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again." ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")