# src/api.py # FastAPI application using an async task queue. # # Important: # The API does NOT receive the full orthomosaic. # It receives a small georeferenced crop GeoTIFF uploaded as multipart/form-data. import asyncio import logging import os import tempfile import time import uuid from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path from typing import Optional, Dict, Any from fastapi import ( FastAPI, HTTPException, status, UploadFile, File, Form, ) from fastapi.responses import FileResponse from starlette.background import BackgroundTask from src.infer import run_geoglyph_sam2_on_crop logger = logging.getLogger("api") # --------------------------------------------------------------------------- # Global state # --------------------------------------------------------------------------- RESULTS_DIR = Path(tempfile.gettempdir()) / "geoglyph_sam2_api" RESULTS_DIR.mkdir(parents=True, exist_ok=True) TASKS: Dict[str, Dict[str, Any]] = {} # Single queue. One worker means one SAM2 inference at a time. # This avoids GPU OOM when several users submit jobs. QUEUE: asyncio.Queue = asyncio.Queue(maxsize=20) # --------------------------------------------------------------------------- # Internal job object # --------------------------------------------------------------------------- @dataclass class InferenceJob: crop_path: str output_gpkg: str device: Optional[str] = None use_clahe: bool = True clahe_clip: float = 4.0 clahe_grid: int = 6 sam2_points_per_side: int = 32 sam2_points_per_batch: int = 32 sam2_pred_iou_thresh: float = 0.35 sam2_stability_score_thresh: float = 0.65 filter_min_area_px: int = 1000 filter_max_area_frac: float = 0.20 filter_min_iou: float = 0.35 filter_min_stability: float = 0.65 filter_border_margin: int = 10 max_crop_side: int = 4096 # --------------------------------------------------------------------------- # Worker # --------------------------------------------------------------------------- async def queue_worker(): """ Background worker. It processes tasks sequentially to avoid GPU/CPU contention and GPU OOM. Heavy SAM2 inference runs in a separate thread. """ logger.info("Task queue worker started.") while True: try: task_id, job = await QUEUE.get() except asyncio.CancelledError: logger.info("Task queue worker cancelled.") break except Exception as exc: logger.error("Error retrieving task from queue: %s", exc, exc_info=True) continue try: if task_id not in TASKS: continue TASKS[task_id]["status"] = "processing" TASKS[task_id]["started_at"] = time.time() logger.info("Worker started task_id=%s", task_id) result = await asyncio.to_thread( run_geoglyph_sam2_on_crop, crop_tif_path=job.crop_path, output_gpkg=job.output_gpkg, device=job.device, use_clahe=job.use_clahe, clahe_clip=job.clahe_clip, clahe_grid=job.clahe_grid, sam2_points_per_side=job.sam2_points_per_side, sam2_points_per_batch=job.sam2_points_per_batch, sam2_pred_iou_thresh=job.sam2_pred_iou_thresh, sam2_stability_score_thresh=job.sam2_stability_score_thresh, filter_min_area_px=job.filter_min_area_px, filter_max_area_frac=job.filter_max_area_frac, filter_min_iou=job.filter_min_iou, filter_min_stability=job.filter_min_stability, filter_border_margin=job.filter_border_margin, max_crop_side=job.max_crop_side, ) TASKS[task_id].update( { "status": "completed", "finished_at": time.time(), "n_masks": result["n_masks"], "output_exists": result["output_exists"], "result": result, "download_url": f"/download/{task_id}" if result["output_exists"] else None, } ) logger.info( "Worker completed task_id=%s n_masks=%d output_exists=%s", task_id, result["n_masks"], result["output_exists"], ) except Exception as exc: logger.error( "Worker failed task_id=%s: %s", task_id, exc, exc_info=True, ) if task_id in TASKS: TASKS[task_id].update( { "status": "failed", "finished_at": time.time(), "error": str(exc), } ) try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Cleared CUDA cache after task failure.") except Exception: pass finally: # Crop is no longer needed after processing. try: if os.path.exists(job.crop_path): os.remove(job.crop_path) logger.info("Deleted temporary crop for task_id=%s", task_id) except Exception as exc: logger.warning( "Could not delete temporary crop for task_id=%s: %s", task_id, exc, ) QUEUE.task_done() # --------------------------------------------------------------------------- # Lifespan # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): worker_task = asyncio.create_task(queue_worker()) yield worker_task.cancel() try: await worker_task except asyncio.CancelledError: pass # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI( title="GeoGlyph SAM2 API", description=( "Backend API for geoglyph detection using SAM2. " "Receives small georeferenced crop GeoTIFFs, not full orthomosaics." ), version="3.0.0", lifespan=lifespan, ) # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/health", status_code=status.HTTP_200_OK) async def health_check(): return { "status": "ok", "message": "GeoGlyph SAM2 API is running.", "queue_size": QUEUE.qsize(), "results_dir": str(RESULTS_DIR), } @app.post("/process", status_code=status.HTTP_202_ACCEPTED) async def process_geoglyphs( crop: UploadFile = File(...), device: Optional[str] = Form(None), use_clahe: bool = Form(True), clahe_clip: float = Form(4.0), clahe_grid: int = Form(6), sam2_points_per_side: int = Form(32), sam2_points_per_batch: int = Form(32), sam2_pred_iou_thresh: float = Form(0.35), sam2_stability_score_thresh: float = Form(0.65), filter_min_area_px: int = Form(1000), filter_max_area_frac: float = Form(0.20), filter_min_iou: float = Form(0.35), filter_min_stability: float = Form(0.65), filter_border_margin: int = Form(10), max_crop_side: int = Form(4096), ): """ Submit a SAM2 inference job. The client uploads a georeferenced crop GeoTIFF. The API never receives the original orthomosaic. """ if QUEUE.full(): raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Task queue is full. Try again later.", ) if device not in {None, "cuda", "cpu"}: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid device. Expected 'cuda', 'cpu', or omitted.", ) task_id = uuid.uuid4().hex[:12] crop_path = RESULTS_DIR / f"{task_id}_crop.tif" output_gpkg = RESULTS_DIR / f"{task_id}.gpkg" try: with open(crop_path, "wb") as f: while True: chunk = await crop.read(1024 * 1024) if not chunk: break f.write(chunk) await crop.close() except Exception as exc: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Could not save uploaded crop: {exc}", ) if not crop_path.is_file() or crop_path.stat().st_size == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Uploaded crop is empty or could not be saved.", ) job = InferenceJob( crop_path=str(crop_path), output_gpkg=str(output_gpkg), device=device, use_clahe=use_clahe, clahe_clip=clahe_clip, clahe_grid=clahe_grid, sam2_points_per_side=sam2_points_per_side, sam2_points_per_batch=sam2_points_per_batch, sam2_pred_iou_thresh=sam2_pred_iou_thresh, sam2_stability_score_thresh=sam2_stability_score_thresh, filter_min_area_px=filter_min_area_px, filter_max_area_frac=filter_max_area_frac, filter_min_iou=filter_min_iou, filter_min_stability=filter_min_stability, filter_border_margin=filter_border_margin, max_crop_side=max_crop_side, ) TASKS[task_id] = { "status": "pending", "created_at": time.time(), "original_filename": crop.filename, "crop_path": str(crop_path), "output_gpkg": str(output_gpkg), "n_masks": None, "output_exists": None, "error": None, } await QUEUE.put((task_id, job)) logger.info( "Enqueued task_id=%s filename=%s queue_size=%d", task_id, crop.filename, QUEUE.qsize(), ) return { "task_id": task_id, "status": "pending", "queue_size": QUEUE.qsize(), } @app.get("/status/{task_id}", status_code=status.HTTP_200_OK) async def get_status(task_id: str): """ Check task status. Possible statuses: - pending - processing - completed - failed """ if task_id not in TASKS: raise HTTPException(status_code=404, detail="Task not found") task_info = TASKS[task_id] if task_info["status"] == "pending": pending_ids = [ tid for tid, data in TASKS.items() if data["status"] == "pending" ] try: position = pending_ids.index(task_id) + 1 except ValueError: position = 0 return { "task_id": task_id, "status": "pending", "queue_position": position, } if task_info["status"] == "processing": return { "task_id": task_id, "status": "processing", "started_at": task_info.get("started_at"), } if task_info["status"] == "completed": return { "task_id": task_id, "status": "completed", "n_masks": task_info.get("n_masks"), "output_exists": task_info.get("output_exists"), "download_url": task_info.get("download_url"), "result": task_info.get("result"), } if task_info["status"] == "failed": return { "task_id": task_id, "status": "failed", "error": task_info.get("error"), } return task_info def cleanup_task(task_id: str): """ Delete generated files and remove task metadata. Called after successful download. """ task = TASKS.get(task_id, {}) paths_to_delete = [ task.get("output_gpkg"), task.get("crop_path"), ] for path_str in paths_to_delete: if not path_str: continue path = Path(path_str) if path.exists(): try: path.unlink() logger.info("Deleted file for task_id=%s: %s", task_id, path) except Exception as exc: logger.warning( "Could not delete file for task_id=%s: %s", task_id, exc, ) if task_id in TASKS: del TASKS[task_id] @app.get("/download/{task_id}") async def download_result(task_id: str): """ Download the generated GeoPackage. The task is cleaned after transfer. """ if task_id not in TASKS: raise HTTPException(status_code=404, detail="Task not found") task = TASKS[task_id] if task["status"] != "completed": raise HTTPException( status_code=400, detail="Task is not completed yet.", ) if not task.get("output_exists"): raise HTTPException( status_code=404, detail="Task completed but no GeoPackage was created. No masks passed the filters.", ) gpkg_path = Path(task["output_gpkg"]) if not gpkg_path.is_file(): raise HTTPException( status_code=404, detail="Result file is missing.", ) logger.info("Serving GPKG download | task_id=%s", task_id) return FileResponse( path=str(gpkg_path), media_type="application/geopackage+sqlite3", filename=f"sam2_result_{task_id}.gpkg", background=BackgroundTask(cleanup_task, task_id), )