Spaces:
Sleeping
Sleeping
Zhen Ye
commited on
Commit
·
b30e7a3
1
Parent(s):
1ea465c
inital commit
Browse files- .gitignore +8 -0
- Dockerfile +21 -0
- app.py +208 -0
- coco_classes.py +163 -0
- demo.html +618 -0
- inference.py +182 -0
- models/detectors/base.py +19 -0
- models/detectors/detr.py +48 -0
- models/detectors/grounding_dino.py +56 -0
- models/detectors/owlv2.py +56 -0
- models/detectors/yolov8.py +69 -0
- models/model_loader.py +43 -0
- models/segmenters/__init__.py +10 -0
- models/segmenters/base.py +29 -0
- models/segmenters/model_loader.py +44 -0
- models/segmenters/sam3.py +134 -0
- requirements.txt +13 -0
- utils/video.py +79 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.venv/
|
| 3 |
+
*.mp4
|
| 4 |
+
*.log
|
| 5 |
+
*.tmp
|
| 6 |
+
.DS_Store
|
| 7 |
+
.env
|
| 8 |
+
*.md
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 4 |
+
PYTHONDONTWRITEBYTECODE=1
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
COPY requirements.txt ./
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
libgl1 \
|
| 11 |
+
libglib2.0-0 \
|
| 12 |
+
ffmpeg \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 14 |
+
&& pip install --no-cache-dir --upgrade pip \
|
| 15 |
+
&& pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
COPY . .
|
| 18 |
+
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
| 9 |
+
import uvicorn
|
| 10 |
+
|
| 11 |
+
from inference import run_inference, run_segmentation
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
|
| 15 |
+
app = FastAPI(title="Video Object Detection")
|
| 16 |
+
app.add_middleware(
|
| 17 |
+
CORSMiddleware,
|
| 18 |
+
allow_origins=["*"],
|
| 19 |
+
allow_credentials=True,
|
| 20 |
+
allow_methods=["*"],
|
| 21 |
+
allow_headers=["*"],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Valid detection modes
|
| 25 |
+
VALID_MODES = {"object_detection", "segmentation", "drone_detection"}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _save_upload_to_tmp(upload: UploadFile) -> str:
|
| 29 |
+
"""Save uploaded file to temporary location."""
|
| 30 |
+
suffix = Path(upload.filename or "upload.mp4").suffix or ".mp4"
|
| 31 |
+
fd, path = tempfile.mkstemp(prefix="input_", suffix=suffix, dir="/tmp")
|
| 32 |
+
os.close(fd)
|
| 33 |
+
with open(path, "wb") as buffer:
|
| 34 |
+
data = upload.file.read()
|
| 35 |
+
buffer.write(data)
|
| 36 |
+
return path
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _safe_delete(path: str) -> None:
|
| 40 |
+
"""Safely delete a file, ignoring errors."""
|
| 41 |
+
try:
|
| 42 |
+
os.remove(path)
|
| 43 |
+
except FileNotFoundError:
|
| 44 |
+
return
|
| 45 |
+
except Exception:
|
| 46 |
+
logging.exception("Failed to remove temporary file: %s", path)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None:
|
| 50 |
+
"""Schedule file cleanup after response is sent."""
|
| 51 |
+
def _cleanup(target: str = path) -> None:
|
| 52 |
+
_safe_delete(target)
|
| 53 |
+
|
| 54 |
+
background_tasks.add_task(_cleanup)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.get("/", response_class=HTMLResponse)
|
| 58 |
+
async def demo_page() -> str:
|
| 59 |
+
"""Serve the demo page."""
|
| 60 |
+
demo_path = Path(__file__).with_name("demo.html")
|
| 61 |
+
try:
|
| 62 |
+
return demo_path.read_text(encoding="utf-8")
|
| 63 |
+
except FileNotFoundError:
|
| 64 |
+
return "<h1>Demo page missing</h1>"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.post("/detect")
|
| 68 |
+
async def detect_endpoint(
|
| 69 |
+
background_tasks: BackgroundTasks,
|
| 70 |
+
video: UploadFile = File(...),
|
| 71 |
+
mode: str = Form(...),
|
| 72 |
+
queries: str = Form(""),
|
| 73 |
+
detector: str = Form("owlv2_base"),
|
| 74 |
+
segmenter: str = Form("sam3"),
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Main detection endpoint.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
video: Video file to process
|
| 81 |
+
mode: Detection mode (object_detection, segmentation, drone_detection)
|
| 82 |
+
queries: Comma-separated object classes for object_detection mode
|
| 83 |
+
detector: Model to use (owlv2_base, hf_yolov8, detr_resnet50, grounding_dino)
|
| 84 |
+
segmenter: Segmentation model to use (sam3)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
- For object_detection: Processed video with bounding boxes
|
| 88 |
+
- For segmentation: Processed video with masks rendered
|
| 89 |
+
- For drone_detection: JSON with "coming_soon" status
|
| 90 |
+
"""
|
| 91 |
+
# Validate mode
|
| 92 |
+
if mode not in VALID_MODES:
|
| 93 |
+
raise HTTPException(
|
| 94 |
+
status_code=400,
|
| 95 |
+
detail=f"Invalid mode '{mode}'. Must be one of: {', '.join(VALID_MODES)}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if mode == "segmentation":
|
| 99 |
+
if video is None:
|
| 100 |
+
raise HTTPException(status_code=400, detail="Video file is required.")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
input_path = _save_upload_to_tmp(video)
|
| 104 |
+
except Exception:
|
| 105 |
+
logging.exception("Failed to save uploaded file.")
|
| 106 |
+
raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
|
| 107 |
+
finally:
|
| 108 |
+
await video.close()
|
| 109 |
+
|
| 110 |
+
fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp")
|
| 111 |
+
os.close(fd)
|
| 112 |
+
|
| 113 |
+
# Parse queries
|
| 114 |
+
query_list = [q.strip() for q in queries.split(",") if q.strip()]
|
| 115 |
+
if not query_list:
|
| 116 |
+
query_list = ["object"]
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
output_path = run_segmentation(
|
| 120 |
+
input_path,
|
| 121 |
+
output_path,
|
| 122 |
+
query_list,
|
| 123 |
+
segmenter_name=segmenter,
|
| 124 |
+
)
|
| 125 |
+
except ValueError as exc:
|
| 126 |
+
logging.exception("Segmentation processing failed.")
|
| 127 |
+
_safe_delete(input_path)
|
| 128 |
+
_safe_delete(output_path)
|
| 129 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
| 130 |
+
except Exception as exc:
|
| 131 |
+
logging.exception("Segmentation inference failed.")
|
| 132 |
+
_safe_delete(input_path)
|
| 133 |
+
_safe_delete(output_path)
|
| 134 |
+
return JSONResponse(status_code=500, content={"error": str(exc)})
|
| 135 |
+
|
| 136 |
+
_schedule_cleanup(background_tasks, input_path)
|
| 137 |
+
_schedule_cleanup(background_tasks, output_path)
|
| 138 |
+
|
| 139 |
+
return FileResponse(
|
| 140 |
+
path=output_path,
|
| 141 |
+
media_type="video/mp4",
|
| 142 |
+
filename="segmented.mp4",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if mode == "drone_detection":
|
| 146 |
+
return JSONResponse(
|
| 147 |
+
status_code=200,
|
| 148 |
+
content={
|
| 149 |
+
"status": "coming_soon",
|
| 150 |
+
"message": "Drone detection mode is under development. Stay tuned!",
|
| 151 |
+
"mode": "drone_detection"
|
| 152 |
+
}
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Handle object detection mode
|
| 156 |
+
if video is None:
|
| 157 |
+
raise HTTPException(status_code=400, detail="Video file is required.")
|
| 158 |
+
|
| 159 |
+
# Save uploaded video
|
| 160 |
+
try:
|
| 161 |
+
input_path = _save_upload_to_tmp(video)
|
| 162 |
+
except Exception:
|
| 163 |
+
logging.exception("Failed to save uploaded file.")
|
| 164 |
+
raise HTTPException(status_code=500, detail="Failed to save uploaded video.")
|
| 165 |
+
finally:
|
| 166 |
+
await video.close()
|
| 167 |
+
|
| 168 |
+
# Create output path
|
| 169 |
+
fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp")
|
| 170 |
+
os.close(fd)
|
| 171 |
+
|
| 172 |
+
# Parse queries
|
| 173 |
+
query_list = [q.strip() for q in queries.split(",") if q.strip()]
|
| 174 |
+
|
| 175 |
+
# Run inference
|
| 176 |
+
try:
|
| 177 |
+
output_path = run_inference(
|
| 178 |
+
input_path,
|
| 179 |
+
output_path,
|
| 180 |
+
query_list,
|
| 181 |
+
detector_name=detector,
|
| 182 |
+
)
|
| 183 |
+
except ValueError as exc:
|
| 184 |
+
logging.exception("Video processing failed.")
|
| 185 |
+
_safe_delete(input_path)
|
| 186 |
+
_safe_delete(output_path)
|
| 187 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
| 188 |
+
except Exception as exc:
|
| 189 |
+
logging.exception("Inference failed.")
|
| 190 |
+
_safe_delete(input_path)
|
| 191 |
+
_safe_delete(output_path)
|
| 192 |
+
return JSONResponse(status_code=500, content={"error": str(exc)})
|
| 193 |
+
|
| 194 |
+
# Schedule cleanup
|
| 195 |
+
_schedule_cleanup(background_tasks, input_path)
|
| 196 |
+
_schedule_cleanup(background_tasks, output_path)
|
| 197 |
+
|
| 198 |
+
# Return processed video
|
| 199 |
+
response = FileResponse(
|
| 200 |
+
path=output_path,
|
| 201 |
+
media_type="video/mp4",
|
| 202 |
+
filename="processed.mp4",
|
| 203 |
+
)
|
| 204 |
+
return response
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|
coco_classes.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import difflib
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, Tuple
|
| 6 |
+
|
| 7 |
+
COCO_CLASSES: Tuple[str, ...] = (
|
| 8 |
+
"person",
|
| 9 |
+
"bicycle",
|
| 10 |
+
"car",
|
| 11 |
+
"motorcycle",
|
| 12 |
+
"airplane",
|
| 13 |
+
"bus",
|
| 14 |
+
"train",
|
| 15 |
+
"truck",
|
| 16 |
+
"boat",
|
| 17 |
+
"traffic light",
|
| 18 |
+
"fire hydrant",
|
| 19 |
+
"stop sign",
|
| 20 |
+
"parking meter",
|
| 21 |
+
"bench",
|
| 22 |
+
"bird",
|
| 23 |
+
"cat",
|
| 24 |
+
"dog",
|
| 25 |
+
"horse",
|
| 26 |
+
"sheep",
|
| 27 |
+
"cow",
|
| 28 |
+
"elephant",
|
| 29 |
+
"bear",
|
| 30 |
+
"zebra",
|
| 31 |
+
"giraffe",
|
| 32 |
+
"backpack",
|
| 33 |
+
"umbrella",
|
| 34 |
+
"handbag",
|
| 35 |
+
"tie",
|
| 36 |
+
"suitcase",
|
| 37 |
+
"frisbee",
|
| 38 |
+
"skis",
|
| 39 |
+
"snowboard",
|
| 40 |
+
"sports ball",
|
| 41 |
+
"kite",
|
| 42 |
+
"baseball bat",
|
| 43 |
+
"baseball glove",
|
| 44 |
+
"skateboard",
|
| 45 |
+
"surfboard",
|
| 46 |
+
"tennis racket",
|
| 47 |
+
"bottle",
|
| 48 |
+
"wine glass",
|
| 49 |
+
"cup",
|
| 50 |
+
"fork",
|
| 51 |
+
"knife",
|
| 52 |
+
"spoon",
|
| 53 |
+
"bowl",
|
| 54 |
+
"banana",
|
| 55 |
+
"apple",
|
| 56 |
+
"sandwich",
|
| 57 |
+
"orange",
|
| 58 |
+
"broccoli",
|
| 59 |
+
"carrot",
|
| 60 |
+
"hot dog",
|
| 61 |
+
"pizza",
|
| 62 |
+
"donut",
|
| 63 |
+
"cake",
|
| 64 |
+
"chair",
|
| 65 |
+
"couch",
|
| 66 |
+
"potted plant",
|
| 67 |
+
"bed",
|
| 68 |
+
"dining table",
|
| 69 |
+
"toilet",
|
| 70 |
+
"tv",
|
| 71 |
+
"laptop",
|
| 72 |
+
"mouse",
|
| 73 |
+
"remote",
|
| 74 |
+
"keyboard",
|
| 75 |
+
"cell phone",
|
| 76 |
+
"microwave",
|
| 77 |
+
"oven",
|
| 78 |
+
"toaster",
|
| 79 |
+
"sink",
|
| 80 |
+
"refrigerator",
|
| 81 |
+
"book",
|
| 82 |
+
"clock",
|
| 83 |
+
"vase",
|
| 84 |
+
"scissors",
|
| 85 |
+
"teddy bear",
|
| 86 |
+
"hair drier",
|
| 87 |
+
"toothbrush",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def coco_class_catalog() -> str:
|
| 92 |
+
"""Return the COCO classes in a comma-separated catalog for prompts."""
|
| 93 |
+
|
| 94 |
+
return ", ".join(COCO_CLASSES)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _normalize(label: str) -> str:
|
| 98 |
+
return re.sub(r"[^a-z0-9]+", " ", label.lower()).strip()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
_CANONICAL_LOOKUP: Dict[str, str] = {_normalize(name): name for name in COCO_CLASSES}
|
| 102 |
+
_COCO_SYNONYMS: Dict[str, str] = {
|
| 103 |
+
"people": "person",
|
| 104 |
+
"man": "person",
|
| 105 |
+
"woman": "person",
|
| 106 |
+
"men": "person",
|
| 107 |
+
"women": "person",
|
| 108 |
+
"motorbike": "motorcycle",
|
| 109 |
+
"motor bike": "motorcycle",
|
| 110 |
+
"bike": "bicycle",
|
| 111 |
+
"aircraft": "airplane",
|
| 112 |
+
"plane": "airplane",
|
| 113 |
+
"jet": "airplane",
|
| 114 |
+
"aeroplane": "airplane",
|
| 115 |
+
"pickup": "truck",
|
| 116 |
+
"pickup truck": "truck",
|
| 117 |
+
"semi": "truck",
|
| 118 |
+
"lorry": "truck",
|
| 119 |
+
"tractor trailer": "truck",
|
| 120 |
+
"coach": "bus",
|
| 121 |
+
"television": "tv",
|
| 122 |
+
"tv monitor": "tv",
|
| 123 |
+
"mobile phone": "cell phone",
|
| 124 |
+
"smartphone": "cell phone",
|
| 125 |
+
"cellphone": "cell phone",
|
| 126 |
+
"dinner table": "dining table",
|
| 127 |
+
"sofa": "couch",
|
| 128 |
+
"cooker": "oven",
|
| 129 |
+
}
|
| 130 |
+
_ALIAS_LOOKUP: Dict[str, str] = {_normalize(alias): canonical for alias, canonical in _COCO_SYNONYMS.items()}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def canonicalize_coco_name(value: str | None) -> str | None:
|
| 134 |
+
"""Map an arbitrary string to the closest COCO class name if possible."""
|
| 135 |
+
|
| 136 |
+
if not value:
|
| 137 |
+
return None
|
| 138 |
+
normalized = _normalize(value)
|
| 139 |
+
if not normalized:
|
| 140 |
+
return None
|
| 141 |
+
if normalized in _CANONICAL_LOOKUP:
|
| 142 |
+
return _CANONICAL_LOOKUP[normalized]
|
| 143 |
+
if normalized in _ALIAS_LOOKUP:
|
| 144 |
+
return _ALIAS_LOOKUP[normalized]
|
| 145 |
+
|
| 146 |
+
for alias_norm, canonical in _ALIAS_LOOKUP.items():
|
| 147 |
+
if alias_norm and alias_norm in normalized:
|
| 148 |
+
return canonical
|
| 149 |
+
for canonical_norm, canonical in _CANONICAL_LOOKUP.items():
|
| 150 |
+
if canonical_norm and canonical_norm in normalized:
|
| 151 |
+
return canonical
|
| 152 |
+
|
| 153 |
+
tokens = normalized.split()
|
| 154 |
+
for token in tokens:
|
| 155 |
+
if token in _CANONICAL_LOOKUP:
|
| 156 |
+
return _CANONICAL_LOOKUP[token]
|
| 157 |
+
if token in _ALIAS_LOOKUP:
|
| 158 |
+
return _ALIAS_LOOKUP[token]
|
| 159 |
+
|
| 160 |
+
close = difflib.get_close_matches(normalized, list(_CANONICAL_LOOKUP.keys()), n=1, cutoff=0.82)
|
| 161 |
+
if close:
|
| 162 |
+
return _CANONICAL_LOOKUP[close[0]]
|
| 163 |
+
return None
|
demo.html
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
|
| 3 |
+
<html lang="en">
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 7 |
+
<title>Video Object Detection</title>
|
| 8 |
+
<style>
|
| 9 |
+
* {
|
| 10 |
+
margin: 0;
|
| 11 |
+
padding: 0;
|
| 12 |
+
box-sizing: border-box;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
body {
|
| 16 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica', 'Arial', sans-serif;
|
| 17 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 18 |
+
min-height: 100vh;
|
| 19 |
+
padding: 20px;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.container {
|
| 23 |
+
max-width: 1200px;
|
| 24 |
+
margin: 0 auto;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
h1 {
|
| 28 |
+
color: white;
|
| 29 |
+
text-align: center;
|
| 30 |
+
margin-bottom: 30px;
|
| 31 |
+
font-size: 2.5rem;
|
| 32 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
.main-card {
|
| 36 |
+
background: white;
|
| 37 |
+
border-radius: 16px;
|
| 38 |
+
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
| 39 |
+
padding: 40px;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
.section {
|
| 43 |
+
margin-bottom: 30px;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.section-title {
|
| 47 |
+
font-size: 1.2rem;
|
| 48 |
+
font-weight: 600;
|
| 49 |
+
color: #333;
|
| 50 |
+
margin-bottom: 15px;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
/* Mode selector */
|
| 54 |
+
.mode-selector {
|
| 55 |
+
display: grid;
|
| 56 |
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
| 57 |
+
gap: 15px;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
.mode-card {
|
| 61 |
+
position: relative;
|
| 62 |
+
padding: 20px;
|
| 63 |
+
border: 2px solid #e0e0e0;
|
| 64 |
+
border-radius: 12px;
|
| 65 |
+
cursor: pointer;
|
| 66 |
+
transition: all 0.3s ease;
|
| 67 |
+
text-align: center;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.mode-card:hover {
|
| 71 |
+
border-color: #667eea;
|
| 72 |
+
transform: translateY(-2px);
|
| 73 |
+
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.2);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.mode-card.selected {
|
| 77 |
+
border-color: #667eea;
|
| 78 |
+
background: #f0f4ff;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.mode-card.disabled {
|
| 82 |
+
opacity: 0.5;
|
| 83 |
+
cursor: not-allowed;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
.mode-card input[type="radio"] {
|
| 87 |
+
position: absolute;
|
| 88 |
+
opacity: 0;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.mode-icon {
|
| 92 |
+
font-size: 2rem;
|
| 93 |
+
margin-bottom: 10px;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
.mode-title {
|
| 97 |
+
font-weight: 600;
|
| 98 |
+
color: #333;
|
| 99 |
+
margin-bottom: 5px;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.mode-badge {
|
| 103 |
+
display: inline-block;
|
| 104 |
+
padding: 4px 8px;
|
| 105 |
+
background: #ffc107;
|
| 106 |
+
color: white;
|
| 107 |
+
font-size: 0.7rem;
|
| 108 |
+
border-radius: 4px;
|
| 109 |
+
font-weight: 600;
|
| 110 |
+
margin-top: 8px;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
/* Input fields */
|
| 114 |
+
.input-group {
|
| 115 |
+
margin-bottom: 20px;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.input-group label {
|
| 119 |
+
display: block;
|
| 120 |
+
font-weight: 500;
|
| 121 |
+
color: #555;
|
| 122 |
+
margin-bottom: 8px;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.input-group input[type="text"],
|
| 126 |
+
.input-group select {
|
| 127 |
+
width: 100%;
|
| 128 |
+
padding: 12px;
|
| 129 |
+
border: 2px solid #e0e0e0;
|
| 130 |
+
border-radius: 8px;
|
| 131 |
+
font-size: 1rem;
|
| 132 |
+
transition: border-color 0.3s;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.input-group input[type="text"]:focus,
|
| 136 |
+
.input-group select:focus {
|
| 137 |
+
outline: none;
|
| 138 |
+
border-color: #667eea;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.file-input-wrapper {
|
| 142 |
+
position: relative;
|
| 143 |
+
display: inline-block;
|
| 144 |
+
width: 100%;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
.file-input-label {
|
| 148 |
+
display: block;
|
| 149 |
+
padding: 15px;
|
| 150 |
+
background: #f8f9fa;
|
| 151 |
+
border: 2px dashed #ccc;
|
| 152 |
+
border-radius: 8px;
|
| 153 |
+
text-align: center;
|
| 154 |
+
cursor: pointer;
|
| 155 |
+
transition: all 0.3s;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
.file-input-label:hover {
|
| 159 |
+
border-color: #667eea;
|
| 160 |
+
background: #f0f4ff;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.file-input-label.has-file {
|
| 164 |
+
border-color: #28a745;
|
| 165 |
+
background: #d4edda;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
input[type="file"] {
|
| 169 |
+
position: absolute;
|
| 170 |
+
opacity: 0;
|
| 171 |
+
width: 0;
|
| 172 |
+
height: 0;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
/* Buttons */
|
| 176 |
+
.btn {
|
| 177 |
+
padding: 14px 28px;
|
| 178 |
+
font-size: 1rem;
|
| 179 |
+
font-weight: 600;
|
| 180 |
+
border: none;
|
| 181 |
+
border-radius: 8px;
|
| 182 |
+
cursor: pointer;
|
| 183 |
+
transition: all 0.3s;
|
| 184 |
+
width: 100%;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.btn-primary {
|
| 188 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 189 |
+
color: white;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
.btn-primary:hover:not(:disabled) {
|
| 193 |
+
transform: translateY(-2px);
|
| 194 |
+
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
.btn:disabled {
|
| 198 |
+
opacity: 0.5;
|
| 199 |
+
cursor: not-allowed;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
/* Results */
|
| 203 |
+
.results-grid {
|
| 204 |
+
display: grid;
|
| 205 |
+
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
|
| 206 |
+
gap: 20px;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
.video-card {
|
| 210 |
+
border: 1px solid #e0e0e0;
|
| 211 |
+
border-radius: 8px;
|
| 212 |
+
overflow: hidden;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.video-card-header {
|
| 216 |
+
background: #f8f9fa;
|
| 217 |
+
padding: 12px 16px;
|
| 218 |
+
font-weight: 600;
|
| 219 |
+
color: #333;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
.video-card-body {
|
| 223 |
+
padding: 16px;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
video {
|
| 227 |
+
width: 100%;
|
| 228 |
+
border-radius: 8px;
|
| 229 |
+
background: #000;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
.download-btn {
|
| 233 |
+
margin-top: 12px;
|
| 234 |
+
padding: 10px 16px;
|
| 235 |
+
background: #28a745;
|
| 236 |
+
color: white;
|
| 237 |
+
text-decoration: none;
|
| 238 |
+
border-radius: 6px;
|
| 239 |
+
display: inline-block;
|
| 240 |
+
font-size: 0.9rem;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
.download-btn:hover {
|
| 244 |
+
background: #218838;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/* Loading spinner */
|
| 248 |
+
.loading {
|
| 249 |
+
display: none;
|
| 250 |
+
text-align: center;
|
| 251 |
+
padding: 20px;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
.loading.show {
|
| 255 |
+
display: block;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.spinner {
|
| 259 |
+
border: 4px solid #f3f3f3;
|
| 260 |
+
border-top: 4px solid #667eea;
|
| 261 |
+
border-radius: 50%;
|
| 262 |
+
width: 40px;
|
| 263 |
+
height: 40px;
|
| 264 |
+
animation: spin 1s linear infinite;
|
| 265 |
+
margin: 0 auto 10px;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
@keyframes spin {
|
| 269 |
+
0% { transform: rotate(0deg); }
|
| 270 |
+
100% { transform: rotate(360deg); }
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
.hidden {
|
| 274 |
+
display: none;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
/* Modal */
|
| 278 |
+
.modal {
|
| 279 |
+
display: none;
|
| 280 |
+
position: fixed;
|
| 281 |
+
z-index: 1000;
|
| 282 |
+
left: 0;
|
| 283 |
+
top: 0;
|
| 284 |
+
width: 100%;
|
| 285 |
+
height: 100%;
|
| 286 |
+
background: rgba(0,0,0,0.5);
|
| 287 |
+
align-items: center;
|
| 288 |
+
justify-content: center;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.modal.show {
|
| 292 |
+
display: flex;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.modal-content {
|
| 296 |
+
background: white;
|
| 297 |
+
padding: 30px;
|
| 298 |
+
border-radius: 12px;
|
| 299 |
+
max-width: 500px;
|
| 300 |
+
text-align: center;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.modal-content h2 {
|
| 304 |
+
margin-bottom: 15px;
|
| 305 |
+
color: #333;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
.modal-content p {
|
| 309 |
+
margin-bottom: 20px;
|
| 310 |
+
color: #666;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
.modal-btn {
|
| 314 |
+
padding: 10px 24px;
|
| 315 |
+
background: #667eea;
|
| 316 |
+
color: white;
|
| 317 |
+
border: none;
|
| 318 |
+
border-radius: 6px;
|
| 319 |
+
cursor: pointer;
|
| 320 |
+
font-size: 1rem;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
.modal-btn:hover {
|
| 324 |
+
background: #5568d3;
|
| 325 |
+
}
|
| 326 |
+
</style>
|
| 327 |
+
</head>
|
| 328 |
+
<body>
|
| 329 |
+
<div class="container">
|
| 330 |
+
<h1>🎥 Video Object Detection</h1>
|
| 331 |
+
|
| 332 |
+
<div class="main-card">
|
| 333 |
+
<!-- Mode Selection -->
|
| 334 |
+
<div class="section">
|
| 335 |
+
<div class="section-title">1. Select Detection Mode</div>
|
| 336 |
+
<div class="mode-selector">
|
| 337 |
+
<label class="mode-card selected">
|
| 338 |
+
<input type="radio" name="mode" value="object_detection" checked>
|
| 339 |
+
<div class="mode-icon">🎯</div>
|
| 340 |
+
<div class="mode-title">Object Detection</div>
|
| 341 |
+
</label>
|
| 342 |
+
|
| 343 |
+
<label class="mode-card">
|
| 344 |
+
<input type="radio" name="mode" value="segmentation">
|
| 345 |
+
<div class="mode-icon">🎨</div>
|
| 346 |
+
<div class="mode-title">Segmentation</div>
|
| 347 |
+
</label>
|
| 348 |
+
|
| 349 |
+
<label class="mode-card disabled">
|
| 350 |
+
<input type="radio" name="mode" value="drone_detection">
|
| 351 |
+
<div class="mode-icon">🚁</div>
|
| 352 |
+
<div class="mode-title">Drone Detection</div>
|
| 353 |
+
<span class="mode-badge">COMING SOON</span>
|
| 354 |
+
</label>
|
| 355 |
+
</div>
|
| 356 |
+
</div>
|
| 357 |
+
|
| 358 |
+
<!-- Text Prompts Input (for all modes) -->
|
| 359 |
+
<div class="section" id="queriesSection">
|
| 360 |
+
<div class="input-group">
|
| 361 |
+
<label for="queries" id="queriesLabel">Text Prompts (comma-separated)</label>
|
| 362 |
+
<input
|
| 363 |
+
type="text"
|
| 364 |
+
id="queries"
|
| 365 |
+
placeholder="person, car, dog, bicycle"
|
| 366 |
+
>
|
| 367 |
+
<small id="queriesHint" style="color: #666; display: block; margin-top: 5px;">
|
| 368 |
+
Enter objects to detect or segment
|
| 369 |
+
</small>
|
| 370 |
+
</div>
|
| 371 |
+
</div>
|
| 372 |
+
|
| 373 |
+
<!-- Detector Selection -->
|
| 374 |
+
<div class="section" id="detectorSection">
|
| 375 |
+
<div class="input-group">
|
| 376 |
+
<label for="detector">2. Select Detection Model</label>
|
| 377 |
+
<select id="detector">
|
| 378 |
+
<option value="owlv2_base">OWLv2 (Open-vocabulary, Default)</option>
|
| 379 |
+
<option value="hf_yolov8">YOLOv8 (Fast, COCO classes)</option>
|
| 380 |
+
<option value="detr_resnet50">DETR ResNet-50 (Transformer-based)</option>
|
| 381 |
+
<option value="grounding_dino">Grounding DINO (Open-vocabulary)</option>
|
| 382 |
+
</select>
|
| 383 |
+
</div>
|
| 384 |
+
</div>
|
| 385 |
+
|
| 386 |
+
<!-- Segmenter Selection -->
|
| 387 |
+
<div class="section hidden" id="segmenterSection">
|
| 388 |
+
<div class="input-group">
|
| 389 |
+
<label for="segmenter">2. Select Segmentation Model</label>
|
| 390 |
+
<select id="segmenter">
|
| 391 |
+
<option value="sam3">SAM3 (Segment Anything Model 3)</option>
|
| 392 |
+
</select>
|
| 393 |
+
</div>
|
| 394 |
+
</div>
|
| 395 |
+
|
| 396 |
+
<!-- Video Upload -->
|
| 397 |
+
<div class="section">
|
| 398 |
+
<div class="input-group">
|
| 399 |
+
<label>3. Upload Video</label>
|
| 400 |
+
<div class="file-input-wrapper">
|
| 401 |
+
<label class="file-input-label" id="fileLabel" for="videoFile">
|
| 402 |
+
📁 Click to select video file (MP4)
|
| 403 |
+
</label>
|
| 404 |
+
<input type="file" id="videoFile" accept="video/*">
|
| 405 |
+
</div>
|
| 406 |
+
</div>
|
| 407 |
+
</div>
|
| 408 |
+
|
| 409 |
+
<!-- Process Button -->
|
| 410 |
+
<div class="section">
|
| 411 |
+
<button class="btn btn-primary" id="processBtn" disabled>
|
| 412 |
+
🚀 Process Video
|
| 413 |
+
</button>
|
| 414 |
+
</div>
|
| 415 |
+
|
| 416 |
+
<!-- Loading -->
|
| 417 |
+
<div class="loading" id="loading">
|
| 418 |
+
<div class="spinner"></div>
|
| 419 |
+
<p>Processing video... This may take a while depending on video length.</p>
|
| 420 |
+
</div>
|
| 421 |
+
|
| 422 |
+
<!-- Results -->
|
| 423 |
+
<div class="section hidden" id="resultsSection">
|
| 424 |
+
<div class="section-title">Results</div>
|
| 425 |
+
<div class="results-grid">
|
| 426 |
+
<div class="video-card">
|
| 427 |
+
<div class="video-card-header">Original Video</div>
|
| 428 |
+
<div class="video-card-body">
|
| 429 |
+
<video id="originalVideo" controls></video>
|
| 430 |
+
</div>
|
| 431 |
+
</div>
|
| 432 |
+
<div class="video-card">
|
| 433 |
+
<div class="video-card-header">Processed Video</div>
|
| 434 |
+
<div class="video-card-body">
|
| 435 |
+
<video id="processedVideo" controls autoplay loop></video>
|
| 436 |
+
<a id="downloadBtn" class="download-btn" download="processed.mp4">
|
| 437 |
+
⬇️ Download Processed Video
|
| 438 |
+
</a>
|
| 439 |
+
</div>
|
| 440 |
+
</div>
|
| 441 |
+
</div>
|
| 442 |
+
</div>
|
| 443 |
+
</div>
|
| 444 |
+
</div>
|
| 445 |
+
|
| 446 |
+
<!-- Coming Soon Modal -->
|
| 447 |
+
<div class="modal" id="comingSoonModal">
|
| 448 |
+
<div class="modal-content">
|
| 449 |
+
<h2>Coming Soon!</h2>
|
| 450 |
+
<p id="modalMessage"></p>
|
| 451 |
+
<button class="modal-btn" id="modalClose">Got it</button>
|
| 452 |
+
</div>
|
| 453 |
+
</div>
|
| 454 |
+
|
| 455 |
+
<script>
|
| 456 |
+
// State
|
| 457 |
+
let selectedMode = 'object_detection';
|
| 458 |
+
let videoFile = null;
|
| 459 |
+
|
| 460 |
+
// Elements
|
| 461 |
+
const modeCards = document.querySelectorAll('.mode-card');
|
| 462 |
+
const queriesSection = document.getElementById('queriesSection');
|
| 463 |
+
const queriesLabel = document.getElementById('queriesLabel');
|
| 464 |
+
const queriesHint = document.getElementById('queriesHint');
|
| 465 |
+
const detectorSection = document.getElementById('detectorSection');
|
| 466 |
+
const segmenterSection = document.getElementById('segmenterSection');
|
| 467 |
+
const fileInput = document.getElementById('videoFile');
|
| 468 |
+
const fileLabel = document.getElementById('fileLabel');
|
| 469 |
+
const processBtn = document.getElementById('processBtn');
|
| 470 |
+
const loading = document.getElementById('loading');
|
| 471 |
+
const resultsSection = document.getElementById('resultsSection');
|
| 472 |
+
const originalVideo = document.getElementById('originalVideo');
|
| 473 |
+
const processedVideo = document.getElementById('processedVideo');
|
| 474 |
+
const downloadBtn = document.getElementById('downloadBtn');
|
| 475 |
+
const modal = document.getElementById('comingSoonModal');
|
| 476 |
+
const modalMessage = document.getElementById('modalMessage');
|
| 477 |
+
const modalClose = document.getElementById('modalClose');
|
| 478 |
+
|
| 479 |
+
// Mode selection handler
|
| 480 |
+
modeCards.forEach(card => {
|
| 481 |
+
card.addEventListener('click', (e) => {
|
| 482 |
+
const input = card.querySelector('input[type="radio"]');
|
| 483 |
+
const mode = input.value;
|
| 484 |
+
|
| 485 |
+
// Check if disabled
|
| 486 |
+
if (card.classList.contains('disabled')) {
|
| 487 |
+
e.preventDefault();
|
| 488 |
+
showComingSoonModal(mode);
|
| 489 |
+
return;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Update selected state
|
| 493 |
+
modeCards.forEach(c => c.classList.remove('selected'));
|
| 494 |
+
card.classList.add('selected');
|
| 495 |
+
selectedMode = mode;
|
| 496 |
+
|
| 497 |
+
// Update query label and hint based on mode
|
| 498 |
+
if (mode === 'object_detection') {
|
| 499 |
+
queriesLabel.textContent = 'Objects to Detect (comma-separated)';
|
| 500 |
+
queriesHint.textContent = 'Example: person, car, dog, bicycle';
|
| 501 |
+
detectorSection.classList.remove('hidden');
|
| 502 |
+
segmenterSection.classList.add('hidden');
|
| 503 |
+
} else if (mode === 'segmentation') {
|
| 504 |
+
queriesLabel.textContent = 'Objects to Segment (comma-separated)';
|
| 505 |
+
queriesHint.textContent = 'Example: person, car, building, tree';
|
| 506 |
+
detectorSection.classList.add('hidden');
|
| 507 |
+
segmenterSection.classList.remove('hidden');
|
| 508 |
+
} else if (mode === 'drone_detection') {
|
| 509 |
+
queriesLabel.textContent = 'Drone Types to Detect (comma-separated)';
|
| 510 |
+
queriesHint.textContent = 'Example: quadcopter, fixed-wing, drone';
|
| 511 |
+
detectorSection.classList.add('hidden');
|
| 512 |
+
segmenterSection.classList.add('hidden');
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
// Always show queries section
|
| 516 |
+
queriesSection.classList.remove('hidden');
|
| 517 |
+
});
|
| 518 |
+
});
|
| 519 |
+
|
| 520 |
+
// File input handler
|
| 521 |
+
fileInput.addEventListener('change', (e) => {
|
| 522 |
+
videoFile = e.target.files[0];
|
| 523 |
+
if (videoFile) {
|
| 524 |
+
fileLabel.textContent = `✅ ${videoFile.name}`;
|
| 525 |
+
fileLabel.classList.add('has-file');
|
| 526 |
+
processBtn.disabled = false;
|
| 527 |
+
|
| 528 |
+
// Preview original video
|
| 529 |
+
originalVideo.src = URL.createObjectURL(videoFile);
|
| 530 |
+
}
|
| 531 |
+
});
|
| 532 |
+
|
| 533 |
+
// Process button handler
|
| 534 |
+
processBtn.addEventListener('click', async () => {
|
| 535 |
+
if (!videoFile) {
|
| 536 |
+
alert('Please select a video file first.');
|
| 537 |
+
return;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
// Show loading
|
| 541 |
+
processBtn.disabled = true;
|
| 542 |
+
loading.classList.add('show');
|
| 543 |
+
resultsSection.classList.add('hidden');
|
| 544 |
+
|
| 545 |
+
// Prepare form data
|
| 546 |
+
const formData = new FormData();
|
| 547 |
+
formData.append('video', videoFile);
|
| 548 |
+
formData.append('mode', selectedMode);
|
| 549 |
+
formData.append('queries', document.getElementById('queries').value);
|
| 550 |
+
formData.append('detector', document.getElementById('detector').value);
|
| 551 |
+
formData.append('segmenter', document.getElementById('segmenter').value);
|
| 552 |
+
|
| 553 |
+
try {
|
| 554 |
+
const response = await fetch('/detect', {
|
| 555 |
+
method: 'POST',
|
| 556 |
+
body: formData
|
| 557 |
+
});
|
| 558 |
+
|
| 559 |
+
if (response.ok) {
|
| 560 |
+
const contentType = response.headers.get('content-type');
|
| 561 |
+
|
| 562 |
+
if (contentType && contentType.includes('application/json')) {
|
| 563 |
+
// Coming soon response
|
| 564 |
+
const data = await response.json();
|
| 565 |
+
showComingSoonModal(data.mode);
|
| 566 |
+
} else {
|
| 567 |
+
// Video response
|
| 568 |
+
const blob = await response.blob();
|
| 569 |
+
const videoUrl = URL.createObjectURL(blob);
|
| 570 |
+
processedVideo.src = videoUrl;
|
| 571 |
+
downloadBtn.href = videoUrl;
|
| 572 |
+
resultsSection.classList.remove('hidden');
|
| 573 |
+
}
|
| 574 |
+
} else {
|
| 575 |
+
const error = await response.json();
|
| 576 |
+
alert(`Error: ${error.detail || error.error || 'Processing failed'}`);
|
| 577 |
+
}
|
| 578 |
+
} catch (error) {
|
| 579 |
+
console.error('Error:', error);
|
| 580 |
+
alert('Network error: ' + error.message);
|
| 581 |
+
} finally {
|
| 582 |
+
loading.classList.remove('show');
|
| 583 |
+
processBtn.disabled = false;
|
| 584 |
+
}
|
| 585 |
+
});
|
| 586 |
+
|
| 587 |
+
// Coming soon modal
|
| 588 |
+
function showComingSoonModal(mode) {
|
| 589 |
+
const messages = {
|
| 590 |
+
'drone_detection': 'Drone detection mode is under development. Stay tuned for specialized UAV and aerial object detection!'
|
| 591 |
+
};
|
| 592 |
+
modalMessage.textContent = messages[mode] || 'This feature is coming soon!';
|
| 593 |
+
modal.classList.add('show');
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
modalClose.addEventListener('click', () => {
|
| 597 |
+
modal.classList.remove('show');
|
| 598 |
+
// Reset to object detection
|
| 599 |
+
document.querySelector('input[value="object_detection"]').checked = true;
|
| 600 |
+
modeCards.forEach(c => c.classList.remove('selected'));
|
| 601 |
+
document.querySelector('input[value="object_detection"]').closest('.mode-card').classList.add('selected');
|
| 602 |
+
selectedMode = 'object_detection';
|
| 603 |
+
// Update labels for object detection mode
|
| 604 |
+
queriesLabel.textContent = 'Objects to Detect (comma-separated)';
|
| 605 |
+
queriesHint.textContent = 'Example: person, car, dog, bicycle';
|
| 606 |
+
detectorSection.classList.remove('hidden');
|
| 607 |
+
segmenterSection.classList.add('hidden');
|
| 608 |
+
});
|
| 609 |
+
|
| 610 |
+
// Close modal on background click
|
| 611 |
+
modal.addEventListener('click', (e) => {
|
| 612 |
+
if (e.target === modal) {
|
| 613 |
+
modalClose.click();
|
| 614 |
+
}
|
| 615 |
+
});
|
| 616 |
+
</script>
|
| 617 |
+
</body>
|
| 618 |
+
</html>
|
inference.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from models.model_loader import load_detector
|
| 7 |
+
from models.segmenters.model_loader import load_segmenter
|
| 8 |
+
from utils.video import extract_frames, write_video
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
| 12 |
+
output = frame.copy()
|
| 13 |
+
if boxes is None:
|
| 14 |
+
return output
|
| 15 |
+
for box in boxes:
|
| 16 |
+
x1, y1, x2, y2 = [int(coord) for coord in box]
|
| 17 |
+
cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
|
| 18 |
+
return output
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def draw_masks(frame: np.ndarray, masks: np.ndarray, alpha: float = 0.45) -> np.ndarray:
|
| 22 |
+
output = frame.copy()
|
| 23 |
+
if masks is None or len(masks) == 0:
|
| 24 |
+
return output
|
| 25 |
+
colors = [
|
| 26 |
+
(255, 0, 0),
|
| 27 |
+
(0, 255, 0),
|
| 28 |
+
(0, 0, 255),
|
| 29 |
+
(255, 255, 0),
|
| 30 |
+
(0, 255, 255),
|
| 31 |
+
(255, 0, 255),
|
| 32 |
+
]
|
| 33 |
+
for idx, mask in enumerate(masks):
|
| 34 |
+
if mask is None:
|
| 35 |
+
continue
|
| 36 |
+
if mask.ndim == 3:
|
| 37 |
+
mask = mask[0]
|
| 38 |
+
if mask.shape[:2] != output.shape[:2]:
|
| 39 |
+
mask = cv2.resize(mask, (output.shape[1], output.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 40 |
+
mask_bool = mask.astype(bool)
|
| 41 |
+
overlay = np.zeros_like(output, dtype=np.uint8)
|
| 42 |
+
overlay[mask_bool] = colors[idx % len(colors)]
|
| 43 |
+
output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _build_detection_records(
|
| 48 |
+
boxes: np.ndarray,
|
| 49 |
+
scores: Sequence[float],
|
| 50 |
+
labels: Sequence[int],
|
| 51 |
+
queries: Sequence[str],
|
| 52 |
+
label_names: Optional[Sequence[str]] = None,
|
| 53 |
+
) -> List[Dict[str, Any]]:
|
| 54 |
+
detections: List[Dict[str, Any]] = []
|
| 55 |
+
for idx, box in enumerate(boxes):
|
| 56 |
+
if label_names is not None and idx < len(label_names):
|
| 57 |
+
label = label_names[idx]
|
| 58 |
+
else:
|
| 59 |
+
label_idx = int(labels[idx]) if idx < len(labels) else -1
|
| 60 |
+
if 0 <= label_idx < len(queries):
|
| 61 |
+
label = queries[label_idx]
|
| 62 |
+
else:
|
| 63 |
+
label = f"label_{label_idx}"
|
| 64 |
+
detections.append(
|
| 65 |
+
{
|
| 66 |
+
"label": label,
|
| 67 |
+
"score": float(scores[idx]) if idx < len(scores) else 0.0,
|
| 68 |
+
"bbox": [int(coord) for coord in box.tolist()],
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
return detections
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def infer_frame(
|
| 75 |
+
frame: np.ndarray,
|
| 76 |
+
queries: Sequence[str],
|
| 77 |
+
detector_name: Optional[str] = None,
|
| 78 |
+
) -> tuple[np.ndarray, List[Dict[str, Any]]]:
|
| 79 |
+
detector = load_detector(detector_name)
|
| 80 |
+
text_queries = list(queries) or ["object"]
|
| 81 |
+
try:
|
| 82 |
+
result = detector.predict(frame, text_queries)
|
| 83 |
+
detections = _build_detection_records(
|
| 84 |
+
result.boxes, result.scores, result.labels, text_queries, result.label_names
|
| 85 |
+
)
|
| 86 |
+
except Exception:
|
| 87 |
+
logging.exception("Inference failed for queries %s", text_queries)
|
| 88 |
+
raise
|
| 89 |
+
return draw_boxes(frame, result.boxes), detections
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def infer_segmentation_frame(
|
| 93 |
+
frame: np.ndarray,
|
| 94 |
+
text_queries: Optional[List[str]] = None,
|
| 95 |
+
segmenter_name: Optional[str] = None,
|
| 96 |
+
) -> tuple[np.ndarray, Any]:
|
| 97 |
+
segmenter = load_segmenter(segmenter_name)
|
| 98 |
+
result = segmenter.predict(frame, text_prompts=text_queries)
|
| 99 |
+
return draw_masks(frame, result.masks), result
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def run_inference(
|
| 103 |
+
input_video_path: str,
|
| 104 |
+
output_video_path: str,
|
| 105 |
+
queries: List[str],
|
| 106 |
+
max_frames: Optional[int] = None,
|
| 107 |
+
detector_name: Optional[str] = None,
|
| 108 |
+
) -> str:
|
| 109 |
+
"""
|
| 110 |
+
Run object detection inference on a video.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
input_video_path: Path to input video
|
| 114 |
+
output_video_path: Path to write processed video
|
| 115 |
+
queries: List of object classes to detect (e.g., ["person", "car"])
|
| 116 |
+
max_frames: Optional frame limit for testing
|
| 117 |
+
detector_name: Detector to use (default: owlv2_base)
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Path to processed output video
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
frames, fps, width, height = extract_frames(input_video_path)
|
| 124 |
+
except ValueError as exc:
|
| 125 |
+
logging.exception("Failed to decode video at %s", input_video_path)
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
# Use provided queries or default to common objects
|
| 129 |
+
if not queries:
|
| 130 |
+
queries = ["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]
|
| 131 |
+
logging.info("No queries provided, using defaults: %s", queries)
|
| 132 |
+
|
| 133 |
+
logging.info("Detection queries: %s", queries)
|
| 134 |
+
|
| 135 |
+
# Select detector
|
| 136 |
+
active_detector = detector_name or "owlv2_base"
|
| 137 |
+
logging.info("Using detector: %s", active_detector)
|
| 138 |
+
|
| 139 |
+
# Process frames
|
| 140 |
+
processed_frames: List[np.ndarray] = []
|
| 141 |
+
for idx, frame in enumerate(frames):
|
| 142 |
+
if max_frames is not None and idx >= max_frames:
|
| 143 |
+
break
|
| 144 |
+
logging.debug("Processing frame %d", idx)
|
| 145 |
+
processed_frame, _ = infer_frame(frame, queries, detector_name=active_detector)
|
| 146 |
+
processed_frames.append(processed_frame)
|
| 147 |
+
|
| 148 |
+
# Write output video
|
| 149 |
+
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
| 150 |
+
logging.info("Processed video written to: %s", output_video_path)
|
| 151 |
+
|
| 152 |
+
return output_video_path
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def run_segmentation(
|
| 156 |
+
input_video_path: str,
|
| 157 |
+
output_video_path: str,
|
| 158 |
+
queries: List[str],
|
| 159 |
+
max_frames: Optional[int] = None,
|
| 160 |
+
segmenter_name: Optional[str] = None,
|
| 161 |
+
) -> str:
|
| 162 |
+
try:
|
| 163 |
+
frames, fps, width, height = extract_frames(input_video_path)
|
| 164 |
+
except ValueError as exc:
|
| 165 |
+
logging.exception("Failed to decode video at %s", input_video_path)
|
| 166 |
+
raise
|
| 167 |
+
|
| 168 |
+
active_segmenter = segmenter_name or "sam3"
|
| 169 |
+
logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
|
| 170 |
+
|
| 171 |
+
processed_frames: List[np.ndarray] = []
|
| 172 |
+
for idx, frame in enumerate(frames):
|
| 173 |
+
if max_frames is not None and idx >= max_frames:
|
| 174 |
+
break
|
| 175 |
+
logging.debug("Processing frame %d", idx)
|
| 176 |
+
processed_frame, _ = infer_segmentation_frame(frame, text_queries=queries, segmenter_name=active_segmenter)
|
| 177 |
+
processed_frames.append(processed_frame)
|
| 178 |
+
|
| 179 |
+
write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
|
| 180 |
+
logging.info("Segmented video written to: %s", output_video_path)
|
| 181 |
+
|
| 182 |
+
return output_video_path
|
models/detectors/base.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import NamedTuple, Optional, Sequence
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DetectionResult(NamedTuple):
|
| 7 |
+
boxes: np.ndarray
|
| 8 |
+
scores: Sequence[float]
|
| 9 |
+
labels: Sequence[int]
|
| 10 |
+
label_names: Optional[Sequence[str]] = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ObjectDetector:
|
| 14 |
+
"""Detector interface to keep inference agnostic to model details."""
|
| 15 |
+
|
| 16 |
+
name: str
|
| 17 |
+
|
| 18 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 19 |
+
raise NotImplementedError
|
models/detectors/detr.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import DetrForObjectDetection, DetrImageProcessor
|
| 7 |
+
|
| 8 |
+
from models.detectors.base import DetectionResult, ObjectDetector
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DetrDetector(ObjectDetector):
|
| 12 |
+
"""Wrapper around facebook/detr-resnet-50 for mission-aligned detection."""
|
| 13 |
+
|
| 14 |
+
MODEL_NAME = "facebook/detr-resnet-50"
|
| 15 |
+
|
| 16 |
+
def __init__(self, score_threshold: float = 0.3) -> None:
|
| 17 |
+
self.name = "detr_resnet50"
|
| 18 |
+
self.score_threshold = score_threshold
|
| 19 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 21 |
+
self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME)
|
| 22 |
+
self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME)
|
| 23 |
+
self.model.to(self.device)
|
| 24 |
+
self.model.eval()
|
| 25 |
+
|
| 26 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 27 |
+
inputs = self.processor(images=frame, return_tensors="pt")
|
| 28 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
outputs = self.model(**inputs)
|
| 31 |
+
target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
|
| 32 |
+
processed = self.processor.post_process_object_detection(
|
| 33 |
+
outputs,
|
| 34 |
+
threshold=self.score_threshold,
|
| 35 |
+
target_sizes=target_sizes,
|
| 36 |
+
)[0]
|
| 37 |
+
boxes = processed["boxes"].cpu().numpy()
|
| 38 |
+
scores = processed["scores"].cpu().tolist()
|
| 39 |
+
labels = processed["labels"].cpu().tolist()
|
| 40 |
+
label_names = [
|
| 41 |
+
self.model.config.id2label.get(int(idx), f"class_{idx}") for idx in labels
|
| 42 |
+
]
|
| 43 |
+
return DetectionResult(
|
| 44 |
+
boxes=boxes,
|
| 45 |
+
scores=scores,
|
| 46 |
+
labels=labels,
|
| 47 |
+
label_names=label_names,
|
| 48 |
+
)
|
models/detectors/grounding_dino.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
|
| 7 |
+
|
| 8 |
+
from models.detectors.base import DetectionResult, ObjectDetector
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GroundingDinoDetector(ObjectDetector):
|
| 12 |
+
"""IDEA-Research Grounding DINO-B detector for open-vocabulary missions."""
|
| 13 |
+
|
| 14 |
+
MODEL_NAME = "IDEA-Research/grounding-dino-base"
|
| 15 |
+
|
| 16 |
+
def __init__(self, box_threshold: float = 0.35, text_threshold: float = 0.25) -> None:
|
| 17 |
+
self.name = "grounding_dino"
|
| 18 |
+
self.box_threshold = box_threshold
|
| 19 |
+
self.text_threshold = text_threshold
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 22 |
+
self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
|
| 23 |
+
self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
|
| 24 |
+
self.model.to(self.device)
|
| 25 |
+
self.model.eval()
|
| 26 |
+
|
| 27 |
+
def _build_prompt(self, queries: Sequence[str]) -> str:
|
| 28 |
+
filtered = [query.strip() for query in queries if query and query.strip()]
|
| 29 |
+
if not filtered:
|
| 30 |
+
return "object."
|
| 31 |
+
return " ".join(f"{term}." for term in filtered)
|
| 32 |
+
|
| 33 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 34 |
+
prompt = self._build_prompt(queries)
|
| 35 |
+
inputs = self.processor(images=frame, text=prompt, return_tensors="pt")
|
| 36 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
outputs = self.model(**inputs)
|
| 39 |
+
target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
|
| 40 |
+
processed = self.processor.post_process_grounded_object_detection(
|
| 41 |
+
outputs,
|
| 42 |
+
inputs["input_ids"],
|
| 43 |
+
box_threshold=self.box_threshold,
|
| 44 |
+
text_threshold=self.text_threshold,
|
| 45 |
+
target_sizes=target_sizes,
|
| 46 |
+
)[0]
|
| 47 |
+
boxes = processed["boxes"].cpu().numpy()
|
| 48 |
+
scores = processed["scores"].cpu().tolist()
|
| 49 |
+
label_names = list(processed.get("labels") or [])
|
| 50 |
+
label_ids = list(range(len(label_names)))
|
| 51 |
+
return DetectionResult(
|
| 52 |
+
boxes=boxes,
|
| 53 |
+
scores=scores,
|
| 54 |
+
labels=label_ids,
|
| 55 |
+
label_names=label_names,
|
| 56 |
+
)
|
models/detectors/owlv2.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import Owlv2ForObjectDetection, Owlv2Processor
|
| 7 |
+
|
| 8 |
+
from models.detectors.base import DetectionResult, ObjectDetector
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Owlv2Detector(ObjectDetector):
|
| 12 |
+
MODEL_NAME = "google/owlv2-base-patch32"
|
| 13 |
+
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 17 |
+
self.processor = Owlv2Processor.from_pretrained(self.MODEL_NAME)
|
| 18 |
+
torch_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
|
| 19 |
+
self.model = Owlv2ForObjectDetection.from_pretrained(
|
| 20 |
+
self.MODEL_NAME, torch_dtype=torch_dtype
|
| 21 |
+
)
|
| 22 |
+
self.model.to(self.device)
|
| 23 |
+
self.model.eval()
|
| 24 |
+
self.name = "owlv2_base"
|
| 25 |
+
|
| 26 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 27 |
+
inputs = self.processor(text=queries, images=frame, return_tensors="pt")
|
| 28 |
+
if hasattr(inputs, "to"):
|
| 29 |
+
inputs = inputs.to(self.device)
|
| 30 |
+
else:
|
| 31 |
+
inputs = {
|
| 32 |
+
key: value.to(self.device) if hasattr(value, "to") else value
|
| 33 |
+
for key, value in inputs.items()
|
| 34 |
+
}
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
outputs = self.model(**inputs)
|
| 37 |
+
processed = self.processor.post_process_object_detection(
|
| 38 |
+
outputs, threshold=0.3, target_sizes=[frame.shape[:2]]
|
| 39 |
+
)[0]
|
| 40 |
+
boxes = processed["boxes"]
|
| 41 |
+
scores = processed.get("scores", [])
|
| 42 |
+
labels = processed.get("labels", [])
|
| 43 |
+
boxes_np = boxes.cpu().numpy() if hasattr(boxes, "cpu") else np.asarray(boxes)
|
| 44 |
+
if hasattr(scores, "cpu"):
|
| 45 |
+
scores_seq = scores.cpu().numpy().tolist()
|
| 46 |
+
elif isinstance(scores, np.ndarray):
|
| 47 |
+
scores_seq = scores.tolist()
|
| 48 |
+
else:
|
| 49 |
+
scores_seq = list(scores)
|
| 50 |
+
if hasattr(labels, "cpu"):
|
| 51 |
+
labels_seq = labels.cpu().numpy().tolist()
|
| 52 |
+
elif isinstance(labels, np.ndarray):
|
| 53 |
+
labels_seq = labels.tolist()
|
| 54 |
+
else:
|
| 55 |
+
labels_seq = list(labels)
|
| 56 |
+
return DetectionResult(boxes=boxes_np, scores=scores_seq, labels=labels_seq)
|
models/detectors/yolov8.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Sequence
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from ultralytics import YOLO
|
| 8 |
+
|
| 9 |
+
from models.detectors.base import DetectionResult, ObjectDetector
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HuggingFaceYoloV8Detector(ObjectDetector):
|
| 13 |
+
"""YOLOv8 detector whose weights are fetched from the Hugging Face Hub."""
|
| 14 |
+
|
| 15 |
+
REPO_ID = "Ultralytics/YOLOv8"
|
| 16 |
+
WEIGHT_FILE = "yolov8s.pt"
|
| 17 |
+
|
| 18 |
+
def __init__(self, score_threshold: float = 0.3) -> None:
|
| 19 |
+
self.name = "hf_yolov8"
|
| 20 |
+
self.score_threshold = score_threshold
|
| 21 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
logging.info(
|
| 23 |
+
"Loading Hugging Face YOLOv8 weights %s/%s onto %s",
|
| 24 |
+
self.REPO_ID,
|
| 25 |
+
self.WEIGHT_FILE,
|
| 26 |
+
self.device,
|
| 27 |
+
)
|
| 28 |
+
weight_path = hf_hub_download(repo_id=self.REPO_ID, filename=self.WEIGHT_FILE)
|
| 29 |
+
self.model = YOLO(weight_path)
|
| 30 |
+
self.model.to(self.device)
|
| 31 |
+
self.class_names = self.model.names
|
| 32 |
+
|
| 33 |
+
def _filter_indices(self, label_names: Sequence[str], queries: Sequence[str]) -> List[int]:
|
| 34 |
+
if not queries:
|
| 35 |
+
return list(range(len(label_names)))
|
| 36 |
+
allowed = {query.lower().strip() for query in queries if query}
|
| 37 |
+
keep = [idx for idx, name in enumerate(label_names) if name.lower() in allowed]
|
| 38 |
+
return keep or list(range(len(label_names)))
|
| 39 |
+
|
| 40 |
+
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
|
| 41 |
+
device_arg = 0 if self.device.startswith("cuda") else "cpu"
|
| 42 |
+
results = self.model.predict(
|
| 43 |
+
source=frame,
|
| 44 |
+
device=device_arg,
|
| 45 |
+
conf=self.score_threshold,
|
| 46 |
+
verbose=False,
|
| 47 |
+
)
|
| 48 |
+
result = results[0]
|
| 49 |
+
boxes = result.boxes
|
| 50 |
+
if boxes is None or boxes.xyxy is None:
|
| 51 |
+
empty = np.empty((0, 4), dtype=np.float32)
|
| 52 |
+
return DetectionResult(empty, [], [], [])
|
| 53 |
+
|
| 54 |
+
xyxy = boxes.xyxy.cpu().numpy()
|
| 55 |
+
scores = boxes.conf.cpu().numpy().tolist()
|
| 56 |
+
label_ids = boxes.cls.cpu().numpy().astype(int).tolist()
|
| 57 |
+
label_names = [self.class_names.get(idx, f"class_{idx}") for idx in label_ids]
|
| 58 |
+
keep_indices = self._filter_indices(label_names, queries)
|
| 59 |
+
xyxy = xyxy[keep_indices] if len(xyxy) else xyxy
|
| 60 |
+
scores = [scores[i] for i in keep_indices]
|
| 61 |
+
label_ids = [label_ids[i] for i in keep_indices]
|
| 62 |
+
label_names = [label_names[i] for i in keep_indices]
|
| 63 |
+
return DetectionResult(
|
| 64 |
+
boxes=xyxy,
|
| 65 |
+
scores=scores,
|
| 66 |
+
labels=label_ids,
|
| 67 |
+
label_names=label_names,
|
| 68 |
+
)
|
| 69 |
+
|
models/model_loader.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Callable, Dict, Optional
|
| 4 |
+
|
| 5 |
+
from models.detectors.base import ObjectDetector
|
| 6 |
+
from models.detectors.detr import DetrDetector
|
| 7 |
+
from models.detectors.grounding_dino import GroundingDinoDetector
|
| 8 |
+
from models.detectors.owlv2 import Owlv2Detector
|
| 9 |
+
from models.detectors.yolov8 import HuggingFaceYoloV8Detector
|
| 10 |
+
|
| 11 |
+
DEFAULT_DETECTOR = "owlv2_base"
|
| 12 |
+
|
| 13 |
+
_REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
|
| 14 |
+
"owlv2_base": Owlv2Detector,
|
| 15 |
+
"hf_yolov8": HuggingFaceYoloV8Detector,
|
| 16 |
+
"detr_resnet50": DetrDetector,
|
| 17 |
+
"grounding_dino": GroundingDinoDetector,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _create_detector(name: str) -> ObjectDetector:
|
| 22 |
+
try:
|
| 23 |
+
factory = _REGISTRY[name]
|
| 24 |
+
except KeyError as exc:
|
| 25 |
+
available = ", ".join(sorted(_REGISTRY))
|
| 26 |
+
raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc
|
| 27 |
+
return factory()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@lru_cache(maxsize=None)
|
| 31 |
+
def _get_cached_detector(name: str) -> ObjectDetector:
|
| 32 |
+
return _create_detector(name)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_detector(name: Optional[str] = None) -> ObjectDetector:
|
| 36 |
+
"""Return a cached detector instance selected via arg or OBJECT_DETECTOR env."""
|
| 37 |
+
detector_name = name or os.getenv("OBJECT_DETECTOR", DEFAULT_DETECTOR)
|
| 38 |
+
return _get_cached_detector(detector_name)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Backwards compatibility for existing callers.
|
| 42 |
+
def load_model():
|
| 43 |
+
return load_detector()
|
models/segmenters/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import Segmenter, SegmentationResult
|
| 2 |
+
from .model_loader import load_segmenter
|
| 3 |
+
from .sam3 import SAM3Segmenter
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"Segmenter",
|
| 7 |
+
"SegmentationResult",
|
| 8 |
+
"load_segmenter",
|
| 9 |
+
"SAM3Segmenter",
|
| 10 |
+
]
|
models/segmenters/base.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import NamedTuple, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SegmentationResult(NamedTuple):
|
| 7 |
+
"""Result from segmentation inference."""
|
| 8 |
+
masks: np.ndarray # NxHxW binary or soft masks
|
| 9 |
+
scores: Optional[np.ndarray] = None # Confidence scores
|
| 10 |
+
boxes: Optional[np.ndarray] = None # Bounding boxes (xyxy)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Segmenter:
|
| 14 |
+
"""Base interface for segmentation models."""
|
| 15 |
+
|
| 16 |
+
name: str
|
| 17 |
+
|
| 18 |
+
def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
|
| 19 |
+
"""
|
| 20 |
+
Run segmentation on a single frame.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
frame: Input image as numpy array (HxWxC)
|
| 24 |
+
text_prompts: Optional list of text prompts for segmentation
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
SegmentationResult with masks and optional metadata
|
| 28 |
+
"""
|
| 29 |
+
raise NotImplementedError
|
models/segmenters/model_loader.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from typing import Callable, Dict, Optional
|
| 4 |
+
|
| 5 |
+
from .base import Segmenter
|
| 6 |
+
from .sam3 import SAM3Segmenter
|
| 7 |
+
|
| 8 |
+
DEFAULT_SEGMENTER = "sam3"
|
| 9 |
+
|
| 10 |
+
_REGISTRY: Dict[str, Callable[[], Segmenter]] = {
|
| 11 |
+
"sam3": SAM3Segmenter,
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _create_segmenter(name: str) -> Segmenter:
|
| 16 |
+
"""Create a new segmenter instance."""
|
| 17 |
+
try:
|
| 18 |
+
factory = _REGISTRY[name]
|
| 19 |
+
except KeyError as exc:
|
| 20 |
+
available = ", ".join(sorted(_REGISTRY))
|
| 21 |
+
raise ValueError(
|
| 22 |
+
f"Unknown segmenter '{name}'. Available: {available}"
|
| 23 |
+
) from exc
|
| 24 |
+
return factory()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@lru_cache(maxsize=None)
|
| 28 |
+
def _get_cached_segmenter(name: str) -> Segmenter:
|
| 29 |
+
"""Get or create cached segmenter instance."""
|
| 30 |
+
return _create_segmenter(name)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_segmenter(name: Optional[str] = None) -> Segmenter:
|
| 34 |
+
"""
|
| 35 |
+
Load a segmenter by name.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
name: Segmenter name (default: sam3)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Cached segmenter instance
|
| 42 |
+
"""
|
| 43 |
+
segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
|
| 44 |
+
return _get_cached_segmenter(segmenter_name)
|
models/segmenters/sam3.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import Sam3Model, Sam3Processor
|
| 8 |
+
|
| 9 |
+
from .base import Segmenter, SegmentationResult
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SAM3Segmenter(Segmenter):
|
| 13 |
+
"""
|
| 14 |
+
SAM3 (Segment Anything Model 3) segmenter.
|
| 15 |
+
|
| 16 |
+
Performs automatic instance segmentation on images without prompts.
|
| 17 |
+
Uses facebook/sam3 model from HuggingFace.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
name = "sam3"
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
model_id: str = "facebook/sam3",
|
| 25 |
+
device: Optional[str] = None,
|
| 26 |
+
threshold: float = 0.5,
|
| 27 |
+
mask_threshold: float = 0.5,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize SAM3 segmenter.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_id: HuggingFace model ID
|
| 34 |
+
device: Device to run on (cuda/cpu), auto-detected if None
|
| 35 |
+
threshold: Confidence threshold for filtering instances
|
| 36 |
+
mask_threshold: Threshold for binarizing masks
|
| 37 |
+
"""
|
| 38 |
+
self.device = device or (
|
| 39 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 40 |
+
)
|
| 41 |
+
self.threshold = threshold
|
| 42 |
+
self.mask_threshold = mask_threshold
|
| 43 |
+
|
| 44 |
+
logging.info(
|
| 45 |
+
"Loading SAM3 model %s on device %s", model_id, self.device
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
self.model = Sam3Model.from_pretrained(model_id).to(self.device)
|
| 50 |
+
self.processor = Sam3Processor.from_pretrained(model_id)
|
| 51 |
+
self.model.eval()
|
| 52 |
+
except Exception:
|
| 53 |
+
logging.exception("Failed to load SAM3 model")
|
| 54 |
+
raise
|
| 55 |
+
|
| 56 |
+
logging.info("SAM3 model loaded successfully")
|
| 57 |
+
|
| 58 |
+
def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
|
| 59 |
+
"""
|
| 60 |
+
Run SAM3 segmentation on a frame.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
frame: Input image (HxWx3 numpy array in RGB)
|
| 64 |
+
text_prompts: List of text prompts for segmentation
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
SegmentationResult with instance masks
|
| 68 |
+
"""
|
| 69 |
+
# Convert numpy array to PIL Image
|
| 70 |
+
if frame.dtype == np.uint8:
|
| 71 |
+
pil_image = Image.fromarray(frame)
|
| 72 |
+
else:
|
| 73 |
+
# Normalize to 0-255 if needed
|
| 74 |
+
frame_uint8 = (frame * 255).astype(np.uint8)
|
| 75 |
+
pil_image = Image.fromarray(frame_uint8)
|
| 76 |
+
|
| 77 |
+
# Use default prompts if none provided
|
| 78 |
+
if not text_prompts:
|
| 79 |
+
text_prompts = ["object"]
|
| 80 |
+
|
| 81 |
+
# Process image with text prompts
|
| 82 |
+
inputs = self.processor(
|
| 83 |
+
images=pil_image, text=text_prompts, return_tensors="pt"
|
| 84 |
+
).to(self.device)
|
| 85 |
+
|
| 86 |
+
# Run inference
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
outputs = self.model(**inputs)
|
| 89 |
+
|
| 90 |
+
# Post-process to get instance masks
|
| 91 |
+
try:
|
| 92 |
+
results = self.processor.post_process_instance_segmentation(
|
| 93 |
+
outputs,
|
| 94 |
+
threshold=self.threshold,
|
| 95 |
+
mask_threshold=self.mask_threshold,
|
| 96 |
+
target_sizes=inputs.get("original_sizes").tolist(),
|
| 97 |
+
)[0]
|
| 98 |
+
|
| 99 |
+
# Extract results
|
| 100 |
+
masks = results.get("masks", [])
|
| 101 |
+
scores = results.get("scores", None)
|
| 102 |
+
boxes = results.get("boxes", None)
|
| 103 |
+
|
| 104 |
+
# Convert to numpy arrays
|
| 105 |
+
if len(masks) > 0:
|
| 106 |
+
# Stack masks: list of (H, W) -> (N, H, W)
|
| 107 |
+
masks_array = np.stack([m.cpu().numpy() for m in masks])
|
| 108 |
+
else:
|
| 109 |
+
# No objects detected
|
| 110 |
+
masks_array = np.zeros(
|
| 111 |
+
(0, frame.shape[0], frame.shape[1]), dtype=bool
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
scores_array = (
|
| 115 |
+
scores.cpu().numpy() if scores is not None else None
|
| 116 |
+
)
|
| 117 |
+
boxes_array = (
|
| 118 |
+
boxes.cpu().numpy() if boxes is not None else None
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return SegmentationResult(
|
| 122 |
+
masks=masks_array,
|
| 123 |
+
scores=scores_array,
|
| 124 |
+
boxes=boxes_array,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
except Exception:
|
| 128 |
+
logging.exception("SAM3 post-processing failed")
|
| 129 |
+
# Return empty result
|
| 130 |
+
return SegmentationResult(
|
| 131 |
+
masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
|
| 132 |
+
scores=None,
|
| 133 |
+
boxes=None,
|
| 134 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
torch
|
| 4 |
+
transformers
|
| 5 |
+
opencv-python-headless
|
| 6 |
+
python-multipart
|
| 7 |
+
accelerate
|
| 8 |
+
pillow
|
| 9 |
+
scipy
|
| 10 |
+
huggingface-hub
|
| 11 |
+
ultralytics
|
| 12 |
+
timm
|
| 13 |
+
ffmpeg-python
|
utils/video.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import subprocess
|
| 5 |
+
import tempfile
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extract_frames(video_path: str) -> Tuple[List[np.ndarray], float, int, int]:
|
| 13 |
+
cap = cv2.VideoCapture(video_path)
|
| 14 |
+
if not cap.isOpened():
|
| 15 |
+
raise ValueError("Unable to open video.")
|
| 16 |
+
|
| 17 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 0.0
|
| 18 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 19 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 20 |
+
|
| 21 |
+
frames: List[np.ndarray] = []
|
| 22 |
+
success, frame = cap.read()
|
| 23 |
+
while success:
|
| 24 |
+
frames.append(frame)
|
| 25 |
+
success, frame = cap.read()
|
| 26 |
+
|
| 27 |
+
cap.release()
|
| 28 |
+
|
| 29 |
+
if not frames:
|
| 30 |
+
raise ValueError("Video decode produced zero frames.")
|
| 31 |
+
|
| 32 |
+
return frames, fps, width, height
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _transcode_with_ffmpeg(src_path: str, dst_path: str) -> None:
|
| 36 |
+
cmd = [
|
| 37 |
+
"ffmpeg",
|
| 38 |
+
"-y",
|
| 39 |
+
"-i",
|
| 40 |
+
src_path,
|
| 41 |
+
"-c:v",
|
| 42 |
+
"libx264",
|
| 43 |
+
"-preset",
|
| 44 |
+
"veryfast",
|
| 45 |
+
"-pix_fmt",
|
| 46 |
+
"yuv420p",
|
| 47 |
+
"-movflags",
|
| 48 |
+
"+faststart",
|
| 49 |
+
dst_path,
|
| 50 |
+
]
|
| 51 |
+
process = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
|
| 52 |
+
if process.returncode != 0:
|
| 53 |
+
raise RuntimeError(process.stderr.decode("utf-8", errors="ignore"))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def write_video(frames: List[np.ndarray], output_path: str, fps: float, width: int, height: int) -> None:
|
| 57 |
+
if not frames:
|
| 58 |
+
raise ValueError("No frames available for writing.")
|
| 59 |
+
temp_fd, temp_path = tempfile.mkstemp(prefix="raw_", suffix=".mp4")
|
| 60 |
+
os.close(temp_fd)
|
| 61 |
+
writer = cv2.VideoWriter(temp_path, cv2.VideoWriter_fourcc(*"mp4v"), fps or 1.0, (width, height))
|
| 62 |
+
if not writer.isOpened():
|
| 63 |
+
os.remove(temp_path)
|
| 64 |
+
raise ValueError("Failed to open VideoWriter.")
|
| 65 |
+
|
| 66 |
+
for frame in frames:
|
| 67 |
+
writer.write(frame)
|
| 68 |
+
|
| 69 |
+
writer.release()
|
| 70 |
+
try:
|
| 71 |
+
_transcode_with_ffmpeg(temp_path, output_path)
|
| 72 |
+
logging.debug("Transcoded video to H.264 for browser compatibility.")
|
| 73 |
+
os.remove(temp_path)
|
| 74 |
+
except FileNotFoundError:
|
| 75 |
+
logging.warning("ffmpeg not found; serving fallback MP4V output.")
|
| 76 |
+
shutil.move(temp_path, output_path)
|
| 77 |
+
except RuntimeError as exc:
|
| 78 |
+
logging.warning("ffmpeg transcode failed (%s); serving fallback MP4V output.", exc)
|
| 79 |
+
shutil.move(temp_path, output_path)
|