ISR / app.py
Zhen Ye
added async first frame/video detection
f0b6460
raw
history blame
11.9 kB
import asyncio
import logging
import os
import shutil
import tempfile
import uuid
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
import cv2
from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
import uvicorn
from inference import process_first_frame, run_inference, run_segmentation
from jobs.background import process_video_async
from jobs.models import JobInfo, JobStatus
from jobs.storage import (
get_first_frame_path,
get_input_video_path,
get_job_directory,
get_job_storage,
get_output_video_path,
)
logging.basicConfig(level=logging.INFO)
async def _periodic_cleanup() -> None:
while True:
await asyncio.sleep(600)
get_job_storage().cleanup_expired(timedelta(hours=1))
@asynccontextmanager
async def lifespan(_: FastAPI):
cleanup_task = asyncio.create_task(_periodic_cleanup())
try:
yield
finally:
cleanup_task.cancel()
app = FastAPI(title="Video Object Detection", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Valid detection modes
VALID_MODES = {"object_detection", "segmentation", "drone_detection"}
def _save_upload_to_tmp(upload: UploadFile) -> str:
"""Save uploaded file to temporary location."""
suffix = Path(upload.filename or "upload.mp4").suffix or ".mp4"
fd, path = tempfile.mkstemp(prefix="input_", suffix=suffix, dir="/tmp")
os.close(fd)
with open(path, "wb") as buffer:
data = upload.file.read()
buffer.write(data)
return path
def _save_upload_to_path(upload: UploadFile, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as buffer:
data = upload.file.read()
buffer.write(data)
def _safe_delete(path: str) -> None:
"""Safely delete a file, ignoring errors."""
try:
os.remove(path)
except FileNotFoundError:
return
except Exception:
logging.exception("Failed to remove temporary file: %s", path)
def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None:
"""Schedule file cleanup after response is sent."""
def _cleanup(target: str = path) -> None:
_safe_delete(target)
background_tasks.add_task(_cleanup)
def _default_queries_for_mode(mode: str) -> list[str]:
if mode == "segmentation":
return ["object"]
if mode == "drone_detection":
return ["drone"]
return ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]
@app.get("/", response_class=HTMLResponse)
async def demo_page() -> str:
"""Serve the demo page."""
demo_path = Path(__file__).with_name("demo.html")
try:
return demo_path.read_text(encoding="utf-8")
except FileNotFoundError:
return "<h1>Demo page missing</h1>"
@app.post("/detect")
async def detect_endpoint(
background_tasks: BackgroundTasks,
video: UploadFile = File(...),
mode: str = Form(...),
queries: str = Form(""),
detector: str = Form("hf_yolov8"),
segmenter: str = Form("sam3"),
):
"""
Main detection endpoint.
Args:
video: Video file to process
mode: Detection mode (object_detection, segmentation, drone_detection)
queries: Comma-separated object classes for object_detection mode
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
segmenter: Segmentation model to use (sam3)
drone_detection uses the dedicated drone_yolo model.
Returns:
- For object_detection: Processed video with bounding boxes
- For segmentation: Processed video with masks rendered
- For drone_detection: Processed video with bounding boxes
"""
# Validate mode
if mode not in VALID_MODES:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}"
)
if mode == "segmentation":
if video is None:
raise HTTPException(status_code=400, detail="Video file is required.")
try:
input_path = _save_upload_to_tmp(video)
except Exception:
logging.exception("Failed to save uploaded file.")
raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
finally:
await video.close()
fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp")
os.close(fd)
# Parse queries
query_list = [q.strip() for q in queries.split(",") if q.strip()]
if not query_list:
query_list = ["object"]
try:
output_path = run_segmentation(
input_path,
output_path,
query_list,
segmenter_name=segmenter,
)
except ValueError as exc:
logging.exception("Segmentation processing failed.")
_safe_delete(input_path)
_safe_delete(output_path)
raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc:
logging.exception("Segmentation inference failed.")
_safe_delete(input_path)
_safe_delete(output_path)
return JSONResponse(status_code=500, content={"error": str(exc)})
_schedule_cleanup(background_tasks, input_path)
_schedule_cleanup(background_tasks, output_path)
return FileResponse(
path=output_path,
media_type="video/mp4",
filename="segmented.mp4",
)
# Handle object detection or drone detection mode
if video is None:
raise HTTPException(status_code=400, detail="Video file is required.")
# Save uploaded video
try:
input_path = _save_upload_to_tmp(video)
except Exception:
logging.exception("Failed to save uploaded file.")
raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
finally:
await video.close()
# Create output path
fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp")
os.close(fd)
# Parse queries
query_list = [q.strip() for q in queries.split(",") if q.strip()]
if mode == "drone_detection" and not query_list:
query_list = ["drone"]
# Run inference
try:
detector_name = "drone_yolo" if mode == "drone_detection" else detector
output_path = run_inference(
input_path,
output_path,
query_list,
detector_name=detector_name,
)
except ValueError as exc:
logging.exception("Video processing failed.")
_safe_delete(input_path)
_safe_delete(output_path)
raise HTTPException(status_code=500, detail=str(exc))
except Exception as exc:
logging.exception("Inference failed.")
_safe_delete(input_path)
_safe_delete(output_path)
return JSONResponse(status_code=500, content={"error": str(exc)})
# Schedule cleanup
_schedule_cleanup(background_tasks, input_path)
_schedule_cleanup(background_tasks, output_path)
# Return processed video
response = FileResponse(
path=output_path,
media_type="video/mp4",
filename="processed.mp4",
)
return response
@app.post("/detect/async")
async def detect_async_endpoint(
video: UploadFile = File(...),
mode: str = Form(...),
queries: str = Form(""),
detector: str = Form("hf_yolov8"),
segmenter: str = Form("sam3"),
):
if mode not in VALID_MODES:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}",
)
if video is None:
raise HTTPException(status_code=400, detail="Video file is required.")
job_id = uuid.uuid4().hex
job_dir = get_job_directory(job_id)
input_path = get_input_video_path(job_id)
output_path = get_output_video_path(job_id)
first_frame_path = get_first_frame_path(job_id)
try:
_save_upload_to_path(video, input_path)
except Exception:
logging.exception("Failed to save uploaded file.")
raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
finally:
await video.close()
query_list = [q.strip() for q in queries.split(",") if q.strip()]
if not query_list:
query_list = _default_queries_for_mode(mode)
detector_name = detector
if mode == "drone_detection":
detector_name = "drone_yolo"
try:
processed_frame, detections = process_first_frame(
str(input_path),
query_list,
mode=mode,
detector_name=detector_name,
segmenter_name=segmenter,
)
cv2.imwrite(str(first_frame_path), processed_frame)
except Exception:
logging.exception("First-frame processing failed.")
shutil.rmtree(job_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Failed to process first frame.")
job = JobInfo(
job_id=job_id,
status=JobStatus.PROCESSING,
mode=mode,
queries=query_list,
detector_name=detector_name,
segmenter_name=segmenter,
input_video_path=str(input_path),
output_video_path=str(output_path),
first_frame_path=str(first_frame_path),
first_frame_detections=detections,
)
get_job_storage().create(job)
asyncio.create_task(process_video_async(job_id))
return {
"job_id": job_id,
"first_frame_url": f"/detect/first-frame/{job_id}",
"status_url": f"/detect/status/{job_id}",
"video_url": f"/detect/video/{job_id}",
"status": job.status.value,
"first_frame_detections": detections,
}
@app.get("/detect/status/{job_id}")
async def detect_status(job_id: str):
job = get_job_storage().get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found or expired.")
return {
"job_id": job.job_id,
"status": job.status.value,
"created_at": job.created_at.isoformat(),
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
"error": job.error,
}
@app.get("/detect/first-frame/{job_id}")
async def detect_first_frame(job_id: str):
job = get_job_storage().get(job_id)
if not job or not Path(job.first_frame_path).exists():
raise HTTPException(status_code=404, detail="First frame not found.")
return FileResponse(
path=job.first_frame_path,
media_type="image/jpeg",
filename="first_frame.jpg",
)
@app.get("/detect/video/{job_id}")
async def detect_video(job_id: str):
job = get_job_storage().get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found or expired.")
if job.status == JobStatus.FAILED:
raise HTTPException(status_code=500, detail=f"Job failed: {job.error}")
if job.status == JobStatus.PROCESSING:
return JSONResponse(
status_code=202,
content={"detail": "Video still processing", "status": "processing"},
)
if not job.output_video_path or not Path(job.output_video_path).exists():
raise HTTPException(status_code=404, detail="Video file not found.")
return FileResponse(
path=job.output_video_path,
media_type="video/mp4",
filename="processed.mp4",
)
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)