Spaces:
Sleeping
Sleeping
| # src/api.py | |
| # FastAPI application using an async task queue. | |
| # | |
| # Important: | |
| # The API does NOT receive the full orthomosaic. | |
| # It receives a small georeferenced crop GeoTIFF uploaded as multipart/form-data. | |
| import asyncio | |
| import logging | |
| import os | |
| import tempfile | |
| import time | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| from fastapi import ( | |
| FastAPI, | |
| HTTPException, | |
| status, | |
| UploadFile, | |
| File, | |
| Form, | |
| ) | |
| from fastapi.responses import FileResponse | |
| from starlette.background import BackgroundTask | |
| from src.infer import run_geoglyph_sam2_on_crop | |
| logger = logging.getLogger("api") | |
| # --------------------------------------------------------------------------- | |
| # Global state | |
| # --------------------------------------------------------------------------- | |
| RESULTS_DIR = Path(tempfile.gettempdir()) / "geoglyph_sam2_api" | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| TASKS: Dict[str, Dict[str, Any]] = {} | |
| # Single queue. One worker means one SAM2 inference at a time. | |
| # This avoids GPU OOM when several users submit jobs. | |
| QUEUE: asyncio.Queue = asyncio.Queue(maxsize=20) | |
| # --------------------------------------------------------------------------- | |
| # Internal job object | |
| # --------------------------------------------------------------------------- | |
| class InferenceJob: | |
| crop_path: str | |
| output_gpkg: str | |
| device: Optional[str] = None | |
| use_clahe: bool = True | |
| clahe_clip: float = 4.0 | |
| clahe_grid: int = 6 | |
| sam2_points_per_side: int = 32 | |
| sam2_points_per_batch: int = 32 | |
| sam2_pred_iou_thresh: float = 0.35 | |
| sam2_stability_score_thresh: float = 0.65 | |
| filter_min_area_px: int = 1000 | |
| filter_max_area_frac: float = 0.20 | |
| filter_min_iou: float = 0.35 | |
| filter_min_stability: float = 0.65 | |
| filter_border_margin: int = 10 | |
| max_crop_side: int = 4096 | |
| # --------------------------------------------------------------------------- | |
| # Worker | |
| # --------------------------------------------------------------------------- | |
| async def queue_worker(): | |
| """ | |
| Background worker. | |
| It processes tasks sequentially to avoid GPU/CPU contention and GPU OOM. | |
| Heavy SAM2 inference runs in a separate thread. | |
| """ | |
| logger.info("Task queue worker started.") | |
| while True: | |
| try: | |
| task_id, job = await QUEUE.get() | |
| except asyncio.CancelledError: | |
| logger.info("Task queue worker cancelled.") | |
| break | |
| except Exception as exc: | |
| logger.error("Error retrieving task from queue: %s", exc, exc_info=True) | |
| continue | |
| try: | |
| if task_id not in TASKS: | |
| continue | |
| TASKS[task_id]["status"] = "processing" | |
| TASKS[task_id]["started_at"] = time.time() | |
| logger.info("Worker started task_id=%s", task_id) | |
| result = await asyncio.to_thread( | |
| run_geoglyph_sam2_on_crop, | |
| crop_tif_path=job.crop_path, | |
| output_gpkg=job.output_gpkg, | |
| device=job.device, | |
| use_clahe=job.use_clahe, | |
| clahe_clip=job.clahe_clip, | |
| clahe_grid=job.clahe_grid, | |
| sam2_points_per_side=job.sam2_points_per_side, | |
| sam2_points_per_batch=job.sam2_points_per_batch, | |
| sam2_pred_iou_thresh=job.sam2_pred_iou_thresh, | |
| sam2_stability_score_thresh=job.sam2_stability_score_thresh, | |
| filter_min_area_px=job.filter_min_area_px, | |
| filter_max_area_frac=job.filter_max_area_frac, | |
| filter_min_iou=job.filter_min_iou, | |
| filter_min_stability=job.filter_min_stability, | |
| filter_border_margin=job.filter_border_margin, | |
| max_crop_side=job.max_crop_side, | |
| ) | |
| TASKS[task_id].update( | |
| { | |
| "status": "completed", | |
| "finished_at": time.time(), | |
| "n_masks": result["n_masks"], | |
| "output_exists": result["output_exists"], | |
| "result": result, | |
| "download_url": f"/download/{task_id}" | |
| if result["output_exists"] | |
| else None, | |
| } | |
| ) | |
| logger.info( | |
| "Worker completed task_id=%s n_masks=%d output_exists=%s", | |
| task_id, | |
| result["n_masks"], | |
| result["output_exists"], | |
| ) | |
| except Exception as exc: | |
| logger.error( | |
| "Worker failed task_id=%s: %s", | |
| task_id, | |
| exc, | |
| exc_info=True, | |
| ) | |
| if task_id in TASKS: | |
| TASKS[task_id].update( | |
| { | |
| "status": "failed", | |
| "finished_at": time.time(), | |
| "error": str(exc), | |
| } | |
| ) | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Cleared CUDA cache after task failure.") | |
| except Exception: | |
| pass | |
| finally: | |
| # Crop is no longer needed after processing. | |
| try: | |
| if os.path.exists(job.crop_path): | |
| os.remove(job.crop_path) | |
| logger.info("Deleted temporary crop for task_id=%s", task_id) | |
| except Exception as exc: | |
| logger.warning( | |
| "Could not delete temporary crop for task_id=%s: %s", | |
| task_id, | |
| exc, | |
| ) | |
| QUEUE.task_done() | |
| # --------------------------------------------------------------------------- | |
| # Lifespan | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| worker_task = asyncio.create_task(queue_worker()) | |
| yield | |
| worker_task.cancel() | |
| try: | |
| await worker_task | |
| except asyncio.CancelledError: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="GeoGlyph SAM2 API", | |
| description=( | |
| "Backend API for geoglyph detection using SAM2. " | |
| "Receives small georeferenced crop GeoTIFFs, not full orthomosaics." | |
| ), | |
| version="3.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def health_check(): | |
| return { | |
| "status": "ok", | |
| "message": "GeoGlyph SAM2 API is running.", | |
| "queue_size": QUEUE.qsize(), | |
| "results_dir": str(RESULTS_DIR), | |
| } | |
| async def process_geoglyphs( | |
| crop: UploadFile = File(...), | |
| device: Optional[str] = Form(None), | |
| use_clahe: bool = Form(True), | |
| clahe_clip: float = Form(4.0), | |
| clahe_grid: int = Form(6), | |
| sam2_points_per_side: int = Form(32), | |
| sam2_points_per_batch: int = Form(32), | |
| sam2_pred_iou_thresh: float = Form(0.35), | |
| sam2_stability_score_thresh: float = Form(0.65), | |
| filter_min_area_px: int = Form(1000), | |
| filter_max_area_frac: float = Form(0.20), | |
| filter_min_iou: float = Form(0.35), | |
| filter_min_stability: float = Form(0.65), | |
| filter_border_margin: int = Form(10), | |
| max_crop_side: int = Form(4096), | |
| ): | |
| """ | |
| Submit a SAM2 inference job. | |
| The client uploads a georeferenced crop GeoTIFF. | |
| The API never receives the original orthomosaic. | |
| """ | |
| if QUEUE.full(): | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Task queue is full. Try again later.", | |
| ) | |
| if device not in {None, "cuda", "cpu"}: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Invalid device. Expected 'cuda', 'cpu', or omitted.", | |
| ) | |
| task_id = uuid.uuid4().hex[:12] | |
| crop_path = RESULTS_DIR / f"{task_id}_crop.tif" | |
| output_gpkg = RESULTS_DIR / f"{task_id}.gpkg" | |
| try: | |
| with open(crop_path, "wb") as f: | |
| while True: | |
| chunk = await crop.read(1024 * 1024) | |
| if not chunk: | |
| break | |
| f.write(chunk) | |
| await crop.close() | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Could not save uploaded crop: {exc}", | |
| ) | |
| if not crop_path.is_file() or crop_path.stat().st_size == 0: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Uploaded crop is empty or could not be saved.", | |
| ) | |
| job = InferenceJob( | |
| crop_path=str(crop_path), | |
| output_gpkg=str(output_gpkg), | |
| device=device, | |
| use_clahe=use_clahe, | |
| clahe_clip=clahe_clip, | |
| clahe_grid=clahe_grid, | |
| sam2_points_per_side=sam2_points_per_side, | |
| sam2_points_per_batch=sam2_points_per_batch, | |
| sam2_pred_iou_thresh=sam2_pred_iou_thresh, | |
| sam2_stability_score_thresh=sam2_stability_score_thresh, | |
| filter_min_area_px=filter_min_area_px, | |
| filter_max_area_frac=filter_max_area_frac, | |
| filter_min_iou=filter_min_iou, | |
| filter_min_stability=filter_min_stability, | |
| filter_border_margin=filter_border_margin, | |
| max_crop_side=max_crop_side, | |
| ) | |
| TASKS[task_id] = { | |
| "status": "pending", | |
| "created_at": time.time(), | |
| "original_filename": crop.filename, | |
| "crop_path": str(crop_path), | |
| "output_gpkg": str(output_gpkg), | |
| "n_masks": None, | |
| "output_exists": None, | |
| "error": None, | |
| } | |
| await QUEUE.put((task_id, job)) | |
| logger.info( | |
| "Enqueued task_id=%s filename=%s queue_size=%d", | |
| task_id, | |
| crop.filename, | |
| QUEUE.qsize(), | |
| ) | |
| return { | |
| "task_id": task_id, | |
| "status": "pending", | |
| "queue_size": QUEUE.qsize(), | |
| } | |
| async def get_status(task_id: str): | |
| """ | |
| Check task status. | |
| Possible statuses: | |
| - pending | |
| - processing | |
| - completed | |
| - failed | |
| """ | |
| if task_id not in TASKS: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| task_info = TASKS[task_id] | |
| if task_info["status"] == "pending": | |
| pending_ids = [ | |
| tid | |
| for tid, data in TASKS.items() | |
| if data["status"] == "pending" | |
| ] | |
| try: | |
| position = pending_ids.index(task_id) + 1 | |
| except ValueError: | |
| position = 0 | |
| return { | |
| "task_id": task_id, | |
| "status": "pending", | |
| "queue_position": position, | |
| } | |
| if task_info["status"] == "processing": | |
| return { | |
| "task_id": task_id, | |
| "status": "processing", | |
| "started_at": task_info.get("started_at"), | |
| } | |
| if task_info["status"] == "completed": | |
| return { | |
| "task_id": task_id, | |
| "status": "completed", | |
| "n_masks": task_info.get("n_masks"), | |
| "output_exists": task_info.get("output_exists"), | |
| "download_url": task_info.get("download_url"), | |
| "result": task_info.get("result"), | |
| } | |
| if task_info["status"] == "failed": | |
| return { | |
| "task_id": task_id, | |
| "status": "failed", | |
| "error": task_info.get("error"), | |
| } | |
| return task_info | |
| def cleanup_task(task_id: str): | |
| """ | |
| Delete generated files and remove task metadata. | |
| Called after successful download. | |
| """ | |
| task = TASKS.get(task_id, {}) | |
| paths_to_delete = [ | |
| task.get("output_gpkg"), | |
| task.get("crop_path"), | |
| ] | |
| for path_str in paths_to_delete: | |
| if not path_str: | |
| continue | |
| path = Path(path_str) | |
| if path.exists(): | |
| try: | |
| path.unlink() | |
| logger.info("Deleted file for task_id=%s: %s", task_id, path) | |
| except Exception as exc: | |
| logger.warning( | |
| "Could not delete file for task_id=%s: %s", | |
| task_id, | |
| exc, | |
| ) | |
| if task_id in TASKS: | |
| del TASKS[task_id] | |
| async def download_result(task_id: str): | |
| """ | |
| Download the generated GeoPackage. | |
| The task is cleaned after transfer. | |
| """ | |
| if task_id not in TASKS: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| task = TASKS[task_id] | |
| if task["status"] != "completed": | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Task is not completed yet.", | |
| ) | |
| if not task.get("output_exists"): | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Task completed but no GeoPackage was created. No masks passed the filters.", | |
| ) | |
| gpkg_path = Path(task["output_gpkg"]) | |
| if not gpkg_path.is_file(): | |
| raise HTTPException( | |
| status_code=404, | |
| detail="Result file is missing.", | |
| ) | |
| logger.info("Serving GPKG download | task_id=%s", task_id) | |
| return FileResponse( | |
| path=str(gpkg_path), | |
| media_type="application/geopackage+sqlite3", | |
| filename=f"sam2_result_{task_id}.gpkg", | |
| background=BackgroundTask(cleanup_task, task_id), | |
| ) |