Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| ROI-aware compression server (FastAPI) | |
| - Uploads a video and prompt | |
| - Runs YOLOv8x detection + simple tracking | |
| - Produces 3 outputs: overlay (tracking), compressed, ROI-preserved | |
| - Serves MJPEG stream of live overlay | |
| Endpoints: | |
| POST /track/async | |
| POST /process/compress/{job_id} | |
| GET /process/status/{job_id} | |
| GET /process/video/overlay/{job_id} | |
| GET /process/video/compressed/{job_id} | |
| GET /process/video/roi/{job_id} | |
| GET /detect/stream/{job_id} | |
| """ | |
| import os | |
| import uuid | |
| import time | |
| import math | |
| import threading | |
| import shutil | |
| import subprocess | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Any | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
| from ultralytics import YOLO, RTDETR | |
| DEFAULT_WEIGHTS = os.environ.get("YOLO_WEIGHTS", "yolov8s.pt") | |
| WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", os.path.dirname(__file__)) | |
| DEFAULT_CONF = float(os.environ.get("YOLO_CONF", "0.25")) | |
| DEFAULT_DEVICE = os.environ.get("YOLO_DEVICE", "auto") | |
| FAST_DETECT_SCALE = float(os.environ.get("FAST_DETECT_SCALE", "0.65")) | |
| FAST_DETECT_IMGSZ = int(os.environ.get("FAST_DETECT_IMGSZ", "512")) | |
| DATA_DIR = os.environ.get("DATA_DIR", "/tmp/roi_demo") | |
| UPLOAD_DIR = os.path.join(DATA_DIR, "uploads") | |
| OUTPUT_DIR = os.path.join(DATA_DIR, "outputs") | |
| app = FastAPI(title="ROI Compression Server", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def root(): | |
| return {"status": "ok", "service": "roi-compression"} | |
| _model_lock = threading.Lock() | |
| _models: Dict[str, Any] = {} | |
| def _infer_model_type(weights: str) -> str: | |
| name = os.path.basename(str(weights or "")).lower() | |
| if name.startswith("rtdetr"): | |
| return "rtdetr" | |
| return "yolo" | |
| def _resolve_weights_path(weights: str) -> (str, List[str]): | |
| if not weights: | |
| return DEFAULT_WEIGHTS, [] | |
| w = str(weights).strip() | |
| if not w: | |
| return DEFAULT_WEIGHTS, [] | |
| if os.path.isabs(w) and os.path.exists(w): | |
| return os.path.abspath(w), [os.path.abspath(w)] | |
| if os.path.exists(w): | |
| return os.path.abspath(w), [os.path.abspath(w)] | |
| search_dirs: List[str] = [] | |
| if WEIGHTS_DIR: | |
| search_dirs.append(WEIGHTS_DIR) | |
| search_dirs.extend([ | |
| os.getcwd(), | |
| os.path.dirname(__file__), | |
| os.path.abspath(os.path.dirname(__file__)), | |
| os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), | |
| DATA_DIR, | |
| "/home/user/app", | |
| "/app", | |
| "/workspace", | |
| "/data", | |
| ]) | |
| checked: List[str] = [] | |
| for base in search_dirs: | |
| if not base: | |
| continue | |
| cand = os.path.join(base, w) | |
| checked.append(cand) | |
| if os.path.exists(cand): | |
| return os.path.abspath(cand), checked | |
| return w, checked | |
| def get_model(weights: str) -> Any: | |
| key, checked = _resolve_weights_path(weights or DEFAULT_WEIGHTS) | |
| model_type = _infer_model_type(key) | |
| if str(key).endswith(".pt") and not os.path.exists(key): | |
| search_list = ", ".join(checked) if checked else "(no local paths searched)" | |
| raise RuntimeError( | |
| f"Weights not found locally: {weights}. Searched: {search_list}. " | |
| f"Set WEIGHTS_DIR or upload the weights to the app directory." | |
| ) | |
| with _model_lock: | |
| cache_key = f"{model_type}:{key}" | |
| if cache_key not in _models: | |
| if model_type == "rtdetr": | |
| _models[cache_key] = RTDETR(key) | |
| else: | |
| _models[cache_key] = YOLO(key) | |
| return _models[cache_key] | |
| def _parse_queries(q: str) -> List[str]: | |
| if not q: | |
| return [] | |
| parts = [p.strip().lower() for p in q.replace("\n", ",").split(",")] | |
| return [p for p in parts if p] | |
| def _keep_det(label: str, queries: List[str]) -> bool: | |
| if not queries: | |
| return True | |
| l = (label or "").strip().lower() | |
| if not l: | |
| return False | |
| return any((q == l) or (q in l) or (l in q) for q in queries) | |
| def _yolo_detect_frame( | |
| model: Any, | |
| frame_bgr: np.ndarray, | |
| conf: float, | |
| queries: List[str], | |
| device: str, | |
| fast_mode: bool = False, | |
| ) -> List[Dict[str, Any]]: | |
| scale = 1.0 | |
| if fast_mode: | |
| scale = max(0.1, min(1.0, float(FAST_DETECT_SCALE))) | |
| if scale < 1.0: | |
| h, w = frame_bgr.shape[:2] | |
| sw, sh = max(64, int(w * scale)), max(64, int(h * scale)) | |
| small = cv2.resize(frame_bgr, (sw, sh), interpolation=cv2.INTER_AREA) | |
| img = cv2.cvtColor(small, cv2.COLOR_BGR2RGB) | |
| else: | |
| img = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| pred_kwargs = {"conf": conf, "verbose": False} | |
| if fast_mode: | |
| pred_kwargs["imgsz"] = FAST_DETECT_IMGSZ | |
| if device and str(device).lower() != "auto": | |
| pred_kwargs["device"] = device | |
| if fast_mode and str(device).lower() != "cpu": | |
| pred_kwargs["half"] = True | |
| try: | |
| res = model.predict(img, **pred_kwargs) | |
| except Exception as e: | |
| msg = str(e) | |
| if ("cuda" in msg.lower()) and (str(device).lower() != "cpu"): | |
| pred_kwargs["device"] = "cpu" | |
| res = model.predict(img, **pred_kwargs) | |
| else: | |
| raise | |
| if not res: | |
| return [] | |
| r0 = res[0] | |
| names = getattr(r0, "names", None) or getattr(model, "names", None) or {} | |
| boxes = [] | |
| if r0.boxes is None: | |
| return boxes | |
| for b in r0.boxes: | |
| try: | |
| xyxy = b.xyxy[0].cpu().numpy().tolist() | |
| if scale < 1.0: | |
| inv = 1.0 / scale | |
| xyxy = [v * inv for v in xyxy] | |
| score = float(b.conf[0].cpu().numpy()) | |
| cls_i = int(b.cls[0].cpu().numpy()) | |
| label = str(names.get(cls_i, cls_i)) | |
| if not _keep_det(label, queries): | |
| continue | |
| boxes.append({"bbox_xyxy": xyxy, "label": label, "score": score}) | |
| except Exception: | |
| continue | |
| return boxes | |
| def _draw_boxes(frame_bgr: np.ndarray, dets: List[Dict[str, Any]]) -> np.ndarray: | |
| out = frame_bgr.copy() | |
| for d in dets: | |
| b = d.get("bbox_xyxy") | |
| if not (isinstance(b, (list, tuple)) and len(b) == 4): | |
| continue | |
| x1, y1, x2, y2 = [int(max(0, v)) for v in b] | |
| label = str(d.get("label", "")) | |
| score = d.get("score", None) | |
| tid = d.get("track_id", None) | |
| tag = f"#{tid}" if isinstance(tid, int) else "" | |
| txt = f"{label}{tag} {score:.2f}" if isinstance(score, (float, int)) else f"{label}{tag}" | |
| cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| if txt: | |
| cv2.putText(out, txt, (x1, max(12, y1 - 6)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA) | |
| return out | |
| def _iou_xyxy(a: List[float], b: List[float]) -> float: | |
| ax1, ay1, ax2, ay2 = a | |
| bx1, by1, bx2, by2 = b | |
| inter_x1 = max(ax1, bx1) | |
| inter_y1 = max(ay1, by1) | |
| inter_x2 = min(ax2, bx2) | |
| inter_y2 = min(ay2, by2) | |
| if inter_x2 <= inter_x1 or inter_y2 <= inter_y1: | |
| return 0.0 | |
| inter = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) | |
| area_a = max(0.0, (ax2 - ax1)) * max(0.0, (ay2 - ay1)) | |
| area_b = max(0.0, (bx2 - bx1)) * max(0.0, (by2 - by1)) | |
| denom = area_a + area_b - inter | |
| if denom <= 0: | |
| return 0.0 | |
| return float(inter / denom) | |
| def _assign_tracks(dets: List[Dict[str, Any]], tracker: Dict[str, Any], iou_thresh: float = 0.3) -> List[Dict[str, Any]]: | |
| prev = tracker.get("tracks", []) | |
| used_prev = set() | |
| out = [] | |
| for d in dets: | |
| b = d.get("bbox_xyxy") | |
| label = str(d.get("label", "")) | |
| best_i = None | |
| best_iou = 0.0 | |
| if isinstance(b, (list, tuple)) and len(b) == 4: | |
| for i, tr in enumerate(prev): | |
| if i in used_prev: | |
| continue | |
| if label and tr.get("label") and tr.get("label") != label: | |
| continue | |
| iou = _iou_xyxy(b, tr.get("bbox_xyxy", [0, 0, 0, 0])) | |
| if iou > best_iou: | |
| best_iou = iou | |
| best_i = i | |
| if best_i is not None and best_iou >= iou_thresh: | |
| d["track_id"] = int(prev[best_i].get("id")) | |
| used_prev.add(best_i) | |
| else: | |
| d["track_id"] = int(tracker.get("next_id", 1)) | |
| tracker["next_id"] = int(d["track_id"]) + 1 | |
| out.append(d) | |
| tracker["tracks"] = [ | |
| {"id": int(d.get("track_id")), "bbox_xyxy": d.get("bbox_xyxy"), "label": d.get("label", "")} | |
| for d in out | |
| ] | |
| return out | |
| def _ensure_even(v: int, min_v: int = 64) -> int: | |
| v = max(min_v, int(v)) | |
| return v - (v % 2) | |
| def _fit_aspect(w: int, h: int, target_w: int, target_h: int) -> Optional[List[int]]: | |
| if w <= 0 or h <= 0: | |
| return None | |
| if target_w and target_h: | |
| scale = min(float(target_w) / float(w), float(target_h) / float(h)) | |
| elif target_w: | |
| scale = float(target_w) / float(w) | |
| elif target_h: | |
| scale = float(target_h) / float(h) | |
| else: | |
| return None | |
| if not math.isfinite(scale) or scale <= 0: | |
| return None | |
| return [int(w * scale), int(h * scale)] | |
| def _compute_target_params(w: int, h: int, fps: float, bandwidth_kbps: int, target_fps: int, target_w: int, target_h: int, scale: float): | |
| fps = max(1.0, float(fps or 1.0)) | |
| budget = max(100, int(bandwidth_kbps or 1500)) | |
| base_kbps_720p30 = 2500.0 | |
| base_kbps_orig = base_kbps_720p30 * (float(w) * float(h) * fps) / (1280.0 * 720.0 * 30.0) | |
| if not math.isfinite(base_kbps_orig) or base_kbps_orig <= 0: | |
| base_kbps_orig = base_kbps_720p30 | |
| if target_w or target_h: | |
| fitted = _fit_aspect(w, h, int(target_w or 0), int(target_h or 0)) | |
| if fitted: | |
| tw, th = fitted | |
| else: | |
| tw, th = int(target_w or w), int(target_h or h) | |
| else: | |
| scale = float(scale or 1.0) | |
| if scale < 0.1: | |
| scale = 0.1 | |
| if scale > 1.0: | |
| scale = 1.0 | |
| tw, th = int(w * scale), int(h * scale) | |
| tfps = int(target_fps or fps) | |
| scale_r = min(1.0, math.sqrt(budget / base_kbps_orig)) | |
| tw = min(tw, int(w * scale_r)) | |
| th = min(th, int(h * scale_r)) | |
| tfps = min(int(fps), tfps) | |
| tw = _ensure_even(max(64, tw)) | |
| th = _ensure_even(max(64, th)) | |
| tfps = max(1, tfps) | |
| frame_step = max(1, int(round(fps / max(1, tfps)))) | |
| return tw, th, tfps, frame_step | |
| def _open_writer(path: str, w: int, h: int, fps: float) -> Optional[cv2.VideoWriter]: | |
| if w <= 0 or h <= 0: | |
| return None | |
| # for codec in ("avc1", "H264", "mp4v"): | |
| # try: | |
| # fourcc = cv2.VideoWriter_fourcc(*codec) | |
| # wtmp = cv2.VideoWriter(path, fourcc, float(fps or 30.0), (int(w), int(h))) | |
| # if wtmp is not None and wtmp.isOpened(): | |
| # return wtmp | |
| # except Exception: | |
| # continue | |
| # Force software-friendly codec to avoid hardware H.264 failures on some systems. | |
| try: | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| wtmp = cv2.VideoWriter(path, fourcc, float(fps or 30.0), (int(w), int(h))) | |
| if wtmp is not None and wtmp.isOpened(): | |
| return wtmp | |
| except Exception: | |
| pass | |
| return None | |
| def _ffmpeg_available() -> bool: | |
| return shutil.which("ffmpeg") is not None | |
| def _transcode_h264(src_path: str) -> Optional[str]: | |
| if not src_path or not os.path.exists(src_path): | |
| return None | |
| if not _ffmpeg_available(): | |
| return None | |
| dst_path = os.path.splitext(src_path)[0] + "_h264.mp4" | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", | |
| src_path, | |
| "-c:v", | |
| "libx264", | |
| "-preset", | |
| "veryfast", | |
| "-pix_fmt", | |
| "yuv420p", | |
| dst_path, | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| if os.path.exists(dst_path) and os.path.getsize(dst_path) > 1024: | |
| return dst_path | |
| except Exception: | |
| return None | |
| return None | |
| def _apply_roi_overlay(frame_bgr: np.ndarray, dets: List[Dict[str, Any]], target_w: int, target_h: int) -> np.ndarray: | |
| h, w = frame_bgr.shape[:2] | |
| bg_small = cv2.resize(frame_bgr, (int(target_w), int(target_h)), interpolation=cv2.INTER_AREA) | |
| bg = cv2.resize(bg_small, (int(w), int(h)), interpolation=cv2.INTER_LINEAR) | |
| out = bg.copy() | |
| pad = max(2, int(min(w, h) * 0.005)) | |
| for d in dets: | |
| b = d.get("bbox_xyxy") | |
| if not (isinstance(b, (list, tuple)) and len(b) == 4): | |
| continue | |
| x1, y1, x2, y2 = [int(v) for v in b] | |
| x1 = max(0, x1 - pad) | |
| y1 = max(0, y1 - pad) | |
| x2 = min(w, x2 + pad) | |
| y2 = min(h, y2 + pad) | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| out[y1:y2, x1:x2] = frame_bgr[y1:y2, x1:x2] | |
| return out | |
| class Job: | |
| id: str | |
| video_path: str | |
| created: float = field(default_factory=time.time) | |
| status: str = "tracking" | |
| error: Optional[str] = None | |
| fps: float = 30.0 | |
| w: int = 0 | |
| h: int = 0 | |
| frame_step: int = 1 | |
| target_fps: int = 15 | |
| target_width: int = 0 | |
| target_height: int = 0 | |
| bandwidth_kbps: int = 1500 | |
| conf: float = DEFAULT_CONF | |
| weights: str = DEFAULT_WEIGHTS | |
| device: str = DEFAULT_DEVICE | |
| fast_mode: bool = False | |
| queries: List[str] = field(default_factory=list) | |
| overlay_video_path: Optional[str] = None | |
| compressed_video_path: Optional[str] = None | |
| roi_video_path: Optional[str] = None | |
| det_by_frame: Dict[int, List[Dict[str, Any]]] = field(default_factory=dict) | |
| latest_jpeg: Optional[bytes] = None | |
| latest_compressed_jpeg: Optional[bytes] = None | |
| latest_roi_jpeg: Optional[bytes] = None | |
| lock: threading.Lock = field(default_factory=threading.Lock) | |
| tracker_state: Dict[str, Any] = field(default_factory=lambda: {"next_id": 1, "tracks": []}) | |
| jobs: Dict[str, Job] = {} | |
| def _process_job(job: Job): | |
| try: | |
| model = get_model(job.weights) | |
| cap = cv2.VideoCapture(job.video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError("Could not open video.") | |
| fps = float(cap.get(cv2.CAP_PROP_FPS) or 30.0) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| tw, th, tfps, frame_step = _compute_target_params( | |
| w=w, | |
| h=h, | |
| fps=fps, | |
| bandwidth_kbps=job.bandwidth_kbps, | |
| target_fps=job.target_fps, | |
| target_w=job.target_width, | |
| target_h=job.target_height, | |
| scale=max(0.25, min(1.0, (job.target_width / w) if (job.target_width and w) else 1.0)), | |
| ) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| overlay_path = os.path.join(OUTPUT_DIR, f"{job.id}_overlay.mp4") | |
| overlay_writer = _open_writer(overlay_path, w, h, fps) | |
| with job.lock: | |
| job.fps = fps | |
| job.w = w | |
| job.h = h | |
| job.frame_step = frame_step | |
| job.target_fps = tfps | |
| job.target_width = tw | |
| job.target_height = th | |
| job.overlay_video_path = overlay_path if overlay_writer is not None else None | |
| job.status = "tracking" | |
| frame_idx = 0 | |
| tracker = job.tracker_state | |
| last_dets: List[Dict[str, Any]] = [] | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| if frame_idx % frame_step == 0: | |
| dets = _yolo_detect_frame(model, frame, conf=job.conf, queries=job.queries, device=job.device, fast_mode=job.fast_mode) | |
| if dets and not any("track_id" in d for d in dets): | |
| dets = _assign_tracks(dets, tracker) | |
| elif dets: | |
| tracker["tracks"] = [ | |
| {"id": int(d.get("track_id")), "bbox_xyxy": d.get("bbox_xyxy"), "label": d.get("label", "")} | |
| for d in dets | |
| ] | |
| max_id = max((int(d.get("track_id", 0)) for d in dets), default=0) | |
| tracker["next_id"] = max(tracker.get("next_id", 1), max_id + 1) | |
| with job.lock: | |
| job.det_by_frame[int(frame_idx)] = dets | |
| last_dets = dets | |
| else: | |
| dets = last_dets | |
| overlay = _draw_boxes(frame, dets or []) | |
| ok2, jpg = cv2.imencode(".jpg", overlay, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) | |
| if ok2: | |
| with job.lock: | |
| job.latest_jpeg = jpg.tobytes() | |
| if overlay_writer is not None: | |
| overlay_writer.write(overlay) | |
| frame_idx += 1 | |
| cap.release() | |
| if overlay_writer is not None: | |
| try: | |
| overlay_writer.release() | |
| except Exception: | |
| pass | |
| h264_overlay = _transcode_h264(overlay_path) if overlay_writer is not None else None | |
| with job.lock: | |
| if h264_overlay: | |
| job.overlay_video_path = h264_overlay | |
| job.status = "tracked" | |
| except Exception as e: | |
| with job.lock: | |
| job.status = "error" | |
| job.error = str(e) | |
| def _compress_job(job: Job, bandwidth_kbps: int, target_fps: int, target_w: int, target_h: int, resolution_scale: float): | |
| try: | |
| cap = cv2.VideoCapture(job.video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError("Could not open video.") | |
| fps = float(cap.get(cv2.CAP_PROP_FPS) or 30.0) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| tw, th, tfps, frame_step = _compute_target_params( | |
| w=w, | |
| h=h, | |
| fps=fps, | |
| bandwidth_kbps=bandwidth_kbps, | |
| target_fps=target_fps, | |
| target_w=target_w, | |
| target_h=target_h, | |
| scale=resolution_scale, | |
| ) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| compressed_path = os.path.join(OUTPUT_DIR, f"{job.id}_compressed_rt.mp4") | |
| roi_path = os.path.join(OUTPUT_DIR, f"{job.id}_roi_rt.mp4") | |
| compressed_writer = _open_writer(compressed_path, tw, th, tfps) | |
| roi_writer = _open_writer(roi_path, w, h, tfps) | |
| with job.lock: | |
| job.status = "compressing" | |
| job.bandwidth_kbps = int(bandwidth_kbps) | |
| job.target_fps = int(tfps) | |
| job.target_width = int(tw) | |
| job.target_height = int(th) | |
| frame_idx = 0 | |
| last_dets: List[Dict[str, Any]] = [] | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| if frame_idx % frame_step != 0: | |
| frame_idx += 1 | |
| continue | |
| dets = job.det_by_frame.get(int(frame_idx)) | |
| if dets is None: | |
| dets = last_dets | |
| else: | |
| last_dets = dets | |
| compressed_frame = None | |
| roi_frame = None | |
| if compressed_writer is not None: | |
| try: | |
| compressed_frame = cv2.resize(frame, (tw, th), interpolation=cv2.INTER_AREA) | |
| compressed_writer.write(compressed_frame) | |
| except Exception: | |
| compressed_frame = None | |
| if roi_writer is not None: | |
| try: | |
| roi_frame = _apply_roi_overlay(frame, dets, tw, th) | |
| roi_writer.write(roi_frame) | |
| except Exception: | |
| roi_frame = None | |
| try: | |
| if compressed_frame is not None: | |
| okc, jpgc = cv2.imencode(".jpg", compressed_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) | |
| if okc: | |
| with job.lock: | |
| job.latest_compressed_jpeg = jpgc.tobytes() | |
| if roi_frame is not None: | |
| okr, jpgr = cv2.imencode(".jpg", roi_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) | |
| if okr: | |
| with job.lock: | |
| job.latest_roi_jpeg = jpgr.tobytes() | |
| except Exception: | |
| pass | |
| frame_idx += 1 | |
| cap.release() | |
| for wtr in (compressed_writer, roi_writer): | |
| if wtr is not None: | |
| try: | |
| wtr.release() | |
| except Exception: | |
| pass | |
| h264_compressed = _transcode_h264(compressed_path) if compressed_writer is not None else None | |
| h264_roi = _transcode_h264(roi_path) if roi_writer is not None else None | |
| with job.lock: | |
| if h264_compressed: | |
| job.compressed_video_path = h264_compressed | |
| else: | |
| job.compressed_video_path = compressed_path if os.path.exists(compressed_path) else job.compressed_video_path | |
| if h264_roi: | |
| job.roi_video_path = h264_roi | |
| else: | |
| job.roi_video_path = roi_path if os.path.exists(roi_path) else job.roi_video_path | |
| job.status = "completed" | |
| except Exception as e: | |
| with job.lock: | |
| job.status = "error" | |
| job.error = str(e) | |
| async def track_async( | |
| video: UploadFile = File(...), | |
| queries: str = Form(""), | |
| conf: float = Form(DEFAULT_CONF), | |
| weights: str = Form(DEFAULT_WEIGHTS), | |
| device: str = Form(""), | |
| fast_mode: bool = Form(False), | |
| bandwidth_kbps: int = Form(1500), | |
| target_fps: int = Form(15), | |
| target_width: int = Form(0), | |
| target_height: int = Form(0), | |
| resolution_scale: float = Form(1.0), | |
| ): | |
| job_id = uuid.uuid4().hex[:12] | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| dst = os.path.join(UPLOAD_DIR, f"{job_id}_{os.path.basename(video.filename or 'input.mp4')}") | |
| data = await video.read() | |
| with open(dst, "wb") as f: | |
| f.write(data) | |
| job = Job( | |
| id=job_id, | |
| video_path=dst, | |
| status="tracking", | |
| conf=float(conf), | |
| weights=str(weights), | |
| device=str(device).strip() or DEFAULT_DEVICE, | |
| queries=_parse_queries(queries), | |
| fast_mode=bool(fast_mode), | |
| target_fps=int(target_fps or 15), | |
| bandwidth_kbps=int(bandwidth_kbps or 1500), | |
| target_width=int(target_width or 0), | |
| target_height=int(target_height or 0), | |
| ) | |
| jobs[job_id] = job | |
| # fast preview for MJPEG | |
| try: | |
| cap = cv2.VideoCapture(dst) | |
| ok, frame0 = cap.read() | |
| cap.release() | |
| if ok and frame0 is not None: | |
| model = get_model(job.weights) | |
| det0 = _yolo_detect_frame(model, frame0, conf=job.conf, queries=job.queries, device=job.device, fast_mode=job.fast_mode) | |
| det0 = _assign_tracks(det0, job.tracker_state) | |
| with job.lock: | |
| job.det_by_frame[0] = det0 | |
| vis0 = _draw_boxes(frame0, det0) | |
| ok2, jpg = cv2.imencode(".jpg", vis0, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) | |
| if ok2: | |
| with job.lock: | |
| job.latest_jpeg = jpg.tobytes() | |
| except Exception: | |
| pass | |
| t = threading.Thread(target=_process_job, args=(job,), daemon=True) | |
| t.start() | |
| return JSONResponse({ | |
| "job_id": job_id, | |
| "status_url": f"/process/status/{job_id}", | |
| "stream_url": f"/detect/stream/{job_id}", | |
| "overlay_video_url": f"/process/video/overlay/{job_id}", | |
| "compressed_video_url": f"/process/video/compressed/{job_id}", | |
| "roi_video_url": f"/process/video/roi/{job_id}", | |
| }) | |
| async def process_compress( | |
| job_id: str, | |
| bandwidth_kbps: int = Form(1500), | |
| target_fps: int = Form(15), | |
| target_width: int = Form(0), | |
| target_height: int = Form(0), | |
| resolution_scale: float = Form(1.0), | |
| ): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| with job.lock: | |
| if job.status in ("tracking", "compressing"): | |
| raise HTTPException(status_code=409, detail="Job still running") | |
| if job.status not in ("tracked", "completed"): | |
| raise HTTPException(status_code=409, detail="Tracking not ready") | |
| t = threading.Thread( | |
| target=_compress_job, | |
| args=(job, int(bandwidth_kbps), int(target_fps), int(target_width), int(target_height), float(resolution_scale)), | |
| daemon=True, | |
| ) | |
| t.start() | |
| return JSONResponse({"job_id": job_id, "status": "compressing"}) | |
| def process_status(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| with job.lock: | |
| return { | |
| "job_id": job.id, | |
| "status": job.status, | |
| "error": job.error, | |
| "target_width": job.target_width, | |
| "target_height": job.target_height, | |
| "target_fps": job.target_fps, | |
| "bandwidth_kbps": job.bandwidth_kbps, | |
| } | |
| def _mjpeg_generator(job: Job): | |
| boundary = b"--frame" | |
| while True: | |
| with job.lock: | |
| jpg = job.latest_jpeg | |
| status = job.status | |
| err = job.error | |
| if err: | |
| break | |
| if jpg: | |
| yield boundary + b"\r\n" | |
| yield b"Content-Type: image/jpeg\r\n" | |
| yield f"Content-Length: {len(jpg)}\r\n\r\n".encode("ascii") | |
| yield jpg + b"\r\n" | |
| time.sleep(0.15) | |
| if status in ("completed", "error"): | |
| time.sleep(0.5) | |
| break | |
| def _mjpeg_generator_compressed(job: Job): | |
| boundary = b"--frame" | |
| while True: | |
| with job.lock: | |
| jpg = job.latest_compressed_jpeg | |
| status = job.status | |
| err = job.error | |
| if err: | |
| break | |
| if jpg: | |
| yield boundary + b"\r\n" | |
| yield b"Content-Type: image/jpeg\r\n" | |
| yield f"Content-Length: {len(jpg)}\r\n\r\n".encode("ascii") | |
| yield jpg + b"\r\n" | |
| time.sleep(0.15) | |
| if status in ("completed", "error"): | |
| time.sleep(0.5) | |
| break | |
| def _mjpeg_generator_roi(job: Job): | |
| boundary = b"--frame" | |
| while True: | |
| with job.lock: | |
| jpg = job.latest_roi_jpeg | |
| status = job.status | |
| err = job.error | |
| if err: | |
| break | |
| if jpg: | |
| yield boundary + b"\r\n" | |
| yield b"Content-Type: image/jpeg\r\n" | |
| yield f"Content-Length: {len(jpg)}\r\n\r\n".encode("ascii") | |
| yield jpg + b"\r\n" | |
| time.sleep(0.15) | |
| if status in ("completed", "error"): | |
| time.sleep(0.5) | |
| break | |
| def detect_stream(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| return StreamingResponse(_mjpeg_generator(job), media_type="multipart/x-mixed-replace; boundary=frame") | |
| def process_stream_compressed(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| return StreamingResponse(_mjpeg_generator_compressed(job), media_type="multipart/x-mixed-replace; boundary=frame") | |
| def process_stream_roi(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| return StreamingResponse(_mjpeg_generator_roi(job), media_type="multipart/x-mixed-replace; boundary=frame") | |
| def process_video_overlay(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| path = job.overlay_video_path if job.overlay_video_path and os.path.exists(job.overlay_video_path) and os.path.getsize(job.overlay_video_path) > 1024 else job.video_path | |
| return FileResponse(path, media_type="video/mp4") | |
| def process_video_compressed(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| path = job.compressed_video_path if job.compressed_video_path and os.path.exists(job.compressed_video_path) and os.path.getsize(job.compressed_video_path) > 1024 else job.video_path | |
| return FileResponse(path, media_type="video/mp4") | |
| def process_video_roi(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Unknown job_id") | |
| path = job.roi_video_path if job.roi_video_path and os.path.exists(job.roi_video_path) and os.path.getsize(job.roi_video_path) > 1024 else job.video_path | |
| return FileResponse(path, media_type="video/mp4") | |
| if __name__ == "__main__": | |
| import argparse | |
| import uvicorn | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--host", default="127.0.0.1") | |
| p.add_argument("--port", default=8000, type=int) | |
| p.add_argument("--weights", default=DEFAULT_WEIGHTS) | |
| p.add_argument("--device", default=DEFAULT_DEVICE) | |
| args = p.parse_args() | |
| DEFAULT_WEIGHTS = args.weights | |
| DEFAULT_DEVICE = args.device | |
| get_model(args.weights) | |
| host = os.environ.get("HOST", args.host or "0.0.0.0") | |
| port = int(os.environ.get("PORT", args.port)) | |
| uvicorn.run(app, host=host, port=port) | |