Spaces:
Paused
Paused
| # | |
| from flask import Flask, render_template, request, jsonify, Response, send_from_directory | |
| from flask_socketio import SocketIO, emit | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import json | |
| import uuid | |
| import threading | |
| import queue | |
| import torch | |
| import time | |
| from datetime import datetime | |
| from collections import Counter | |
| import base64 | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from utils.frame_selector import GlobalBufferManager | |
| from rfdetr import RFDETRMedium | |
| import supervision as sv | |
| from gradio_client import Client, handle_file | |
| # --- WebRTC Imports ---- | |
| import asyncio | |
| from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack, RTCIceCandidate | |
| from aiortc.contrib.media import MediaRelay | |
| import av | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = 'your-secret-key-here' | |
| app.config['SESSION_TYPE'] = 'filesystem' | |
| socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading', manage_session=False) | |
| # ── Directories ───────────────────────────────────────────────────────────────- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static/uploads') | |
| RESULTS_FOLDER = os.path.join(BASE_DIR, 'static/results') | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
| # ── HF Download ───────────────────────────────────────────────────────────────- | |
| # /data is the HF Spaces persistent writable volume (enable in Space Settings). | |
| # Fallback to /tmp if /data is not available (no persistent storage enabled). | |
| from huggingface_hub import snapshot_download | |
| if os.path.isdir("/data"): | |
| MODEL_DIR = "/data/CV_MODELS" | |
| else: | |
| MODEL_DIR = "/tmp/CV_MODELS" | |
| print("[WARNING] /data not available — using /tmp. Models will re-download on every restart.") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| if not os.listdir(MODEL_DIR): | |
| print(f"[INIT] Downloading models to {MODEL_DIR} ...") | |
| snapshot_download( | |
| repo_id="WebAshlarWA/CV_MODELS", | |
| local_dir=MODEL_DIR, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| print("[INIT] Download complete.") | |
| else: | |
| print(f"[INIT] Models already present at {MODEL_DIR}, skipping download.") | |
| # Debug: print all downloaded files so you can verify paths on first boot | |
| print("[INIT] Files found in MODEL_DIR:") | |
| for root, _, files in os.walk(MODEL_DIR): | |
| for f in files: | |
| print(f" {os.path.join(root, f)}") | |
| def find_file(root, filename): | |
| for r, _, files in os.walk(root): | |
| if filename in files: | |
| return os.path.join(r, filename) | |
| return None | |
| RIDER_WEIGHTS = find_file(MODEL_DIR, "checkpoint0039.pth") | |
| HELMET_HEAD_WEIGHTS = find_file(MODEL_DIR, "checkpoint_best_ema_hel.pth") | |
| PLATE_WEIGHTS = find_file(MODEL_DIR, "checkpoint_best_ema_plate.pth") | |
| print(f"[INIT] RIDER_WEIGHTS: {RIDER_WEIGHTS}") | |
| print(f"[INIT] HELMET_HEAD_WEIGHTS: {HELMET_HEAD_WEIGHTS}") | |
| print(f"[INIT] PLATE_WEIGHTS: {PLATE_WEIGHTS}") | |
| if not RIDER_WEIGHTS: | |
| raise FileNotFoundError("checkpoint0039.pth not found — check MODEL_DIR file listing above.") | |
| if not HELMET_HEAD_WEIGHTS: | |
| raise FileNotFoundError("checkpoint_best_ema_hel.pth not found — check MODEL_DIR file listing above.") | |
| if not PLATE_WEIGHTS: | |
| raise FileNotFoundError("checkpoint_best_ema_plate.pth not found — check MODEL_DIR file listing above.") | |
| # Class labels used only for logging / debug | |
| # RIDER_CLASSES = ["rider"] # Stage-1 is a pure rider detector (1 class) | |
| # If your rider checkpoint was trained on 2 classes keep the line below instead: | |
| # # ── Model Weights ────────────────────────────────────────────────────────────-- | |
| # # Stage 1 – Rider / motorbike detector (finds riders in the full frame) | |
| # RIDER_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoints/checkpoint0039.pth") | |
| # # Stage 2 – Dedicated head / helmet model | |
| # # General-purpose: detects every helmet / bare-head in a crop. | |
| # # We constrain it to each rider bounding box so it never fires outside. | |
| # HELMET_HEAD_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoints/checkpoint_best_ema_hel.pth") | |
| # # Stage 3 – License plate detector (runs within rider box on violation) | |
| # PLATE_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoints/checkpoint_best_ema_plate.pth") | |
| # Class labels used only for logging / debug | |
| #RIDER_CLASSES = ["rider"] # Stage-1 is a pure rider detector (1 class) | |
| # If your rider checkpoint was trained on 2 classes keep the line below instead: | |
| RIDER_CLASSES = ["motorbike and helmet", "motorbike and no helmet"] | |
| # ── Stage-2 helmet model class mapping (Sync with camera_processor_LT.py) ───── | |
| # class 1 → helmet (SAFE) | |
| # class 2 → no-helmet (VIOLATION) | |
| # (class 0 is often 'all' or 'rider' depending on model) | |
| HELMET_CLASS_ID = 1 | |
| NO_HELMET_CLASS_ID = 2 | |
| NUM_HELMET_CLASSES = 3 | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| CONF_RIDER = 0.25 # Stage-1: minimum score to enter tracker | |
| CONF_HELMET_HEAD = 0.25 # Stage-2: minimum score for helmet model | |
| CONF_PLATE = 0.50 # Stage-3: plate detection score | |
| # Decision engine (Unified) | |
| # NOTE: CONF_THRESHOLD is the cutoff used by the consensus/voting stage. | |
| # It should be lower than your model's single-frame "strong" cutoff. | |
| #CONF_THRESHOLD = 0.35 # symmetric sensitivity threshold (lowered from 0.95) | |
| CONF_THRESHOLD = 0.80 | |
| DECISION_MARGIN = 1.2 # confidence multiplier for margin decision | |
| HISTORY_WINDOW = 30 | |
| RECENT_WINDOW = 10 | |
| SAFE_RECENT_H_HITS = 2 | |
| SAFE_RECENT_NH_HITS = 2 | |
| MIN_FRAMES_BEFORE_DECIDE = 12 | |
| VIOLATION_PERSIST_FRAMES = 18 | |
| # ── Bounding-box colours (Standardized) ────────────────────────────────────── | |
| # Stage-1 boxes: Single uniform color for all riders | |
| COLOR_RIDER_CYAN = (255, 255, 0) # Cyan in BGR | |
| COLOR_SAFE = (0, 255, 0) # Green (Helmet) | |
| COLOR_VIOLATION = (0, 0, 255) # Red (No-Helmet) | |
| COLOR_PLATE_BOX = (255, 0, 255) # Magenta/Purple | |
| COLOR_ANALYZING = (0, 200, 200) # Dark Cyan | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[INIT] Device: {device}") | |
| # ── Model Loading ───────────────────────────────────────────────────────────── | |
| def _load_model(label, weights_path, num_classes): | |
| print(f"[INIT] Loading {label} → {weights_path}") | |
| m = RFDETRMedium(num_classes=num_classes, pretrain_weights=weights_path) | |
| try: | |
| m.optimize_for_inference(compile=True, batch_size=1) | |
| if device.type == "cuda": | |
| m.model.half() | |
| except Exception as e: | |
| print(f"[WARNING] {label} optimisation failed: {e}") | |
| return m | |
| rider_model = _load_model("Rider Detection", RIDER_WEIGHTS, len(RIDER_CLASSES)) | |
| helmet_head_model = _load_model("Helmet/NoHelmet", HELMET_HEAD_WEIGHTS, NUM_HELMET_CLASSES) | |
| plate_model = _load_model("Plate Detection", PLATE_WEIGHTS, 1) | |
| # ── Shared session data ─────────────────────────────────────────────────────── | |
| current_session_data = { | |
| "violations": {}, | |
| "safe_tracks": set(), # Track IDs confirmed as SAFE | |
| "total_riders": set(), # All unique track IDs seen | |
| "ocr_queue": queue.Queue(), | |
| "track_plate_cache": {}, | |
| "track_capture_count":{}, | |
| "track_ocr_history": {}, | |
| "ocr_in_progress": set(), | |
| "track_violation_age":{}, | |
| } | |
| # ── Per-session best-frame buffer (video pipeline) ──────────────────────────── | |
| video_buf_manager = GlobalBufferManager() | |
| live_camera_sessions = {} | |
| json_lock = threading.Lock() | |
| # --- WebRTC Global State --- | |
| pcs = set() | |
| active_pcs = {} # SID -> PeerConnection | |
| relay = MediaRelay() | |
| publisher_tracks = {} # Mapping session_id -> track | |
| loop = asyncio.new_event_loop() | |
| def start_async_loop(loop): | |
| asyncio.set_event_loop(loop) | |
| loop.run_forever() | |
| threading.Thread(target=start_async_loop, args=(loop,), daemon=True).start() | |
| # Helper to run async code from sync Socket.IO handlers | |
| def run_async(coro): | |
| return asyncio.run_coroutine_threadsafe(coro, loop).result() | |
| # --- Custom WebRTC Video Track for processing --- | |
| class VideoProcessTrack(MediaStreamTrack): | |
| kind = "video" | |
| def __init__(self, track, session_id, socket_sid): | |
| super().__init__() | |
| self.track = track | |
| self.session_id = session_id | |
| self.socket_sid = socket_sid | |
| async def recv(self): | |
| frame = await self.track.recv() | |
| # Convert to numpy/opencv | |
| img = frame.to_ndarray(format="bgr24") | |
| # Access session data (The Publisher's session) | |
| if self.socket_sid in live_camera_sessions: | |
| session = live_camera_sessions[self.socket_sid] | |
| # Process the frame | |
| processed_img, new_violations = process_live_frame( | |
| img, session, self.session_id, self.socket_sid) | |
| # Emit results to the SESSION ROOM (so Admin on other device sees it) | |
| room_name = f"session_{self.session_id}" | |
| socketio.emit('processed_frame_relay', { | |
| 'violations': new_violations, | |
| 'stats': { | |
| 'total_riders': len(session.get('total_riders', set())), | |
| 'safe_count': len(session.get('safe_tracks', set())), | |
| 'violation_count': len(session.get('violations', {})) | |
| } | |
| }, room=room_name) | |
| return frame | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # HELPERS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def get_best_consensus(results): | |
| cleaned = [r.replace("\n", " ").strip() for r in results | |
| if r not in ["API_ERROR", "PENDING...", ""]] | |
| if not cleaned: | |
| return "PENDING..." | |
| if len(cleaned) == 1: | |
| return cleaned[0] | |
| max_len = max(len(r) for r in cleaned) | |
| final_chars = [] | |
| for i in range(max_len): | |
| pool = [r[i] for r in cleaned if i < len(r)] | |
| if pool: | |
| final_chars.append(Counter(pool).most_common(1)[0][0]) | |
| return "".join(final_chars).strip() | |
| def clamp_box(box, w, h): | |
| x1, y1, x2, y2 = map(int, box) | |
| return [max(0, x1), max(0, y1), min(w - 1, x2), min(h - 1, y2)] | |
| def expand_box_for_plate(box, w, h): | |
| """Shift & expand box downward to include the number plate area.""" | |
| x1, y1, x2, y2 = map(int, box) | |
| bw, bh = x2 - x1, y2 - y1 | |
| return clamp_box( | |
| [x1 - bw * 0.1, y1 + bh * 0.4, x2 + bw * 0.1, y2 + bh * 0.4], w, h | |
| ) | |
| def parse_preds(preds, W, H, debug_tag=""): | |
| """Extract boxes/scores/labels from RFDETRMedium predictions.""" | |
| boxes, scores, labels = np.array([]), np.array([]), np.array([]) | |
| if hasattr(preds, "xyxy"): | |
| boxes = preds.xyxy if isinstance(preds.xyxy, np.ndarray) else preds.xyxy.cpu().numpy() | |
| scores = preds.confidence if isinstance(preds.confidence, np.ndarray) else preds.confidence.cpu().numpy() | |
| labels = preds.class_id if isinstance(preds.class_id, np.ndarray) else preds.class_id.cpu().numpy() | |
| if boxes.size > 0 and boxes.max() <= 1.01: | |
| boxes = boxes.copy() | |
| boxes[:, [0, 2]] *= W | |
| boxes[:, [1, 3]] *= H | |
| if debug_tag and boxes.size > 0: | |
| for i in range(len(scores)): | |
| print(f"[PARSE:{debug_tag}] det#{i} cls={int(labels[i])} conf={scores[i]:.3f}") | |
| return boxes, scores, labels | |
| def run_helmet_on_crop(frame_bgr, rx1, ry1, rx2, ry2): | |
| """ | |
| Run Stage-2 (helmet head model) on a single rider crop. | |
| Returns (best_cls, best_conf, list_of_shifted_detections) | |
| """ | |
| h_orig, w_orig = frame_bgr.shape[:2] | |
| # Boundary check: Clamping crop to frame | |
| rx1, ry1 = max(0, rx1), max(0, ry1) | |
| rx2, ry2 = min(w_orig, rx2), min(h_orig, ry2) | |
| crop_h = ry2 - ry1 | |
| crop_w = rx2 - rx1 | |
| if crop_h < 20 or crop_w < 20: # Skip tiny crops | |
| return -1, 0.0, [] | |
| crop_bgr = frame_bgr[ry1:ry2, rx1:rx2] | |
| crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) | |
| with torch.no_grad(): | |
| hpreds = helmet_head_model.predict(crop_rgb, conf=CONF_HELMET_HEAD, iou=0.45) | |
| hb, hs, hl = parse_preds(hpreds, crop_w, crop_h, debug_tag="helmet_head") | |
| if hb.size == 0: | |
| return -1, 0.0, [] | |
| # Shift detections back to main frame coordinates (Cascade Logic) | |
| shifted_detections = [] | |
| for i in range(len(hb)): | |
| sx1 = int(hb[i][0] + rx1) | |
| sy1 = int(hb[i][1] + ry1) | |
| sx2 = int(hb[i][2] + rx1) | |
| sy2 = int(hb[i][3] + ry1) | |
| shifted_detections.append({ | |
| "box": [sx1, sy1, sx2, sy2], | |
| "score": float(hs[i]), | |
| "class": int(hl[i]) | |
| }) | |
| # Prioritize HELMET over NO-HELMET for state evaluation | |
| h_indices = np.where(hl == HELMET_CLASS_ID)[0] | |
| if h_indices.size > 0: | |
| best_h_idx = h_indices[np.argmax(hs[h_indices])] | |
| return HELMET_CLASS_ID, float(hs[best_h_idx]), shifted_detections | |
| nh_indices = np.where(hl == NO_HELMET_CLASS_ID)[0] | |
| if nh_indices.size > 0: | |
| best_nh_idx = nh_indices[np.argmax(hs[nh_indices])] | |
| return NO_HELMET_CLASS_ID, float(hs[best_nh_idx]), shifted_detections | |
| # No specific head signals found? Use the best overall if any | |
| best = int(np.argmax(hs)) | |
| return int(hl[best]), float(hs[best]), shifted_detections | |
| # def evaluate_helmet_state(hist, cls_idx, confidence, prev_violation, violation_age, rider_cls=-1): | |
| # """ | |
| # Enhanced decision engine: sensitive to both Helmet and No-Helmet detections. | |
| # Uses a weighted consensus rather than a single-hit override. | |
| # """ | |
| # total = len(hist) | |
| # recent = hist[-RECENT_WINDOW:] | |
| # r_total = len(recent) | |
| # # Counts based on lowered sensitivity thresholds | |
| # f_h = sum(1 for h in hist if h['class'] == 0 and h['conf'] >= CONF_HELMET_CONFIRM) | |
| # f_nh = sum(1 for h in hist if h['class'] != 0 and h['conf'] >= CONF_NO_HELMET_TRIGGER) | |
| # r_h = sum(1 for h in recent if h['class'] == 0 and h['conf'] >= CONF_HELMET_CONFIRM) | |
| # r_nh = sum(1 for h in recent if h['class'] != 0 and h['conf'] >= CONF_NO_HELMET_TRIGGER) | |
| # r_nh_frac = r_nh / max(r_total, 1) | |
| # # ── SAFE checks (Helmet detected) ──────────────────────────────────────── | |
| # # 1. Stage-1 Rider detector specifically says "with helmet" (High Priority) | |
| # if rider_cls == 0: | |
| # return False, True, "safe:rider_H" | |
| # # 2. Strong current evidence of a helmet | |
| # if cls_idx == 0 and confidence >= 0.50: | |
| # return False, True, f"safe:cur_H_strong({confidence:.2f})" | |
| # # 3. Decision by majority/consistency (Balanced sensitivity) | |
| # # If we have significantly more helmet hits than no-helmet hits, it's safe. | |
| # if f_h > f_nh and f_h >= 2: | |
| # return False, True, f"safe:vote_H({f_h} vs {f_nh})" | |
| # # ── VIOLATION gate (No-Helmet detected) ────────────────────────────────── | |
| # # We trigger a violation if no-helmet hits dominate, even if there was a | |
| # # sporadic/low-conf helmet hit (reduces false 'Safe' from noise). | |
| # violation_gate = ( | |
| # total >= MIN_FRAMES_BEFORE_VIOLATION | |
| # and f_nh > f_h * 2 # No-helmet hits must clearly outweigh helmet hits | |
| # and r_nh >= RECENT_NH_MIN | |
| # and r_nh_frac >= RECENT_NH_FRAC | |
| # and f_nh >= MIN_NH_HITS_FULL | |
| # and rider_cls != 0 | |
| # ) | |
| # if violation_gate: | |
| # return True, False, f"violation(NH={f_nh}, H={f_h})" | |
| # # ── Soft persistence ────────────────────────────────────────────────────── | |
| # if prev_violation and violation_age <= VIOLATION_PERSIST_FRAMES and r_h == 0: | |
| # return True, False, f"persist(age={violation_age})" | |
| # # ── Default State ───────────────────────────────────────────────────────── | |
| # debug = (f"analyzing({total}fr)" if total < MIN_FRAMES_BEFORE_VIOLATION | |
| # else f"neutral(H={f_h},NH={f_nh})") | |
| # # If we can't decide but have some helmet hits, default to Safe for UX stability | |
| # if f_h >= 1 and not violation_gate: | |
| # return False, True, f"safe:default_H({f_h})" | |
| # return False, False, debug | |
| def evaluate_helmet_state(hist, cls_idx, confidence, | |
| prev_violation, violation_age, | |
| rider_cls=-1): | |
| """ | |
| Balanced & confident temporal decision engine. | |
| - Symmetric sensitivity via CONF_THRESHOLD | |
| - Confidence-weighted historical scoring | |
| - Recent count-based override for responsiveness | |
| - Margin-based decision to avoid flips | |
| """ | |
| total = len(hist) | |
| recent = hist[-RECENT_WINDOW:] | |
| r_total = len(recent) | |
| # Stage-1 override (high priority): if your Stage-1 mapping is helmet==0 | |
| if rider_cls == 0: | |
| return False, True, "safe:rider_H" | |
| # Strong current-frame shortcut (immediate decisions for very confident frames) | |
| STRONG_SINGLE_FRAME = 0.60 | |
| if cls_idx == HELMET_CLASS_ID and confidence >= STRONG_SINGLE_FRAME: | |
| return False, True, f"safe:cur_H_strong({confidence:.2f})" | |
| if cls_idx == NO_HELMET_CLASS_ID and confidence >= STRONG_SINGLE_FRAME: | |
| return True, False, f"violation:cur_NH_strong({confidence:.2f})" | |
| # Confidence-weighted historical scoring (only sum confidences >= CONF_THRESHOLD) | |
| score_h = sum(h['conf'] for h in hist | |
| if h['class'] == HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD) | |
| score_nh = sum(h['conf'] for h in hist | |
| if h['class'] == NO_HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD) | |
| # Recent count-based override (simpler and robust) | |
| r_h = sum(1 for h in recent if h['class'] == HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD) | |
| r_nh = sum(1 for h in recent if h['class'] == NO_HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD) | |
| if r_h >= SAFE_RECENT_H_HITS: | |
| return False, True, f"safe:recent_H(count={r_h})" | |
| if r_nh >= SAFE_RECENT_NH_HITS: | |
| return True, False, f"violation:recent_NH(count={r_nh})" | |
| # Require minimal accumulation of frames before final decision | |
| if total < MIN_FRAMES_BEFORE_DECIDE: | |
| return False, False, f"warming({total}fr)" | |
| # Margin-based confident decision | |
| if score_h > score_nh * DECISION_MARGIN: | |
| return False, True, f"safe:margin({score_h:.2f} vs {score_nh:.2f})" | |
| if score_nh > score_h * DECISION_MARGIN: | |
| return True, False, f"violation:margin({score_nh:.2f} vs {score_h:.2f})" | |
| # Soft persistence (anti-flicker) for ongoing violations | |
| if prev_violation and violation_age <= VIOLATION_PERSIST_FRAMES: | |
| if r_h == 0: # do not persist if helmet appears in recent window | |
| return True, False, f"persist(age={violation_age})" | |
| # Neutral / uncertain state | |
| return False, False, f"uncertain(H={score_h:.2f},NH={score_nh:.2f})" | |
| def _draw_overlay(frame, x1, y1, x2, y2, tid, display_name, confidence, color, plate_text="", reason=""): | |
| """Unified overlay drawing for a single tracked rider.""" | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| label = f"ID:{tid} {display_name} {confidence:.2f}" | |
| cv2.putText(frame, label, (x1, max(y1 - 10, 10)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| if reason: | |
| cv2.putText(frame, f"({reason})", (x1, max(y1 - 25, 10)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1) | |
| if plate_text: | |
| cv2.putText(frame, f"Plate: {plate_text}", (x1, y2 + 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLOR_PLATE_BOX, 2) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # OCR WORKER | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def background_ocr_worker(): | |
| """Processes OCR tasks from the shared queue asynchronously.""" | |
| print("[OCR] Worker Thread Started") | |
| client = None | |
| for attempt in range(3): | |
| try: | |
| client = Client("WebashalarForML/demo-glm-ocr") | |
| print("[OCR] Gradio Client Connected") | |
| break | |
| except Exception as e: | |
| print(f"[OCR] Connection attempt {attempt + 1} failed: {e}") | |
| if attempt == 2: | |
| print("[OCR] Max retries reached. Worker exiting.") | |
| return | |
| time.sleep(2) | |
| while True: | |
| try: | |
| task = current_session_data["ocr_queue"].get(timeout=1) | |
| if task is None: | |
| current_session_data["ocr_queue"].task_done() | |
| continue | |
| track_id, plate_path, session_id, socket_sid = task | |
| if track_id in current_session_data["ocr_in_progress"]: | |
| print(f"[OCR] ID {track_id} already in progress, skipping") | |
| current_session_data["ocr_queue"].task_done() | |
| continue | |
| current_session_data["ocr_in_progress"].add(track_id) | |
| if not os.path.exists(plate_path): | |
| current_session_data["ocr_queue"].task_done() | |
| current_session_data["ocr_in_progress"].discard(track_id) | |
| continue | |
| plate_text = "API_ERROR" | |
| try: | |
| result = client.predict(image=handle_file(plate_path), | |
| api_name="/proses_intelijen") | |
| plate_text = str(result).strip() | |
| print(f"[OCR] ID {track_id}: {plate_text}") | |
| except Exception as e: | |
| print(f"[OCR] API error for ID {track_id}: {e}") | |
| is_live = session_id.startswith('live_') | |
| if is_live: | |
| if socket_sid and socket_sid in live_camera_sessions: | |
| session = live_camera_sessions[socket_sid] | |
| if track_id not in session["track_ocr_history"]: | |
| session["track_ocr_history"][track_id] = [] | |
| if plate_text not in ["API_ERROR", ""]: | |
| session["track_ocr_history"][track_id].append(plate_text) | |
| final = get_best_consensus(session["track_ocr_history"][track_id]) | |
| session["track_plate_cache"][track_id] = final | |
| with json_lock: | |
| if track_id in session["violations"]: | |
| session["violations"][track_id]["plate_number"] = final | |
| session["violations"][track_id]["ocr_attempts"] = session["track_ocr_history"][track_id] | |
| socketio.emit('ocr_update', { | |
| 'track_id': track_id, | |
| 'plate_number': final, | |
| 'violation': session["violations"][track_id] | |
| }, room=socket_sid) | |
| else: | |
| if track_id not in current_session_data["track_ocr_history"]: | |
| current_session_data["track_ocr_history"][track_id] = [] | |
| if plate_text not in ["API_ERROR", ""]: | |
| current_session_data["track_ocr_history"][track_id].append(plate_text) | |
| final = get_best_consensus(current_session_data["track_ocr_history"][track_id]) | |
| current_session_data["track_plate_cache"][track_id] = final | |
| with json_lock: | |
| if track_id in current_session_data["violations"]: | |
| current_session_data["violations"][track_id]["plate_number"] = final | |
| current_session_data["violations"][track_id]["ocr_attempts"] = ( | |
| current_session_data["track_ocr_history"][track_id]) | |
| json_path = os.path.join(RESULTS_FOLDER, f"session_{session_id}.json") | |
| with open(json_path, 'w') as f: | |
| json.dump(list(current_session_data["violations"].values()), f, indent=4) | |
| current_session_data["ocr_in_progress"].discard(track_id) | |
| current_session_data["ocr_queue"].task_done() | |
| except queue.Empty: | |
| continue | |
| except Exception as e: | |
| print(f"[OCR] Loop Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if 'track_id' in locals(): | |
| current_session_data["ocr_in_progress"].discard(track_id) | |
| try: | |
| current_session_data["ocr_queue"].task_done() | |
| except Exception: | |
| pass | |
| NUM_OCR_WORKERS = 3 | |
| for _i in range(NUM_OCR_WORKERS): | |
| _t = threading.Thread(target=background_ocr_worker, daemon=True, name=f"OCR-Worker-{_i+1}") | |
| _t.start() | |
| print(f"[INIT] Started OCR Worker {_i+1}") | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # THREE-MODEL PIPELINE – VIDEO (generator) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def process_video_gen(video_path, session_id): | |
| cap = cv2.VideoCapture(video_path) | |
| tracker = sv.ByteTrack() | |
| track_class_history = {} | |
| track_violation_memory= {} | |
| track_last_seen = {} | |
| track_violation_age = {} | |
| dead_ids = set() | |
| frame_idx = 0 | |
| prev_frame = None # for motion scoring | |
| video_buf_manager.reset_all() # clean slate for this video | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_idx += 1 | |
| h_orig, w_orig = frame.shape[:2] | |
| # ── Expire old tracks & force-flush their buffers ───────────────────── | |
| newly_dead = [t for t, last in track_last_seen.items() | |
| if frame_idx - last > 50 and t not in dead_ids] | |
| for tid in newly_dead: | |
| dead_ids.add(tid) | |
| # Force-flush dead tracks (short-track safety net) | |
| dead_flushes = video_buf_manager.force_flush_dead(set(newly_dead)) | |
| for tid, best_entry in dead_flushes.items(): | |
| if best_entry.plate_crop is not None and best_entry.plate_crop.size > 0: | |
| pname = f"viol_{session_id}_{tid}_plate_best.jpg" | |
| ppath = os.path.join(RESULTS_FOLDER, pname) | |
| cv2.imwrite(ppath, best_entry.plate_crop) | |
| with json_lock: | |
| if tid in current_session_data["violations"]: | |
| current_session_data["violations"][tid]["plate_image_url"] = ( | |
| f"/static/results/{pname}") | |
| current_session_data["ocr_queue"].put((tid, ppath, session_id, None)) | |
| print(f"[BUFFER] Dead-track flush: tid={tid} frame={best_entry.frame_idx} score={best_entry.score:.3f}") | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # ══ STAGE 1 – Rider Detection ═════════════════════════════════════════ | |
| with torch.no_grad(): | |
| rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.45) | |
| r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig, debug_tag="rider") | |
| if r_boxes.size > 0: | |
| detections = sv.Detections( | |
| xyxy=r_boxes.astype(np.float32), | |
| confidence=r_scores.astype(np.float32), | |
| class_id=r_labels.astype(np.int32) | |
| ) | |
| # [USER BUG FIX] Apply NMS to prevent overlapping rider boxes | |
| detections = detections.with_nms(threshold=0.5) | |
| else: | |
| detections = sv.Detections.empty() | |
| detections = tracker.update_with_detections(detections) | |
| for (xyxy, _mask, rider_conf, rider_cls, tracker_id, _data) in detections: | |
| if tracker_id is None: | |
| continue | |
| tid = int(tracker_id) | |
| track_last_seen[tid] = frame_idx | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| # ══ STAGE 2 – Helmet / No-Helmet (within rider crop) ═════════════ | |
| h_cls, h_conf, h_dets = run_helmet_on_crop(frame, x1, y1, x2, y2) | |
| # Draw individual child detections (Cascade Coordinate Shifting) | |
| for det in h_dets: | |
| dx1, dy1, dx2, dy2 = det["box"] | |
| d_color = COLOR_SAFE if det["class"] == HELMET_CLASS_ID else COLOR_VIOLATION | |
| cv2.rectangle(frame, (dx1, dy1), (dx2, dy2), d_color, 1) | |
| d_label = "H" if det["class"] == HELMET_CLASS_ID else "NH" | |
| cv2.putText(frame, f"{d_label}:{det['score']:.2f}", (dx1, dy1-5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.3, d_color, 1) | |
| # --- unified history append: prefer Stage-2, fall back to Stage-1 --- | |
| if tid not in track_class_history: | |
| track_class_history[tid] = [] | |
| # debug: quickly log class disagreement for diagnosis | |
| # print(f"[DBG-HIST] tid={tid} stage2_h_cls={h_cls} stage2_h_conf={h_conf:.2f} stage1_rider_cls={rider_cls}") | |
| if h_cls >= 0: | |
| # use helmet-head model (preferred source of truth) | |
| track_class_history[tid].append({"class": h_cls, "conf": h_conf}) | |
| else: | |
| # fallback to stage-1 only when stage-2 had no head detection | |
| if int(rider_cls) == 0: | |
| track_class_history[tid].append({"class": HELMET_CLASS_ID, "conf": 0.99}) | |
| # clamp history | |
| if len(track_class_history[tid]) > HISTORY_WINDOW: | |
| track_class_history[tid].pop(0) | |
| hist = track_class_history.get(tid, []) | |
| prev_viol = track_violation_memory.get(tid, False) | |
| viol_age = track_violation_age.get(tid, 0) | |
| is_no_helmet, is_safe, dbg = evaluate_helmet_state( | |
| hist, h_cls, h_conf, prev_viol, viol_age, rider_cls=int(rider_cls)) | |
| if is_no_helmet: | |
| track_violation_memory[tid] = True | |
| track_violation_age[tid] = viol_age + 1 | |
| with json_lock: | |
| current_session_data["safe_tracks"].discard(tid) | |
| else: | |
| track_violation_memory[tid] = False | |
| track_violation_age[tid] = 0 | |
| if is_safe: | |
| with json_lock: | |
| current_session_data["safe_tracks"].add(tid) | |
| if prev_viol: | |
| current_session_data["violations"].pop(tid, None) | |
| with json_lock: | |
| current_session_data["total_riders"].add(tid) | |
| print(f"[TRACK] ID={tid} h_cls={h_cls} h_conf={h_conf:.2f} | {dbg}") | |
| # ── Display state ───────────────────────────────────────────────── | |
| total_hist = len(hist) | |
| if is_no_helmet: | |
| display_name = "VIOLATION: NO HELMET" | |
| color = COLOR_VIOLATION | |
| elif is_safe: | |
| display_name = "SAFE: HELMET" | |
| color = COLOR_SAFE | |
| elif total_hist < MIN_FRAMES_BEFORE_DECIDE: | |
| display_name = "ANALYZING..." | |
| color = COLOR_ANALYZING | |
| else: | |
| display_name = "RIDER" | |
| color = COLOR_RIDER_CYAN | |
| plate_text = current_session_data["track_plate_cache"].get(tid, "") | |
| # ══ STAGE 3 – Plate Detection (violation only) ════════════════════ | |
| if is_no_helmet and tid not in dead_ids: | |
| with json_lock: | |
| if tid not in current_session_data["violations"]: | |
| ts = datetime.now() | |
| rider_img_name = f"viol_{session_id}_{tid}_rider.jpg" | |
| cv2.imwrite(os.path.join(RESULTS_FOLDER, rider_img_name), | |
| frame[y1:y2, x1:x2]) | |
| current_session_data["violations"][tid] = { | |
| "id": tid, | |
| "timestamp": ts.strftime('%H:%M:%S'), | |
| "type": "No Helmet", | |
| "plate_number": "Scanning...", | |
| "image_url": f"/static/results/{rider_img_name}", | |
| "plate_image_url": None, | |
| "ocr_attempts": [], | |
| "raw": { | |
| "confidence": float(h_conf), | |
| "box": xyxy.tolist() | |
| } | |
| } | |
| current_session_data["track_capture_count"][tid] = 0 | |
| # ── Buffer-based best-frame plate capture ───────────────────── | |
| eb = expand_box_for_plate(xyxy, w_orig, h_orig) | |
| plate_crop_region = frame[eb[1]:eb[3], eb[0]:eb[2]] | |
| if plate_crop_region.size > 0: | |
| with torch.no_grad(): | |
| plate_preds = plate_model.predict( | |
| cv2.cvtColor(plate_crop_region, cv2.COLOR_BGR2RGB), | |
| conf=CONF_PLATE, iou=0.45) | |
| pb, ps, _pl = parse_preds(plate_preds, | |
| plate_crop_region.shape[1], | |
| plate_crop_region.shape[0]) | |
| if pb.size > 0: | |
| best_det = int(np.argmax(ps)) | |
| px1, py1, px2, py2 = map(int, pb[best_det]) | |
| plate_crop = plate_crop_region[py1:py2, px1:px2] | |
| if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20: | |
| # Build rider ROI for motion scoring | |
| curr_roi = frame[y1:y2, x1:x2] | |
| prev_roi = prev_frame[y1:y2, x1:x2] if prev_frame is not None else None | |
| # Add to quality buffer (always, even if only 1 frame) | |
| video_buf_manager.add( | |
| tid, plate_crop, (x1, y1, x2, y2), | |
| frame_idx, prev_roi, curr_roi | |
| ) | |
| print(f"[BUFFER] tid={tid} frame={frame_idx} buffered plate crop") | |
| # Try flush: post-peak or timeout trigger | |
| if video_buf_manager.should_flush(tid, (x1, y1, x2, y2), frame_idx): | |
| best_entry = video_buf_manager.flush(tid) | |
| if best_entry is not None: | |
| snap = current_session_data["track_capture_count"].get(tid, 0) + 1 | |
| pname = f"viol_{session_id}_{tid}_plate_best{snap}.jpg" | |
| ppath = os.path.join(RESULTS_FOLDER, pname) | |
| cv2.imwrite(ppath, best_entry.plate_crop) | |
| with json_lock: | |
| current_session_data["violations"][tid]["plate_image_url"] = ( | |
| f"/static/results/{pname}") | |
| current_session_data["ocr_queue"].put((tid, ppath, session_id, None)) | |
| current_session_data["track_capture_count"][tid] = snap | |
| print(f"[BUFFER] tid={tid} FLUSH → frame={best_entry.frame_idx} score={best_entry.score:.3f}") | |
| _draw_overlay(frame, x1, y1, x2, y2, tid, display_name, h_conf, color, plate_text, reason=dbg) | |
| prev_frame = frame.copy() # store for next-frame motion scoring | |
| _, buffer = cv2.imencode('.jpg', frame) | |
| yield (b'--frame\r\nContent-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') | |
| # ── Final cleanup: flush all remaining active tracks ───────────────────── | |
| active_ids = set(track_last_seen.keys()) - dead_ids | |
| if active_ids: | |
| print(f"[VIDEO-END] Flushing {len(active_ids)} remaining tracks") | |
| # 1. Force-flush buffers for remaining active tracks | |
| final_flushes = video_buf_manager.force_flush_dead(active_ids) | |
| for tid, best_entry in final_flushes.items(): | |
| if best_entry.plate_crop is not None and best_entry.plate_crop.size > 0: | |
| pname = f"viol_{session_id}_{tid}_plate_final.jpg" | |
| ppath = os.path.join(RESULTS_FOLDER, pname) | |
| cv2.imwrite(ppath, best_entry.plate_crop) | |
| with json_lock: | |
| if tid in current_session_data["violations"]: | |
| current_session_data["violations"][tid]["plate_image_url"] = ( | |
| f"/static/results/{pname}") | |
| current_session_data["ocr_queue"].put((tid, ppath, session_id, None)) | |
| print(f"[BUFFER] End-of-video flush: tid={tid} frame={best_entry.frame_idx}") | |
| # 2. Aggressive evaluation for tracks that were visible until the end | |
| for tid in active_ids: | |
| hist = track_class_history.get(tid, []) | |
| if not hist: continue | |
| prev_viol = track_violation_memory.get(tid, False) | |
| if prev_viol: continue # already a violation | |
| # If visible >= 3 frames and has any No-Helmet signals, consider it if end-of-feed | |
| nh_hits = [h for h in hist if h['class'] == NO_HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD] | |
| if len(hist) >= 3 and len(nh_hits) >= 1: | |
| best_nh = max(nh_hits, key=lambda x: x['conf']) | |
| with json_lock: | |
| if tid not in current_session_data["violations"]: | |
| ts = datetime.now() | |
| # Use prev_frame (the last valid frame) to save a crop if we know where it was | |
| # We don't have the last xyxy easily here unless we store it. | |
| # For simplicity, we'll mark it but the plate flush above is more important. | |
| rider_img_name = f"viol_{session_id}_{tid}_rider_final.jpg" | |
| # If we had the last xyxy, we could cv2.imwrite here. | |
| # since we don't store track_last_box, we'll just use a placeholder or the plate url. | |
| current_session_data["violations"][tid] = { | |
| "id": tid, | |
| "timestamp": ts.strftime('%H:%M:%S'), | |
| "type": "No Helmet (Final)", | |
| "plate_number": "Scanning...", | |
| "image_url": f"/static/results/{rider_img_name}", # Will be updated if possible | |
| "plate_image_url": None, | |
| "ocr_attempts": [], | |
| "raw": { | |
| "confidence": float(best_nh['conf']), | |
| "box": [] | |
| } | |
| } | |
| print(f"[VIDEO-END] Forced violation for tid={tid}") | |
| cap.release() | |
| print(f"[VIDEO-END] Processing complete for session {session_id}") | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # THREE-MODEL PIPELINE – LIVE CAMERA (socket frame) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def process_live_frame(frame, session, session_id, socket_sid): | |
| tracker = session['tracker'] | |
| track_class_history = session['track_class_history'] | |
| track_violation_memory = session['track_violation_memory'] | |
| track_violation_age = session.setdefault('track_violation_age', {}) | |
| track_last_seen = session['track_last_seen'] | |
| dead_ids = session['dead_ids'] | |
| live_buf_mgr = session['plate_buf_manager'] # per-session buffer | |
| prev_frame = session.get('prev_frame') # for motion scoring | |
| session['frame_idx'] += 1 | |
| frame_idx = session['frame_idx'] | |
| # Expire old tracks & force-flush their buffers (short-track safety net) | |
| newly_dead = [t for t, last in track_last_seen.items() | |
| if frame_idx - last > 50 and t not in dead_ids] | |
| for tid in newly_dead: | |
| dead_ids.add(tid) | |
| dead_flushes = live_buf_mgr.force_flush_dead(set(newly_dead)) | |
| for tid, best_entry in dead_flushes.items(): | |
| if best_entry.plate_crop is not None and best_entry.plate_crop.size > 0: | |
| pname = f"viol_live_{session_id}_{tid}_plate_best.jpg" | |
| ppath = os.path.join(RESULTS_FOLDER, pname) | |
| cv2.imwrite(ppath, best_entry.plate_crop) | |
| with json_lock: | |
| if tid in session['violations']: | |
| session['violations'][tid]["plate_image_url"] = ( | |
| f"/static/results/{pname}") | |
| current_session_data["ocr_queue"].put( | |
| (tid, ppath, f"live_{session_id}", socket_sid)) | |
| print(f"[LIVE-BUFFER] Dead-track flush: tid={tid} frame={best_entry.frame_idx} score={best_entry.score:.3f}") | |
| h_orig, w_orig = frame.shape[:2] | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # ══ STAGE 1 – Rider Detection ═════════════════════════════════════════════ | |
| with torch.no_grad(): | |
| rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.45) | |
| r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig, debug_tag="rider") | |
| if r_boxes.size > 0: | |
| detections = sv.Detections( | |
| xyxy=r_boxes.astype(np.float32), | |
| confidence=r_scores.astype(np.float32), | |
| class_id=r_labels.astype(np.int32) | |
| ) | |
| # Apply NMS | |
| detections = detections.with_nms(threshold=0.5) | |
| else: | |
| detections = sv.Detections.empty() | |
| detections = tracker.update_with_detections(detections) | |
| new_violations = [] | |
| for (xyxy, _mask, rider_conf, rider_cls, tracker_id, _data) in detections: | |
| if tracker_id is None: | |
| continue | |
| tid = int(tracker_id) | |
| track_last_seen[tid] = frame_idx | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| # ══ STAGE 2 – Helmet / No-Helmet (within rider crop) ═════════════════ | |
| h_cls, h_conf, h_dets = run_helmet_on_crop(frame, x1, y1, x2, y2) | |
| # Draw child detections | |
| for det in h_dets: | |
| dx1, dy1, dx2, dy2 = det["box"] | |
| d_color = COLOR_SAFE if det["class"] == HELMET_CLASS_ID else COLOR_VIOLATION | |
| cv2.rectangle(frame, (dx1, dy1), (dx2, dy2), d_color, 1) | |
| d_label = "H" if det["class"] == HELMET_CLASS_ID else "NH" | |
| cv2.putText(frame, f"{d_label}:{det['score']:.2f}", (dx1, dy1-5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.3, d_color, 1) | |
| # --- unified history append: prefer Stage-2, fall back to Stage-1 --- | |
| if tid not in track_class_history: | |
| track_class_history[tid] = [] | |
| if h_cls >= 0: | |
| # use helmet-head model (preferred source of truth) | |
| track_class_history[tid].append({"class": h_cls, "conf": h_conf}) | |
| else: | |
| # fallback to stage-1 only when stage-2 had no head detection | |
| if int(rider_cls) == 0: | |
| track_class_history[tid].append({"class": HELMET_CLASS_ID, "conf": 0.99}) | |
| # clamp history | |
| if len(track_class_history[tid]) > HISTORY_WINDOW: | |
| track_class_history[tid].pop(0) | |
| # debug: quickly log class disagreement | |
| # print(f"[LIVE-DBG-HIST] tid={tid} stage2_h_cls={h_cls} stage2_h_conf={h_conf:.2f} stage1_rider_cls={rider_cls}") | |
| hist = track_class_history.get(tid, []) | |
| prev_viol = track_violation_memory.get(tid, False) | |
| viol_age = track_violation_age.get(tid, 0) | |
| is_no_helmet, is_safe, dbg = evaluate_helmet_state( | |
| hist, h_cls, h_conf, prev_viol, viol_age, rider_cls=int(rider_cls)) | |
| if is_no_helmet: | |
| track_violation_memory[tid] = True | |
| track_violation_age[tid] = viol_age + 1 | |
| with json_lock: | |
| session.setdefault('safe_tracks', set()).discard(tid) | |
| else: | |
| track_violation_memory[tid] = False | |
| track_violation_age[tid] = 0 | |
| if is_safe: | |
| with json_lock: | |
| session.setdefault('safe_tracks', set()).add(tid) | |
| if prev_viol: | |
| session['violations'].pop(tid, None) | |
| with json_lock: | |
| session.setdefault('total_riders', set()).add(tid) | |
| print(f"[LIVE-TRACK] ID={tid} h_cls={h_cls} h_conf={h_conf:.2f} | {dbg}") | |
| # ── Display state ───────────────────────────────────────────────────── | |
| total_hist = len(hist) | |
| if is_no_helmet: | |
| display_name = "VIOLATION: NO HELMET" | |
| color = COLOR_VIOLATION | |
| elif is_safe: | |
| display_name = "SAFE: HELMET" | |
| color = COLOR_SAFE | |
| elif total_hist < MIN_FRAMES_BEFORE_DECIDE: | |
| display_name = "ANALYZING..." | |
| color = COLOR_ANALYZING | |
| else: | |
| display_name = "NO VIOLATION" | |
| color = COLOR_SAFE | |
| plate_text = session['track_plate_cache'].get(tid, "") | |
| # ══ STAGE 3 – Plate Detection (violation only) ════════════════════════ | |
| if is_no_helmet and tid not in dead_ids: | |
| with json_lock: | |
| if tid not in session['violations']: | |
| ts = datetime.now() | |
| rider_img_name = f"viol_live_{session_id}_{tid}_rider.jpg" | |
| cv2.imwrite(os.path.join(RESULTS_FOLDER, rider_img_name), | |
| frame[y1:y2, x1:x2]) | |
| viol_record = { | |
| "id": tid, | |
| "timestamp": ts.strftime('%H:%M:%S'), | |
| "type": "No Helmet", | |
| "plate_number": "Scanning...", | |
| "image_url": f"/static/results/{rider_img_name}", | |
| "plate_image_url": None, | |
| "ocr_attempts": [], | |
| "raw": { | |
| "confidence": float(h_conf), | |
| "box": xyxy.tolist() | |
| } | |
| } | |
| session['violations'][tid] = viol_record | |
| session['track_capture_count'][tid] = 0 | |
| new_violations.append(viol_record) | |
| if session['track_capture_count'].get(tid, 0) < 5: | |
| eb = expand_box_for_plate(xyxy, w_orig, h_orig) | |
| plate_crop_region = frame[eb[1]:eb[3], eb[0]:eb[2]] | |
| if plate_crop_region.size > 0: | |
| with torch.no_grad(): | |
| plate_preds = plate_model.predict( | |
| cv2.cvtColor(plate_crop_region, cv2.COLOR_BGR2RGB), | |
| conf=CONF_PLATE, iou=0.45) | |
| pb, ps, _pl = parse_preds(plate_preds, | |
| plate_crop_region.shape[1], | |
| plate_crop_region.shape[0]) | |
| if pb.size > 0: | |
| best_det = int(np.argmax(ps)) | |
| px1, py1, px2, py2 = map(int, pb[best_det]) | |
| plate_crop = plate_crop_region[py1:py2, px1:px2] | |
| if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20: | |
| # Motion scoring ROIs | |
| curr_roi = frame[y1:y2, x1:x2] | |
| prev_roi = prev_frame[y1:y2, x1:x2] if prev_frame is not None else None | |
| # Buffer the scored plate crop (short-track safe) | |
| live_buf_mgr.add( | |
| tid, plate_crop, (x1, y1, x2, y2), | |
| frame_idx, prev_roi, curr_roi | |
| ) | |
| print(f"[LIVE-BUFFER] tid={tid} frame={frame_idx} buffered plate crop") | |
| # Flush trigger: post-peak or hard timeout | |
| if live_buf_mgr.should_flush(tid, (x1, y1, x2, y2), frame_idx): | |
| best_entry = live_buf_mgr.flush(tid) | |
| if best_entry is not None: | |
| snap = session['track_capture_count'].get(tid, 0) + 1 | |
| pname = f"viol_live_{session_id}_{tid}_plate_best{snap}.jpg" | |
| ppath = os.path.join(RESULTS_FOLDER, pname) | |
| cv2.imwrite(ppath, best_entry.plate_crop) | |
| with json_lock: | |
| session['violations'][tid]["plate_image_url"] = ( | |
| f"/static/results/{pname}") | |
| current_session_data["ocr_queue"].put( | |
| (tid, ppath, f"live_{session_id}", socket_sid)) | |
| session['track_capture_count'][tid] = snap | |
| print(f"[LIVE-BUFFER] tid={tid} FLUSH → frame={best_entry.frame_idx} score={best_entry.score:.3f}") | |
| _draw_overlay(frame, x1, y1, x2, y2, tid, display_name, h_conf, color, plate_text, reason=dbg) | |
| session['prev_frame'] = frame.copy() # store for next-frame motion scoring | |
| return frame, new_violations | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # FLASK ROUTES | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def index(): | |
| return render_template('landing.html') | |
| def dashboard(): | |
| return render_template('dashboard.html') | |
| def publisher(): | |
| return render_template('publisher.html') | |
| def camera_debug(): | |
| return render_template('camera_debug.html') | |
| def test_simple(): | |
| return send_from_directory('.', 'test_simple.html') | |
| def test_socket_echo(): | |
| return send_from_directory('.', 'test_socket_echo.html') | |
| def upload_video(): | |
| if 'file' not in request.files: | |
| return jsonify({"error": "No file part"}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"error": "No selected file"}), 400 | |
| session_id = str(uuid.uuid4())[:8] | |
| with json_lock: | |
| current_session_data["violations"] = {} | |
| current_session_data["safe_tracks"] = set() | |
| current_session_data["total_riders"] = set() | |
| current_session_data["track_plate_cache"] = {} | |
| current_session_data["track_capture_count"] = {} | |
| current_session_data["track_ocr_history"] = {} | |
| current_session_data["track_violation_age"] = {} | |
| while not current_session_data["ocr_queue"].empty(): | |
| try: | |
| current_session_data["ocr_queue"].get_nowait() | |
| except Exception: | |
| pass | |
| filename = f"{session_id}_{file.filename}" | |
| filepath = os.path.join(UPLOAD_FOLDER, filename) | |
| file.save(filepath) | |
| return jsonify({"filename": filename, "session_id": session_id}) | |
| def get_violations(): | |
| with json_lock: | |
| data = list(current_session_data["violations"].values()) | |
| data.reverse() | |
| return jsonify(data) | |
| def get_stats(): | |
| with json_lock: | |
| return jsonify({ | |
| 'total_riders': len(current_session_data.get('total_riders', set())), | |
| 'safe_count': len(current_session_data.get('safe_tracks', set())), | |
| 'violation_count': len(current_session_data.get('violations', {})) | |
| }) | |
| def video_feed(filename, session_id): | |
| filepath = os.path.join(UPLOAD_FOLDER, filename) | |
| return Response(process_video_gen(filepath, session_id), | |
| mimetype='multipart/x-mixed-replace; boundary=frame') | |
| def mobile_node(session_id): | |
| return render_template('mobile.html', session_id=session_id) | |
| def upload_frame(session_id): | |
| return jsonify({"status": "received"}) | |
| def get_live_violations(session_id): | |
| for sid, session in live_camera_sessions.items(): | |
| if session['session_id'] == session_id: | |
| with json_lock: | |
| data = list(session['violations'].values()) | |
| data.reverse() | |
| return jsonify(data) | |
| return jsonify([]) | |
| def get_active_sessions(): | |
| """Returns a list of all currently active session IDs.""" | |
| sessions = [] | |
| for sid, data in live_camera_sessions.items(): | |
| if 'session_id' in data: | |
| sessions.append(data['session_id']) | |
| return jsonify({"sessions": sessions}) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # SOCKET.IO – LIVE CAMERA | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def handle_connect(): | |
| print(f"[SOCKET] Client connected: {request.sid}") | |
| emit('connection_response', {'status': 'connected'}) | |
| def handle_disconnect(): | |
| print(f"[SOCKET] Client disconnected: {request.sid}") | |
| live_camera_sessions.pop(request.sid, None) | |
| def handle_start_camera(data): | |
| session_id = data.get('session_id', str(uuid.uuid4())[:8]) | |
| print(f"[SOCKET] Starting camera session: {session_id} for {request.sid}") | |
| # Join the session room | |
| from flask_socketio import join_room | |
| join_room(f"session_{session_id}") | |
| live_camera_sessions[request.sid] = { | |
| 'session_id': session_id, | |
| 'tracker': sv.ByteTrack(), | |
| 'track_class_history': {}, | |
| 'track_violation_memory':{}, | |
| 'track_violation_age': {}, | |
| 'track_last_seen': {}, | |
| 'dead_ids': set(), | |
| 'frame_idx': 0, | |
| 'violations': {}, | |
| 'safe_tracks': set(), | |
| 'total_riders': set(), | |
| 'track_plate_cache': {}, | |
| 'track_capture_count': {}, | |
| 'track_ocr_history': {}, | |
| 'plate_buf_manager': GlobalBufferManager(), # best-frame buffer | |
| 'prev_frame': None, # for motion scoring | |
| } | |
| emit('camera_session_started', {'session_id': session_id}) | |
| def handle_join_remote(data): | |
| """Allows Admin to watch a Publisher's session results.""" | |
| session_id = data.get('session_id') | |
| if not session_id: | |
| return | |
| from flask_socketio import join_room | |
| room_name = f"session_{session_id}" | |
| join_room(room_name) | |
| print(f"[SOCKET] Admin {request.sid} joined session room: {room_name}") | |
| emit('remote_session_joined', {'session_id': session_id}) | |
| def handle_camera_frame(data): | |
| if request.sid not in live_camera_sessions: | |
| emit('error', {'message': 'No active session'}) | |
| return | |
| try: | |
| frame_data = data['frame'].split(',')[1] | |
| frame_bytes = base64.b64decode(frame_data) | |
| nparr = np.frombuffer(frame_bytes, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| return | |
| session = live_camera_sessions[request.sid] | |
| session_id = session['session_id'] | |
| processed_frame, new_violations = process_live_frame( | |
| frame, session, session_id, request.sid) | |
| _, buffer = cv2.imencode('.jpg', processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
| proc_b64 = base64.b64encode(buffer).decode('utf-8') | |
| socketio.emit('processed_frame', { | |
| 'frame': f'data:image/jpeg;base64,{proc_b64}', | |
| 'violations': list(session['violations'].values()), # Send FULL list for state sync | |
| 'stats': { | |
| 'total_riders': len(session.get('total_riders', set())), | |
| 'safe_count': len(session.get('safe_tracks', set())), | |
| 'violation_count': len(session.get('violations', {})) | |
| } | |
| }, room=request.sid) | |
| # Also relay to any Admin viewers in this session's room | |
| room_name = f"session_{session_id}" | |
| socketio.emit('processed_frame_relay', { | |
| 'frame': f'data:image/jpeg;base64,{proc_b64}', | |
| 'violations': list(session['violations'].values()), | |
| 'stats': { | |
| 'total_riders': len(session.get('total_riders', set())), | |
| 'safe_count': len(session.get('safe_tracks', set())), | |
| 'violation_count': len(session.get('violations', {})) | |
| } | |
| }, room=room_name) | |
| except Exception as e: | |
| print(f"[SOCKET] Frame error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| emit('error', {'message': str(e)}) | |
| def handle_webrtc_offer(data): | |
| """Handles WebRTC offer from Publisher (Mobile).""" | |
| if request.sid not in live_camera_sessions: | |
| handle_start_camera({'session_id': data.get('session_id')}) | |
| session_id = data.get('session_id') | |
| offer = RTCSessionDescription(sdp=data['sdp'], type=data['type']) | |
| pc = RTCPeerConnection() | |
| pcs.add(pc) | |
| active_pcs[request.sid] = pc | |
| def on_icecandidate(candidate): | |
| if candidate: | |
| socketio.emit('ice_candidate', { | |
| 'candidate': candidate.candidate, | |
| 'sdpMid': candidate.sdpMid, | |
| 'sdpMLineIndex': candidate.sdpMLineIndex | |
| }, room=request.sid) | |
| async def on_connectionstatechange(): | |
| print(f"[WEBRTC] Publisher Connection ID {session_id}: {pc.connectionState}") | |
| if pc.connectionState == "failed": | |
| await pc.close() | |
| pcs.discard(pc) | |
| active_pcs.pop(request.sid, None) | |
| publisher_tracks.pop(session_id, None) | |
| def on_track(track): | |
| # ... (same as before) | |
| if track.kind == "video": | |
| processor = VideoProcessTrack(relay.subscribe(track), session_id, request.sid) | |
| async def run_processor(): | |
| while True: | |
| try: await processor.recv() | |
| except: break | |
| asyncio.run_coroutine_threadsafe(run_processor(), loop) | |
| publisher_tracks[session_id] = track | |
| async def create_answer(): | |
| await pc.setRemoteDescription(offer) | |
| answer = await pc.createAnswer() | |
| await pc.setLocalDescription(answer) | |
| return pc.localDescription | |
| local_desc = run_async(create_answer()) | |
| emit('webrtc_answer', {'sdp': local_desc.sdp, 'type': local_desc.type}) | |
| def handle_subscriber_offer(data): | |
| """Handles WebRTC offer from Admin (Viewer) wanting to see a relayed stream.""" | |
| session_id = data.get('session_id') | |
| if session_id not in publisher_tracks: | |
| emit('error', {'message': f'Stream {session_id} not available'}) | |
| return | |
| offer = RTCSessionDescription(sdp=data['sdp'], type=data['type']) | |
| pc = RTCPeerConnection() | |
| pcs.add(pc) | |
| active_pcs[request.sid] = pc | |
| def on_icecandidate(candidate): | |
| if candidate: | |
| socketio.emit('ice_candidate', { | |
| 'candidate': candidate.candidate, | |
| 'sdpMid': candidate.sdpMid, | |
| 'sdpMLineIndex': candidate.sdpMLineIndex | |
| }, room=request.sid) | |
| track = publisher_tracks[session_id] | |
| pc.addTrack(relay.subscribe(track)) | |
| async def on_connectionstatechange(): | |
| if pc.connectionState == "failed": | |
| await pc.close() | |
| pcs.discard(pc) | |
| active_pcs.pop(request.sid, None) | |
| async def create_answer(): | |
| await pc.setRemoteDescription(offer) | |
| answer = await pc.createAnswer() | |
| await pc.setLocalDescription(answer) | |
| return pc.localDescription | |
| local_desc = run_async(create_answer()) | |
| emit('subscriber_answer', {'sdp': local_desc.sdp, 'type': local_desc.type}) | |
| def handle_ice_candidate(data): | |
| if request.sid in active_pcs: | |
| pc = active_pcs[request.sid] | |
| print(f"[WEBRTC] Adding remote ICE candidate for {request.sid}") | |
| async def add_candidate(): | |
| try: | |
| # Some clients might send null candidate to indicate end-of-candidates | |
| if data.get('candidate'): | |
| candidate = RTCIceCandidate( | |
| candidate=data['candidate'], | |
| sdpMid=data.get('sdpMid'), | |
| sdpMLineIndex=data.get('sdpMLineIndex') | |
| ) | |
| await pc.addIceCandidate(candidate) | |
| except Exception as e: | |
| print(f"[WEBRTC] Error adding ICE candidate: {e}") | |
| run_async(add_candidate()) | |
| if __name__ == '__main__': | |
| socketio.run(app, host='0.0.0.0', port=7860, ssl_context='adhoc') |