Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from datetime import timedelta | |
| from pathlib import Path | |
| import cv2 | |
| from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse | |
| import uvicorn | |
| from inference import process_first_frame, run_inference, run_segmentation | |
| from jobs.background import process_video_async | |
| from jobs.models import JobInfo, JobStatus | |
| from jobs.storage import ( | |
| get_first_frame_path, | |
| get_input_video_path, | |
| get_job_directory, | |
| get_job_storage, | |
| get_output_video_path, | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| async def _periodic_cleanup() -> None: | |
| while True: | |
| await asyncio.sleep(600) | |
| get_job_storage().cleanup_expired(timedelta(hours=1)) | |
| async def lifespan(_: FastAPI): | |
| cleanup_task = asyncio.create_task(_periodic_cleanup()) | |
| try: | |
| yield | |
| finally: | |
| cleanup_task.cancel() | |
| app = FastAPI(title="Video Object Detection", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Valid detection modes | |
| VALID_MODES = {"object_detection", "segmentation", "drone_detection"} | |
| def _save_upload_to_tmp(upload: UploadFile) -> str: | |
| """Save uploaded file to temporary location.""" | |
| suffix = Path(upload.filename or "upload.mp4").suffix or ".mp4" | |
| fd, path = tempfile.mkstemp(prefix="input_", suffix=suffix, dir="/tmp") | |
| os.close(fd) | |
| with open(path, "wb") as buffer: | |
| data = upload.file.read() | |
| buffer.write(data) | |
| return path | |
| def _save_upload_to_path(upload: UploadFile, path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "wb") as buffer: | |
| data = upload.file.read() | |
| buffer.write(data) | |
| def _safe_delete(path: str) -> None: | |
| """Safely delete a file, ignoring errors.""" | |
| try: | |
| os.remove(path) | |
| except FileNotFoundError: | |
| return | |
| except Exception: | |
| logging.exception("Failed to remove temporary file: %s", path) | |
| def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None: | |
| """Schedule file cleanup after response is sent.""" | |
| def _cleanup(target: str = path) -> None: | |
| _safe_delete(target) | |
| background_tasks.add_task(_cleanup) | |
| def _default_queries_for_mode(mode: str) -> list[str]: | |
| if mode == "segmentation": | |
| return ["object"] | |
| if mode == "drone_detection": | |
| return ["drone"] | |
| return ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"] | |
| async def demo_page() -> str: | |
| """Serve the demo page.""" | |
| demo_path = Path(__file__).with_name("demo.html") | |
| try: | |
| return demo_path.read_text(encoding="utf-8") | |
| except FileNotFoundError: | |
| return "<h1>Demo page missing</h1>" | |
| async def detect_endpoint( | |
| background_tasks: BackgroundTasks, | |
| video: UploadFile = File(...), | |
| mode: str = Form(...), | |
| queries: str = Form(""), | |
| detector: str = Form("hf_yolov8"), | |
| segmenter: str = Form("sam3"), | |
| ): | |
| """ | |
| Main detection endpoint. | |
| Args: | |
| video: Video file to process | |
| mode: Detection mode (object_detection, segmentation, drone_detection) | |
| queries: Comma-separated object classes for object_detection mode | |
| detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino) | |
| segmenter: Segmentation model to use (sam3) | |
| drone_detection uses the dedicated drone_yolo model. | |
| Returns: | |
| - For object_detection: Processed video with bounding boxes | |
| - For segmentation: Processed video with masks rendered | |
| - For drone_detection: Processed video with bounding boxes | |
| """ | |
| # Validate mode | |
| if mode not in VALID_MODES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}" | |
| ) | |
| if mode == "segmentation": | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| try: | |
| input_path = _save_upload_to_tmp(video) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp") | |
| os.close(fd) | |
| # Parse queries | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| if not query_list: | |
| query_list = ["object"] | |
| try: | |
| output_path = run_segmentation( | |
| input_path, | |
| output_path, | |
| query_list, | |
| segmenter_name=segmenter, | |
| ) | |
| except ValueError as exc: | |
| logging.exception("Segmentation processing failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| except Exception as exc: | |
| logging.exception("Segmentation inference failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| _schedule_cleanup(background_tasks, input_path) | |
| _schedule_cleanup(background_tasks, output_path) | |
| return FileResponse( | |
| path=output_path, | |
| media_type="video/mp4", | |
| filename="segmented.mp4", | |
| ) | |
| # Handle object detection or drone detection mode | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| # Save uploaded video | |
| try: | |
| input_path = _save_upload_to_tmp(video) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| # Create output path | |
| fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp") | |
| os.close(fd) | |
| # Parse queries | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| if mode == "drone_detection" and not query_list: | |
| query_list = ["drone"] | |
| # Run inference | |
| try: | |
| detector_name = "drone_yolo" if mode == "drone_detection" else detector | |
| output_path = run_inference( | |
| input_path, | |
| output_path, | |
| query_list, | |
| detector_name=detector_name, | |
| ) | |
| except ValueError as exc: | |
| logging.exception("Video processing failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| except Exception as exc: | |
| logging.exception("Inference failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| # Schedule cleanup | |
| _schedule_cleanup(background_tasks, input_path) | |
| _schedule_cleanup(background_tasks, output_path) | |
| # Return processed video | |
| response = FileResponse( | |
| path=output_path, | |
| media_type="video/mp4", | |
| filename="processed.mp4", | |
| ) | |
| return response | |
| async def detect_async_endpoint( | |
| video: UploadFile = File(...), | |
| mode: str = Form(...), | |
| queries: str = Form(""), | |
| detector: str = Form("hf_yolov8"), | |
| segmenter: str = Form("sam3"), | |
| ): | |
| if mode not in VALID_MODES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}", | |
| ) | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| job_id = uuid.uuid4().hex | |
| job_dir = get_job_directory(job_id) | |
| input_path = get_input_video_path(job_id) | |
| output_path = get_output_video_path(job_id) | |
| first_frame_path = get_first_frame_path(job_id) | |
| try: | |
| _save_upload_to_path(video, input_path) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| query_list = [q.strip() for q in queries.split(",") if q.strip()] | |
| if not query_list: | |
| query_list = _default_queries_for_mode(mode) | |
| detector_name = detector | |
| if mode == "drone_detection": | |
| detector_name = "drone_yolo" | |
| try: | |
| processed_frame, detections = process_first_frame( | |
| str(input_path), | |
| query_list, | |
| mode=mode, | |
| detector_name=detector_name, | |
| segmenter_name=segmenter, | |
| ) | |
| cv2.imwrite(str(first_frame_path), processed_frame) | |
| except Exception: | |
| logging.exception("First-frame processing failed.") | |
| shutil.rmtree(job_dir, ignore_errors=True) | |
| raise HTTPException(status_code=500, detail="Failed to process first frame.") | |
| job = JobInfo( | |
| job_id=job_id, | |
| status=JobStatus.PROCESSING, | |
| mode=mode, | |
| queries=query_list, | |
| detector_name=detector_name, | |
| segmenter_name=segmenter, | |
| input_video_path=str(input_path), | |
| output_video_path=str(output_path), | |
| first_frame_path=str(first_frame_path), | |
| first_frame_detections=detections, | |
| ) | |
| get_job_storage().create(job) | |
| asyncio.create_task(process_video_async(job_id)) | |
| return { | |
| "job_id": job_id, | |
| "first_frame_url": f"/detect/first-frame/{job_id}", | |
| "status_url": f"/detect/status/{job_id}", | |
| "video_url": f"/detect/video/{job_id}", | |
| "status": job.status.value, | |
| "first_frame_detections": detections, | |
| } | |
| async def detect_status(job_id: str): | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found or expired.") | |
| return { | |
| "job_id": job.job_id, | |
| "status": job.status.value, | |
| "created_at": job.created_at.isoformat(), | |
| "completed_at": job.completed_at.isoformat() if job.completed_at else None, | |
| "error": job.error, | |
| } | |
| async def detect_first_frame(job_id: str): | |
| job = get_job_storage().get(job_id) | |
| if not job or not Path(job.first_frame_path).exists(): | |
| raise HTTPException(status_code=404, detail="First frame not found.") | |
| return FileResponse( | |
| path=job.first_frame_path, | |
| media_type="image/jpeg", | |
| filename="first_frame.jpg", | |
| ) | |
| async def detect_video(job_id: str): | |
| job = get_job_storage().get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found or expired.") | |
| if job.status == JobStatus.FAILED: | |
| raise HTTPException(status_code=500, detail=f"Job failed: {job.error}") | |
| if job.status == JobStatus.PROCESSING: | |
| return JSONResponse( | |
| status_code=202, | |
| content={"detail": "Video still processing", "status": "processing"}, | |
| ) | |
| if not job.output_video_path or not Path(job.output_video_path).exists(): | |
| raise HTTPException(status_code=404, detail="Video file not found.") | |
| return FileResponse( | |
| path=job.output_video_path, | |
| media_type="video/mp4", | |
| filename="processed.mp4", | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |