Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import logging | |
| # Fix: Set Hugging Face cache to writable location | |
| # In containerized environments, /.cache may not be writable | |
| if "HF_HOME" not in os.environ: | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| print(f"Set HF_HOME to {os.environ['HF_HOME']}") | |
| # Debug/Fix: Unset CUDA_VISIBLE_DEVICES to ensure all GPUs are visible | |
| # Some environments (like HF Spaces) might set this to "0" by default. | |
| if "CUDA_VISIBLE_DEVICES" in os.environ: | |
| # Use print because logging config might not be set yet | |
| print(f"Found CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}. Unsetting it to enable all GPUs.") | |
| del os.environ["CUDA_VISIBLE_DEVICES"] | |
| else: | |
| print("CUDA_VISIBLE_DEVICES not set. All GPUs should be visible.") | |
| import torch | |
| try: | |
| print(f"Startup Diagnostics: Torch version {torch.__version__}, CUDA available: {torch.cuda.is_available()}, Device count: {torch.cuda.device_count()}") | |
| except Exception as e: | |
| print(f"Startup Diagnostics Error: {e}") | |
| import asyncio | |
| import json | |
| import shutil | |
| import tempfile | |
| import time | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from datetime import timedelta | |
| from pathlib import Path | |
| from typing import Optional | |
| import cv2 | |
| from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import uvicorn | |
| from inference import run_inference, run_grounded_sam2_tracking | |
| from models.depth_estimators.model_loader import list_depth_estimators | |
| from jobs.background import process_video_async | |
| from jobs.models import JobInfo, JobStatus | |
| from jobs.streaming import get_stream, get_stream_event | |
| from jobs.storage import ( | |
| get_depth_output_path, | |
| get_first_frame_depth_path, | |
| get_input_video_path, | |
| get_job_directory, | |
| get_job_storage, | |
| get_output_video_path, | |
| ) | |
| from models.segmenters.model_loader import get_segmenter_detector | |
| from pydantic import BaseModel | |
| from inspection.router import router as inspection_router | |
| logging.basicConfig(level=logging.INFO) | |
| # Suppress noisy external libraries | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logging.getLogger("huggingface_hub").setLevel(logging.WARNING) | |
| logging.getLogger("transformers").setLevel(logging.WARNING) | |
| async def _periodic_cleanup() -> None: | |
| while True: | |
| await asyncio.sleep(600) | |
| get_job_storage().cleanup_expired(timedelta(hours=1)) | |
| async def lifespan(_: FastAPI): | |
| cleanup_task = asyncio.create_task(_periodic_cleanup()) | |
| try: | |
| yield | |
| finally: | |
| cleanup_task.cancel() | |
| app = FastAPI(title="Video Object Detection", lifespan=lifespan) | |
| app.include_router(inspection_router) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| from fastapi import Request | |
| async def add_no_cache_header(request: Request, call_next): | |
| """Ensure frontend assets are not cached by the browser (important for HF Spaces updates).""" | |
| response = await call_next(request) | |
| # Apply to all static files and the root page | |
| if request.url.path.startswith("/app") or request.url.path == "/": | |
| response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" | |
| response.headers["Pragma"] = "no-cache" | |
| response.headers["Expires"] = "0" | |
| return response | |
| _FRONTEND_DIR = Path(__file__).with_name("frontend") | |
| if _FRONTEND_DIR.exists(): | |
| app.mount("/app", StaticFiles(directory=_FRONTEND_DIR, html=True), name="frontend") | |
| # Valid detection modes | |
| VALID_MODES = {"object_detection", "segmentation", "drone_detection"} | |
| # ── Chat endpoint ────────────────────────────────────────────── | |
| class ChatRequest(BaseModel): | |
| message: str | |
| mission: str = "" | |
| track_context: dict | None = None | |
| history: list[dict] = [] | |
| async def chat(req: ChatRequest): | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured") | |
| import openai | |
| client = openai.OpenAI(api_key=api_key) | |
| # Build system prompt | |
| system_parts = [ | |
| "You are a mission analyst for an ISR (Intelligence, Surveillance, Reconnaissance) platform. " | |
| "You help operators understand tracked objects, their mission relevance, and assessment results. " | |
| "Be concise and direct." | |
| ] | |
| if req.mission: | |
| system_parts.append(f"\nCurrent mission objective: {req.mission}") | |
| if req.track_context: | |
| tc = req.track_context | |
| system_parts.append( | |
| f"\nCurrently selected track context:" | |
| f"\n- Track ID: {tc.get('id', 'unknown')}" | |
| f"\n- Label: {tc.get('label', 'unknown')}" | |
| f"\n- Confidence: {tc.get('score', 'N/A')}" | |
| f"\n- Mission Relevant: {tc.get('mission_relevant', 'not assessed')}" | |
| f"\n- Satisfies Mission: {tc.get('satisfies', 'not assessed')}" | |
| f"\n- Assessment Status: {tc.get('assessment_status', 'UNASSESSED')}" | |
| f"\n- Reason: {tc.get('reason', 'none')}" | |
| f"\n- Speed: {tc.get('speed_kph', 'N/A')} kph" | |
| f"\n- Bounding Box: {tc.get('bbox', 'N/A')}" | |
| ) | |
| features = tc.get("features", {}) | |
| if features: | |
| feat_str = "\n".join(f" - {k}: {v}" for k, v in features.items()) | |
| system_parts.append(f"- Observable Features:\n{feat_str}") | |
| gpt_raw = tc.get("gpt_raw") | |
| if gpt_raw: | |
| system_parts.append(f"- Raw GPT Assessment: {json.dumps(gpt_raw)}") | |
| system_msg = "\n".join(system_parts) | |
| # Build messages array | |
| messages = [{"role": "system", "content": system_msg}] | |
| for h in req.history[-20:]: | |
| if h.get("role") in ("user", "assistant") and h.get("content"): | |
| messages.append({"role": h["role"], "content": h["content"]}) | |
| messages.append({"role": "user", "content": req.message}) | |
| try: | |
| response = await asyncio.to_thread( | |
| lambda: client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=messages, | |
| max_tokens=512, | |
| temperature=0.3, | |
| ) | |
| ) | |
| return {"response": response.choices[0].message.content} | |
| except Exception as e: | |
| logging.exception("Chat endpoint error") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def _save_upload_to_tmp(upload: UploadFile) -> str: | |
| """Save uploaded file to temporary location.""" | |
| suffix = Path(upload.filename or "upload.mp4").suffix or ".mp4" | |
| fd, path = tempfile.mkstemp(prefix="input_", suffix=suffix, dir="/tmp") | |
| os.close(fd) | |
| with open(path, "wb") as buffer: | |
| data = upload.file.read() | |
| buffer.write(data) | |
| return path | |
| def _save_upload_to_path(upload: UploadFile, path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "wb") as buffer: | |
| data = upload.file.read() | |
| buffer.write(data) | |
| def _safe_delete(path: str) -> None: | |
| """Safely delete a file, ignoring errors.""" | |
| try: | |
| os.remove(path) | |
| except FileNotFoundError: | |
| return | |
| except Exception: | |
| logging.exception("Failed to remove temporary file: %s", path) | |
| def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None: | |
| """Schedule file cleanup after response is sent.""" | |
| def _cleanup(target: str = path) -> None: | |
| _safe_delete(target) | |
| background_tasks.add_task(_cleanup) | |
| def _default_queries_for_mode(mode: str) -> list[str]: | |
| if mode == "segmentation": | |
| return ["object"] | |
| if mode == "drone_detection": | |
| return ["drone"] | |
| return ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"] | |
| async def demo_page(): | |
| """Redirect to Mission Console app.""" | |
| return RedirectResponse(url="/app/index.html") | |
| async def detect_endpoint( | |
| background_tasks: BackgroundTasks, | |
| video: UploadFile = File(...), | |
| mode: str = Form(...), | |
| queries: str = Form(""), | |
| detector: str = Form("yolo11"), | |
| segmenter: str = Form("GSAM2-L"), | |
| enable_depth: bool = Form(False), | |
| ): | |
| """ | |
| Main detection endpoint. | |
| Args: | |
| video: Video file to process | |
| mode: Detection mode (object_detection, segmentation, drone_detection) | |
| queries: Comma-separated object classes for object_detection mode | |
| detector: Model to use (yolo11, detr_resnet50, grounding_dino) | |
| segmenter: Segmentation model to use (GSAM2-S/B/L, YSAM2-S/B/L) | |
| enable_depth: Whether to run legacy depth estimation (default: False) | |
| drone_detection uses the dedicated yolov8_visdrone model. | |
| Returns: | |
| - For object_detection: Processed video with bounding boxes | |
| - For segmentation: Processed video with masks rendered | |
| - For drone_detection: Processed video with bounding boxes | |
| """ | |
| # Validate mode | |
| if mode not in VALID_MODES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}" | |
| ) | |
| if mode == "segmentation": | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| try: | |
| input_path = _save_upload_to_tmp(video) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp") | |
| os.close(fd) | |
| # Parse queries | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| if not query_list: | |
| query_list = ["object"] | |
| try: | |
| output_path = run_grounded_sam2_tracking( | |
| input_path, | |
| output_path, | |
| query_list, | |
| segmenter_name=segmenter, | |
| num_maskmem=7, | |
| ) | |
| except ValueError as exc: | |
| logging.exception("Segmentation processing failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| except Exception as exc: | |
| logging.exception("Segmentation inference failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| _schedule_cleanup(background_tasks, input_path) | |
| _schedule_cleanup(background_tasks, output_path) | |
| return FileResponse( | |
| path=output_path, | |
| media_type="video/mp4", | |
| filename="segmented.mp4", | |
| ) | |
| # Handle object detection or drone detection mode | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| # Save uploaded video | |
| try: | |
| input_path = _save_upload_to_tmp(video) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| # Create output path | |
| fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp") | |
| os.close(fd) | |
| # Parse queries | |
| detector_name = (detector or "yolov8_visdrone") if mode == "drone_detection" else detector | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] or _default_queries_for_mode(mode) | |
| if mode == "drone_detection" and not query_list: | |
| query_list = ["drone"] | |
| # Run inference | |
| try: | |
| # Determine depth estimator | |
| active_depth = "depth" if enable_depth else None | |
| output_path, _ = run_inference( | |
| input_path, | |
| output_path, | |
| query_list, | |
| detector_name=detector_name, | |
| depth_estimator_name=active_depth, | |
| depth_scale=25.0, | |
| ) | |
| except ValueError as exc: | |
| logging.exception("Video processing failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| except Exception as exc: | |
| logging.exception("Inference failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| # Schedule cleanup | |
| _schedule_cleanup(background_tasks, input_path) | |
| _schedule_cleanup(background_tasks, output_path) | |
| # Return processed video | |
| response = FileResponse( | |
| path=output_path, | |
| media_type="video/mp4", | |
| filename="processed.mp4", | |
| ) | |
| return response | |
| async def detect_async_endpoint( | |
| video: UploadFile = File(...), | |
| mode: str = Form(...), | |
| queries: str = Form(""), | |
| detector: str = Form("yolo11"), | |
| segmenter: str = Form("GSAM2-L"), | |
| depth_estimator: str = Form("depth"), | |
| depth_scale: float = Form(25.0), | |
| enable_depth: bool = Form(False), | |
| step: int = Form(7), | |
| mission: str = Form(None), | |
| ): | |
| _ttfs_t0 = time.perf_counter() | |
| if mode not in VALID_MODES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}", | |
| ) | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| job_id = uuid.uuid4().hex | |
| job_dir = get_job_directory(job_id) | |
| input_path = get_input_video_path(job_id) | |
| output_path = get_output_video_path(job_id) | |
| depth_output_path = get_depth_output_path(job_id) | |
| first_frame_depth_path = get_first_frame_depth_path(job_id) | |
| try: | |
| _save_upload_to_path(video, input_path) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| logging.info("[TTFS:%s] +%.1fs upload_saved", job_id, time.perf_counter() - _ttfs_t0) | |
| # --- Query Parsing --- | |
| detector_name = detector | |
| if mode == "drone_detection": | |
| detector_name = detector or "yolov8_visdrone" | |
| elif mode == "segmentation": | |
| # Segmenter registry owns detector selection (GSAM2→GDINO, YSAM2→YOLO). | |
| # detector_name=None so the job doesn't forward it (avoids duplicate kwarg). | |
| try: | |
| get_segmenter_detector(segmenter) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| detector_name = None | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] or _default_queries_for_mode(mode) | |
| logging.info("[TTFS:%s] +%.1fs queries_parsed", job_id, time.perf_counter() - _ttfs_t0) | |
| available_depth_estimators = set(list_depth_estimators()) | |
| if depth_estimator not in available_depth_estimators: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=( | |
| f"Invalid depth estimator '{depth_estimator}'. " | |
| f"Must be one of: {', '.join(sorted(available_depth_estimators))}" | |
| ), | |
| ) | |
| # Determine active depth estimator (Legacy) | |
| active_depth = depth_estimator if enable_depth else None | |
| job = JobInfo( | |
| job_id=job_id, | |
| status=JobStatus.PROCESSING, | |
| mode=mode, | |
| queries=query_list, | |
| detector_name=detector_name, | |
| segmenter_name=segmenter, | |
| input_video_path=str(input_path), | |
| output_video_path=str(output_path), | |
| depth_estimator_name=active_depth, | |
| depth_scale=float(depth_scale), | |
| depth_output_path=str(depth_output_path), | |
| first_frame_depth_path=str(first_frame_depth_path), | |
| step=step, | |
| ttfs_t0=_ttfs_t0, | |
| mission=mission, | |
| ) | |
| get_job_storage().create(job) | |
| asyncio.create_task(process_video_async(job_id)) | |
| return { | |
| "job_id": job_id, | |
| "status_url": f"/detect/status/{job_id}", | |
| "video_url": f"/detect/video/{job_id}", | |
| "depth_video_url": f"/detect/depth-video/{job_id}", | |
| "stream_url": f"/detect/stream/{job_id}", | |
| "status": job.status.value, | |
| } | |
| async def detect_status(job_id: str): | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found or expired.") | |
| return { | |
| "job_id": job.job_id, | |
| "status": job.status.value, | |
| "created_at": job.created_at.isoformat(), | |
| "completed_at": job.completed_at.isoformat() if job.completed_at else None, | |
| "error": job.error, | |
| } | |
| async def get_track_summary_endpoint(job_id: str): | |
| """Return per-frame detection counts for timeline heatmap.""" | |
| from jobs.storage import get_track_summary, get_job_storage | |
| import cv2 | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| summary = get_track_summary(job_id) | |
| total_frames = 0 | |
| fps = 30.0 | |
| video_path = job.output_video_path | |
| if video_path: | |
| cap = cv2.VideoCapture(video_path) | |
| if cap.isOpened(): | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| cap.release() | |
| if total_frames == 0 and summary: | |
| total_frames = max(summary.keys()) + 1 | |
| return { | |
| "total_frames": total_frames, | |
| "fps": fps, | |
| "frames": summary, | |
| } | |
| async def get_frame_tracks(job_id: str, frame_idx: int): | |
| """Retrieve detections (with tracking info) for a specific frame.""" | |
| # This requires us to store detections PER FRAME in JobStorage or similar. | |
| # Currently, inference.py returns 'sorted_detections' at the end. | |
| # But during streaming, where is it? | |
| # We can peek into the 'stream_queue' logic or we need a shared store. | |
| # Ideally, inference should write to a map/db that we can read. | |
| # Quick fix: If job is done, we might have it. If running, it's harder absent a DB. | |
| # BUT, 'stream_queue' sends frames. | |
| # Let's use a global cache in memory for active jobs? | |
| # See inference.py: 'all_detections_map' is local to that function. | |
| # BETTER APPROACH for this demo: | |
| # Use a simple shared dictionary in jobs/storage.py or app.py used by inference. | |
| # We will pass a callback or shared dict to run_inference. | |
| # For now, let's just return 404 if not implemented, but I need to implement it. | |
| # I'll add a cache in app.py for active job tracks? | |
| from jobs.storage import get_track_data | |
| data = get_track_data(job_id, frame_idx) | |
| return data or [] | |
| async def cancel_job(job_id: str): | |
| """Cancel a running job.""" | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found or expired.") | |
| if job.status != JobStatus.PROCESSING: | |
| return { | |
| "message": f"Job already {job.status.value}", | |
| "status": job.status.value, | |
| } | |
| get_job_storage().update(job_id, status=JobStatus.CANCELLED) | |
| return { | |
| "message": "Job cancellation requested", | |
| "status": "cancelled", | |
| } | |
| async def detect_video(job_id: str): | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found or expired.") | |
| if job.status == JobStatus.FAILED: | |
| raise HTTPException(status_code=500, detail=f"Job failed: {job.error}") | |
| if job.status == JobStatus.CANCELLED: | |
| raise HTTPException(status_code=410, detail="Job was cancelled") | |
| if job.status == JobStatus.PROCESSING: | |
| return JSONResponse( | |
| status_code=202, | |
| content={"detail": "Video still processing", "status": "processing"}, | |
| ) | |
| if not job.output_video_path or not Path(job.output_video_path).exists(): | |
| raise HTTPException(status_code=404, detail="Video file not found.") | |
| return FileResponse( | |
| path=job.output_video_path, | |
| media_type="video/mp4", | |
| filename="processed.mp4", | |
| ) | |
| async def detect_depth_video(job_id: str): | |
| """Return depth estimation video.""" | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found or expired.") | |
| if not job.depth_output_path: | |
| # Check if depth failed (partial success) | |
| if job.partial_success and job.depth_error: | |
| raise HTTPException(status_code=404, detail=f"Depth unavailable: {job.depth_error}") | |
| raise HTTPException(status_code=404, detail="No depth video for this job.") | |
| if job.status == JobStatus.FAILED: | |
| raise HTTPException(status_code=500, detail=f"Job failed: {job.error}") | |
| if job.status == JobStatus.CANCELLED: | |
| raise HTTPException(status_code=410, detail="Job was cancelled") | |
| if job.status == JobStatus.PROCESSING: | |
| return JSONResponse( | |
| status_code=202, | |
| content={"detail": "Video still processing", "status": "processing"}, | |
| ) | |
| if not Path(job.depth_output_path).exists(): | |
| raise HTTPException(status_code=404, detail="Depth video file not found.") | |
| return FileResponse( | |
| path=job.depth_output_path, | |
| media_type="video/mp4", | |
| filename="depth.mp4", | |
| ) | |
| async def stream_video(job_id: str): | |
| """MJPEG stream of the processing video (event-driven).""" | |
| import queue as queue_mod | |
| async def stream_generator(): | |
| STREAM_FPS = 24 | |
| FRAME_INTERVAL = 1.0 / STREAM_FPS # ~41.7ms | |
| loop = asyncio.get_running_loop() | |
| buffered = False | |
| last_yield_time = 0.0 | |
| # TTFS instrumentation | |
| _first_yielded = False | |
| _buffer_wait_logged = False | |
| _job = get_job_storage().get(job_id) | |
| _stream_t0 = _job.ttfs_t0 if _job else None | |
| if _stream_t0: | |
| logging.info("[TTFS:%s] +%.1fs stream_subscribed", job_id, time.perf_counter() - _stream_t0) | |
| # Get or create the asyncio.Event for this stream (must be in async context) | |
| event = get_stream_event(job_id) | |
| # Hold a local ref to the queue so we can drain it even after remove_stream() | |
| q = get_stream(job_id) | |
| if not q: | |
| return | |
| stream_removed = False | |
| while True: | |
| try: | |
| # Initial Buffer: Wait until we have enough frames or job is done | |
| if not buffered: | |
| if not _buffer_wait_logged and _stream_t0: | |
| logging.info("[TTFS:%s] +%.1fs stream_buffer_wait (qsize=%d)", job_id, time.perf_counter() - _stream_t0, q.qsize()) | |
| _buffer_wait_logged = True | |
| if q.qsize() < 5 and not stream_removed: | |
| await asyncio.sleep(0.1) | |
| stream_removed = get_stream(job_id) is None | |
| continue | |
| buffered = True | |
| if _stream_t0: | |
| logging.info("[TTFS:%s] +%.1fs stream_buffer_ready", job_id, time.perf_counter() - _stream_t0) | |
| # Try to get a frame from the queue first (non-blocking) | |
| frame = None | |
| try: | |
| frame = q.get_nowait() | |
| except queue_mod.Empty: | |
| pass | |
| # Only block on event when queue is actually empty | |
| if frame is None: | |
| if stream_removed: | |
| break # stream ended and queue fully drained | |
| if event is not None: | |
| try: | |
| await asyncio.wait_for(event.wait(), timeout=1.0) | |
| event.clear() | |
| except asyncio.TimeoutError: | |
| stream_removed = get_stream(job_id) is None | |
| continue | |
| else: | |
| await asyncio.sleep(FRAME_INTERVAL) | |
| # After waking, try the queue again | |
| try: | |
| frame = q.get_nowait() | |
| except queue_mod.Empty: | |
| stream_removed = get_stream(job_id) is None | |
| continue | |
| # Pace output at fixed 24fps | |
| now = time.perf_counter() | |
| wait = FRAME_INTERVAL - (now - last_yield_time) | |
| if wait > 0: | |
| await asyncio.sleep(wait) | |
| # Encode in thread pool to avoid blocking the event loop | |
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 60] | |
| success, buffer = await loop.run_in_executor(None, cv2.imencode, '.jpg', frame, encode_param) | |
| if success: | |
| last_yield_time = time.perf_counter() | |
| if not _first_yielded: | |
| _first_yielded = True | |
| if _stream_t0: | |
| logging.info("[TTFS:%s] +%.1fs first_yield_to_client", job_id, time.perf_counter() - _stream_t0) | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') | |
| except Exception: | |
| await asyncio.sleep(0.1) | |
| return StreamingResponse( | |
| stream_generator(), | |
| media_type="multipart/x-mixed-replace; boundary=frame" | |
| ) | |
| async def benchmark_endpoint( | |
| video: UploadFile = File(...), | |
| queries: str = Form("person,car,truck"), | |
| segmenter: str = Form("GSAM2-L"), | |
| step: int = Form(60), | |
| num_maskmem: Optional[int] = Form(None), | |
| ): | |
| """Run instrumented GSAM2 pipeline and return latency breakdown JSON. | |
| This is a long-running synchronous request (may take minutes). | |
| Callers should set an appropriate HTTP timeout. | |
| """ | |
| import threading | |
| # Save uploaded video to temp path | |
| input_path = tempfile.mktemp(suffix=".mp4", prefix="bench_in_") | |
| output_path = tempfile.mktemp(suffix=".mp4", prefix="bench_out_") | |
| try: | |
| with open(input_path, "wb") as f: | |
| shutil.copyfileobj(video.file, f) | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| metrics = { | |
| "end_to_end_ms": 0.0, | |
| "frame_extraction_ms": 0.0, | |
| "model_load_ms": 0.0, | |
| "init_state_ms": 0.0, | |
| "tracking_total_ms": 0.0, | |
| "gdino_total_ms": 0.0, | |
| "sam_image_total_ms": 0.0, | |
| "sam_video_total_ms": 0.0, | |
| "id_reconciliation_ms": 0.0, | |
| "render_total_ms": 0.0, | |
| "writer_total_ms": 0.0, | |
| "gpu_peak_mem_mb": 0.0, | |
| } | |
| lock = threading.Lock() | |
| await asyncio.to_thread( | |
| run_grounded_sam2_tracking, | |
| input_path, | |
| output_path, | |
| query_list, | |
| segmenter_name=segmenter, | |
| step=step, | |
| _perf_metrics=metrics, | |
| _perf_lock=lock, | |
| num_maskmem=num_maskmem, | |
| ) | |
| # Read frame count and fps from output video | |
| total_frames = 0 | |
| fps = 0.0 | |
| cap = cv2.VideoCapture(output_path) | |
| if cap.isOpened(): | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 0.0 | |
| cap.release() | |
| num_gpus = torch.cuda.device_count() | |
| return JSONResponse({ | |
| "total_frames": total_frames, | |
| "fps": fps, | |
| "num_gpus": num_gpus, | |
| "num_maskmem": num_maskmem if num_maskmem is not None else 7, | |
| "metrics": metrics, | |
| }) | |
| finally: | |
| for p in (input_path, output_path): | |
| try: | |
| os.remove(p) | |
| except OSError: | |
| pass | |
| async def gpu_monitor_endpoint(duration: int = 180, interval: int = 1): | |
| """Stream nvidia-smi dmon output for the given duration. | |
| Usage: curl 'http://.../gpu-monitor?duration=180&interval=1' | |
| Run this in one terminal while /benchmark runs in another. | |
| """ | |
| import subprocess | |
| async def _stream(): | |
| proc = subprocess.Popen( | |
| ["nvidia-smi", "dmon", "-s", "u", "-d", str(interval)], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| ) | |
| try: | |
| elapsed = 0 | |
| for line in proc.stdout: | |
| yield line | |
| if interval > 0: | |
| elapsed += interval | |
| if elapsed > duration: | |
| break | |
| finally: | |
| proc.terminate() | |
| proc.wait() | |
| return StreamingResponse(_stream(), media_type="text/plain") | |
| # --------------------------------------------------------------------------- | |
| # Benchmark Profiler & Roofline Analysis Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def benchmark_hardware(): | |
| """Return hardware specs JSON (no video needed, cached).""" | |
| import dataclasses | |
| from utils.hardware_info import get_hardware_info | |
| hw = await asyncio.to_thread(get_hardware_info) | |
| return JSONResponse(dataclasses.asdict(hw)) | |
| async def benchmark_profile( | |
| video: UploadFile = File(...), | |
| mode: str = Form("detection"), | |
| detector: str = Form("yolo11"), | |
| segmenter: str = Form("GSAM2-L"), | |
| queries: str = Form("person,car,truck"), | |
| max_frames: int = Form(100), | |
| warmup_frames: int = Form(5), | |
| step: int = Form(60), | |
| num_maskmem: Optional[int] = Form(None), | |
| ): | |
| """Run profiled inference and return per-frame timing breakdown. | |
| Args: | |
| video: Video file to profile. | |
| mode: "detection" or "segmentation". | |
| detector: Detector key (for detection mode). | |
| segmenter: Segmenter key (for segmentation mode). | |
| queries: Comma-separated object classes. | |
| max_frames: Maximum frames to profile. | |
| warmup_frames: Warmup frames (detection only). | |
| step: Keyframe interval (segmentation only). | |
| num_maskmem: SAM2 memory frames (None = model default 7). | |
| """ | |
| import dataclasses | |
| from utils.profiler import run_profiled_detection, run_profiled_segmentation | |
| if mode not in ("detection", "segmentation"): | |
| raise HTTPException(status_code=400, detail="mode must be 'detection' or 'segmentation'") | |
| input_path = _save_upload_to_tmp(video) | |
| await video.close() | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| try: | |
| if mode == "detection": | |
| result = await asyncio.to_thread( | |
| run_profiled_detection, | |
| input_path, detector, query_list, | |
| max_frames=max_frames, warmup_frames=warmup_frames, | |
| ) | |
| else: | |
| result = await asyncio.to_thread( | |
| run_profiled_segmentation, | |
| input_path, segmenter, query_list, | |
| max_frames=max_frames, step=step, | |
| num_maskmem=num_maskmem, | |
| ) | |
| except Exception as exc: | |
| _safe_delete(input_path) | |
| logging.exception("Profiling failed") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| finally: | |
| _safe_delete(input_path) | |
| # Serialize dataclass, handling any non-serializable fields | |
| out = dataclasses.asdict(result) | |
| # Include GSAM2 metrics if present | |
| gsam2 = getattr(result, "_gsam2_metrics", None) | |
| if gsam2: | |
| out["gsam2_metrics"] = gsam2 | |
| return JSONResponse(out) | |
| async def benchmark_analysis( | |
| video: UploadFile = File(...), | |
| mode: str = Form("detection"), | |
| detector: str = Form("yolo11"), | |
| segmenter: str = Form("GSAM2-L"), | |
| queries: str = Form("person,car,truck"), | |
| max_frames: int = Form(100), | |
| warmup_frames: int = Form(5), | |
| step: int = Form(60), | |
| num_maskmem: Optional[int] = Form(None), | |
| ): | |
| """Full roofline analysis: hardware + profiling + theoretical ceilings + bottleneck ID. | |
| Combines hardware extraction, profiled inference, and roofline model | |
| to identify bottlenecks and provide actionable recommendations. | |
| """ | |
| import dataclasses | |
| from utils.hardware_info import get_hardware_info | |
| from utils.profiler import run_profiled_detection, run_profiled_segmentation | |
| from utils.roofline import compute_roofline | |
| if mode not in ("detection", "segmentation"): | |
| raise HTTPException(status_code=400, detail="mode must be 'detection' or 'segmentation'") | |
| input_path = _save_upload_to_tmp(video) | |
| await video.close() | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| try: | |
| # Get hardware info (cached, fast) | |
| hardware = await asyncio.to_thread(get_hardware_info) | |
| # Run profiling | |
| if mode == "detection": | |
| profiling = await asyncio.to_thread( | |
| run_profiled_detection, | |
| input_path, detector, query_list, | |
| max_frames=max_frames, warmup_frames=warmup_frames, | |
| ) | |
| else: | |
| profiling = await asyncio.to_thread( | |
| run_profiled_segmentation, | |
| input_path, segmenter, query_list, | |
| max_frames=max_frames, step=step, | |
| num_maskmem=num_maskmem, | |
| ) | |
| # Compute roofline | |
| roofline = compute_roofline(hardware, profiling) | |
| except Exception as exc: | |
| _safe_delete(input_path) | |
| logging.exception("Benchmark analysis failed") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| finally: | |
| _safe_delete(input_path) | |
| return JSONResponse({ | |
| "hardware": dataclasses.asdict(hardware), | |
| "profiling": dataclasses.asdict(profiling), | |
| "roofline": dataclasses.asdict(roofline), | |
| }) | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |