Traffic-Safety / app.py
prthm11's picture
Update app.py
86a897d verified
#
from flask import Flask, render_template, request, jsonify, Response, send_from_directory
from flask_socketio import SocketIO, emit
import cv2
import numpy as np
import os
import json
import uuid
import threading
import queue
import torch
import time
from datetime import datetime
from collections import Counter
import base64
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.frame_selector import GlobalBufferManager
from rfdetr import RFDETRMedium
import supervision as sv
from gradio_client import Client, handle_file
# --- WebRTC Imports ----
import asyncio
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack, RTCIceCandidate
from aiortc.contrib.media import MediaRelay
import av
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-here'
app.config['SESSION_TYPE'] = 'filesystem'
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading', manage_session=False)
# ── Directories ───────────────────────────────────────────────────────────────-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static/uploads')
RESULTS_FOLDER = os.path.join(BASE_DIR, 'static/results')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
# ── HF Download ───────────────────────────────────────────────────────────────-
# /data is the HF Spaces persistent writable volume (enable in Space Settings).
# Fallback to /tmp if /data is not available (no persistent storage enabled).
from huggingface_hub import snapshot_download
if os.path.isdir("/data"):
MODEL_DIR = "/data/CV_MODELS"
else:
MODEL_DIR = "/tmp/CV_MODELS"
print("[WARNING] /data not available — using /tmp. Models will re-download on every restart.")
os.makedirs(MODEL_DIR, exist_ok=True)
if not os.listdir(MODEL_DIR):
print(f"[INIT] Downloading models to {MODEL_DIR} ...")
snapshot_download(
repo_id="WebAshlarWA/CV_MODELS",
local_dir=MODEL_DIR,
token=os.getenv("HF_TOKEN")
)
print("[INIT] Download complete.")
else:
print(f"[INIT] Models already present at {MODEL_DIR}, skipping download.")
# Debug: print all downloaded files so you can verify paths on first boot
print("[INIT] Files found in MODEL_DIR:")
for root, _, files in os.walk(MODEL_DIR):
for f in files:
print(f" {os.path.join(root, f)}")
def find_file(root, filename):
for r, _, files in os.walk(root):
if filename in files:
return os.path.join(r, filename)
return None
RIDER_WEIGHTS = find_file(MODEL_DIR, "checkpoint0039.pth")
HELMET_HEAD_WEIGHTS = find_file(MODEL_DIR, "checkpoint_best_ema_hel.pth")
PLATE_WEIGHTS = find_file(MODEL_DIR, "checkpoint_best_ema_plate.pth")
print(f"[INIT] RIDER_WEIGHTS: {RIDER_WEIGHTS}")
print(f"[INIT] HELMET_HEAD_WEIGHTS: {HELMET_HEAD_WEIGHTS}")
print(f"[INIT] PLATE_WEIGHTS: {PLATE_WEIGHTS}")
if not RIDER_WEIGHTS:
raise FileNotFoundError("checkpoint0039.pth not found — check MODEL_DIR file listing above.")
if not HELMET_HEAD_WEIGHTS:
raise FileNotFoundError("checkpoint_best_ema_hel.pth not found — check MODEL_DIR file listing above.")
if not PLATE_WEIGHTS:
raise FileNotFoundError("checkpoint_best_ema_plate.pth not found — check MODEL_DIR file listing above.")
# Class labels used only for logging / debug
# RIDER_CLASSES = ["rider"] # Stage-1 is a pure rider detector (1 class)
# If your rider checkpoint was trained on 2 classes keep the line below instead:
# # ── Model Weights ────────────────────────────────────────────────────────────--
# # Stage 1 – Rider / motorbike detector (finds riders in the full frame)
# RIDER_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoints/checkpoint0039.pth")
# # Stage 2 – Dedicated head / helmet model
# # General-purpose: detects every helmet / bare-head in a crop.
# # We constrain it to each rider bounding box so it never fires outside.
# HELMET_HEAD_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoints/checkpoint_best_ema_hel.pth")
# # Stage 3 – License plate detector (runs within rider box on violation)
# PLATE_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoints/checkpoint_best_ema_plate.pth")
# Class labels used only for logging / debug
#RIDER_CLASSES = ["rider"] # Stage-1 is a pure rider detector (1 class)
# If your rider checkpoint was trained on 2 classes keep the line below instead:
RIDER_CLASSES = ["motorbike and helmet", "motorbike and no helmet"]
# ── Stage-2 helmet model class mapping (Sync with camera_processor_LT.py) ─────
# class 1 → helmet (SAFE)
# class 2 → no-helmet (VIOLATION)
# (class 0 is often 'all' or 'rider' depending on model)
HELMET_CLASS_ID = 1
NO_HELMET_CLASS_ID = 2
NUM_HELMET_CLASSES = 3
# ══════════════════════════════════════════════════════════════════════════════
CONF_RIDER = 0.25 # Stage-1: minimum score to enter tracker
CONF_HELMET_HEAD = 0.25 # Stage-2: minimum score for helmet model
CONF_PLATE = 0.50 # Stage-3: plate detection score
# Decision engine (Unified)
# NOTE: CONF_THRESHOLD is the cutoff used by the consensus/voting stage.
# It should be lower than your model's single-frame "strong" cutoff.
#CONF_THRESHOLD = 0.35 # symmetric sensitivity threshold (lowered from 0.95)
CONF_THRESHOLD = 0.80
DECISION_MARGIN = 1.2 # confidence multiplier for margin decision
HISTORY_WINDOW = 30
RECENT_WINDOW = 10
SAFE_RECENT_H_HITS = 2
SAFE_RECENT_NH_HITS = 2
MIN_FRAMES_BEFORE_DECIDE = 12
VIOLATION_PERSIST_FRAMES = 18
# ── Bounding-box colours (Standardized) ──────────────────────────────────────
# Stage-1 boxes: Single uniform color for all riders
COLOR_RIDER_CYAN = (255, 255, 0) # Cyan in BGR
COLOR_SAFE = (0, 255, 0) # Green (Helmet)
COLOR_VIOLATION = (0, 0, 255) # Red (No-Helmet)
COLOR_PLATE_BOX = (255, 0, 255) # Magenta/Purple
COLOR_ANALYZING = (0, 200, 200) # Dark Cyan
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INIT] Device: {device}")
# ── Model Loading ─────────────────────────────────────────────────────────────
def _load_model(label, weights_path, num_classes):
print(f"[INIT] Loading {label}{weights_path}")
m = RFDETRMedium(num_classes=num_classes, pretrain_weights=weights_path)
try:
m.optimize_for_inference(compile=True, batch_size=1)
if device.type == "cuda":
m.model.half()
except Exception as e:
print(f"[WARNING] {label} optimisation failed: {e}")
return m
rider_model = _load_model("Rider Detection", RIDER_WEIGHTS, len(RIDER_CLASSES))
helmet_head_model = _load_model("Helmet/NoHelmet", HELMET_HEAD_WEIGHTS, NUM_HELMET_CLASSES)
plate_model = _load_model("Plate Detection", PLATE_WEIGHTS, 1)
# ── Shared session data ───────────────────────────────────────────────────────
current_session_data = {
"violations": {},
"safe_tracks": set(), # Track IDs confirmed as SAFE
"total_riders": set(), # All unique track IDs seen
"ocr_queue": queue.Queue(),
"track_plate_cache": {},
"track_capture_count":{},
"track_ocr_history": {},
"ocr_in_progress": set(),
"track_violation_age":{},
}
# ── Per-session best-frame buffer (video pipeline) ────────────────────────────
video_buf_manager = GlobalBufferManager()
live_camera_sessions = {}
json_lock = threading.Lock()
# --- WebRTC Global State ---
pcs = set()
active_pcs = {} # SID -> PeerConnection
relay = MediaRelay()
publisher_tracks = {} # Mapping session_id -> track
loop = asyncio.new_event_loop()
def start_async_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
threading.Thread(target=start_async_loop, args=(loop,), daemon=True).start()
# Helper to run async code from sync Socket.IO handlers
def run_async(coro):
return asyncio.run_coroutine_threadsafe(coro, loop).result()
# --- Custom WebRTC Video Track for processing ---
class VideoProcessTrack(MediaStreamTrack):
kind = "video"
def __init__(self, track, session_id, socket_sid):
super().__init__()
self.track = track
self.session_id = session_id
self.socket_sid = socket_sid
async def recv(self):
frame = await self.track.recv()
# Convert to numpy/opencv
img = frame.to_ndarray(format="bgr24")
# Access session data (The Publisher's session)
if self.socket_sid in live_camera_sessions:
session = live_camera_sessions[self.socket_sid]
# Process the frame
processed_img, new_violations = process_live_frame(
img, session, self.session_id, self.socket_sid)
# Emit results to the SESSION ROOM (so Admin on other device sees it)
room_name = f"session_{self.session_id}"
socketio.emit('processed_frame_relay', {
'violations': new_violations,
'stats': {
'total_riders': len(session.get('total_riders', set())),
'safe_count': len(session.get('safe_tracks', set())),
'violation_count': len(session.get('violations', {}))
}
}, room=room_name)
return frame
# ══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ══════════════════════════════════════════════════════════════════════════════
def get_best_consensus(results):
cleaned = [r.replace("\n", " ").strip() for r in results
if r not in ["API_ERROR", "PENDING...", ""]]
if not cleaned:
return "PENDING..."
if len(cleaned) == 1:
return cleaned[0]
max_len = max(len(r) for r in cleaned)
final_chars = []
for i in range(max_len):
pool = [r[i] for r in cleaned if i < len(r)]
if pool:
final_chars.append(Counter(pool).most_common(1)[0][0])
return "".join(final_chars).strip()
def clamp_box(box, w, h):
x1, y1, x2, y2 = map(int, box)
return [max(0, x1), max(0, y1), min(w - 1, x2), min(h - 1, y2)]
def expand_box_for_plate(box, w, h):
"""Shift & expand box downward to include the number plate area."""
x1, y1, x2, y2 = map(int, box)
bw, bh = x2 - x1, y2 - y1
return clamp_box(
[x1 - bw * 0.1, y1 + bh * 0.4, x2 + bw * 0.1, y2 + bh * 0.4], w, h
)
def parse_preds(preds, W, H, debug_tag=""):
"""Extract boxes/scores/labels from RFDETRMedium predictions."""
boxes, scores, labels = np.array([]), np.array([]), np.array([])
if hasattr(preds, "xyxy"):
boxes = preds.xyxy if isinstance(preds.xyxy, np.ndarray) else preds.xyxy.cpu().numpy()
scores = preds.confidence if isinstance(preds.confidence, np.ndarray) else preds.confidence.cpu().numpy()
labels = preds.class_id if isinstance(preds.class_id, np.ndarray) else preds.class_id.cpu().numpy()
if boxes.size > 0 and boxes.max() <= 1.01:
boxes = boxes.copy()
boxes[:, [0, 2]] *= W
boxes[:, [1, 3]] *= H
if debug_tag and boxes.size > 0:
for i in range(len(scores)):
print(f"[PARSE:{debug_tag}] det#{i} cls={int(labels[i])} conf={scores[i]:.3f}")
return boxes, scores, labels
def run_helmet_on_crop(frame_bgr, rx1, ry1, rx2, ry2):
"""
Run Stage-2 (helmet head model) on a single rider crop.
Returns (best_cls, best_conf, list_of_shifted_detections)
"""
h_orig, w_orig = frame_bgr.shape[:2]
# Boundary check: Clamping crop to frame
rx1, ry1 = max(0, rx1), max(0, ry1)
rx2, ry2 = min(w_orig, rx2), min(h_orig, ry2)
crop_h = ry2 - ry1
crop_w = rx2 - rx1
if crop_h < 20 or crop_w < 20: # Skip tiny crops
return -1, 0.0, []
crop_bgr = frame_bgr[ry1:ry2, rx1:rx2]
crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
with torch.no_grad():
hpreds = helmet_head_model.predict(crop_rgb, conf=CONF_HELMET_HEAD, iou=0.45)
hb, hs, hl = parse_preds(hpreds, crop_w, crop_h, debug_tag="helmet_head")
if hb.size == 0:
return -1, 0.0, []
# Shift detections back to main frame coordinates (Cascade Logic)
shifted_detections = []
for i in range(len(hb)):
sx1 = int(hb[i][0] + rx1)
sy1 = int(hb[i][1] + ry1)
sx2 = int(hb[i][2] + rx1)
sy2 = int(hb[i][3] + ry1)
shifted_detections.append({
"box": [sx1, sy1, sx2, sy2],
"score": float(hs[i]),
"class": int(hl[i])
})
# Prioritize HELMET over NO-HELMET for state evaluation
h_indices = np.where(hl == HELMET_CLASS_ID)[0]
if h_indices.size > 0:
best_h_idx = h_indices[np.argmax(hs[h_indices])]
return HELMET_CLASS_ID, float(hs[best_h_idx]), shifted_detections
nh_indices = np.where(hl == NO_HELMET_CLASS_ID)[0]
if nh_indices.size > 0:
best_nh_idx = nh_indices[np.argmax(hs[nh_indices])]
return NO_HELMET_CLASS_ID, float(hs[best_nh_idx]), shifted_detections
# No specific head signals found? Use the best overall if any
best = int(np.argmax(hs))
return int(hl[best]), float(hs[best]), shifted_detections
# def evaluate_helmet_state(hist, cls_idx, confidence, prev_violation, violation_age, rider_cls=-1):
# """
# Enhanced decision engine: sensitive to both Helmet and No-Helmet detections.
# Uses a weighted consensus rather than a single-hit override.
# """
# total = len(hist)
# recent = hist[-RECENT_WINDOW:]
# r_total = len(recent)
# # Counts based on lowered sensitivity thresholds
# f_h = sum(1 for h in hist if h['class'] == 0 and h['conf'] >= CONF_HELMET_CONFIRM)
# f_nh = sum(1 for h in hist if h['class'] != 0 and h['conf'] >= CONF_NO_HELMET_TRIGGER)
# r_h = sum(1 for h in recent if h['class'] == 0 and h['conf'] >= CONF_HELMET_CONFIRM)
# r_nh = sum(1 for h in recent if h['class'] != 0 and h['conf'] >= CONF_NO_HELMET_TRIGGER)
# r_nh_frac = r_nh / max(r_total, 1)
# # ── SAFE checks (Helmet detected) ────────────────────────────────────────
# # 1. Stage-1 Rider detector specifically says "with helmet" (High Priority)
# if rider_cls == 0:
# return False, True, "safe:rider_H"
# # 2. Strong current evidence of a helmet
# if cls_idx == 0 and confidence >= 0.50:
# return False, True, f"safe:cur_H_strong({confidence:.2f})"
# # 3. Decision by majority/consistency (Balanced sensitivity)
# # If we have significantly more helmet hits than no-helmet hits, it's safe.
# if f_h > f_nh and f_h >= 2:
# return False, True, f"safe:vote_H({f_h} vs {f_nh})"
# # ── VIOLATION gate (No-Helmet detected) ──────────────────────────────────
# # We trigger a violation if no-helmet hits dominate, even if there was a
# # sporadic/low-conf helmet hit (reduces false 'Safe' from noise).
# violation_gate = (
# total >= MIN_FRAMES_BEFORE_VIOLATION
# and f_nh > f_h * 2 # No-helmet hits must clearly outweigh helmet hits
# and r_nh >= RECENT_NH_MIN
# and r_nh_frac >= RECENT_NH_FRAC
# and f_nh >= MIN_NH_HITS_FULL
# and rider_cls != 0
# )
# if violation_gate:
# return True, False, f"violation(NH={f_nh}, H={f_h})"
# # ── Soft persistence ──────────────────────────────────────────────────────
# if prev_violation and violation_age <= VIOLATION_PERSIST_FRAMES and r_h == 0:
# return True, False, f"persist(age={violation_age})"
# # ── Default State ─────────────────────────────────────────────────────────
# debug = (f"analyzing({total}fr)" if total < MIN_FRAMES_BEFORE_VIOLATION
# else f"neutral(H={f_h},NH={f_nh})")
# # If we can't decide but have some helmet hits, default to Safe for UX stability
# if f_h >= 1 and not violation_gate:
# return False, True, f"safe:default_H({f_h})"
# return False, False, debug
def evaluate_helmet_state(hist, cls_idx, confidence,
prev_violation, violation_age,
rider_cls=-1):
"""
Balanced & confident temporal decision engine.
- Symmetric sensitivity via CONF_THRESHOLD
- Confidence-weighted historical scoring
- Recent count-based override for responsiveness
- Margin-based decision to avoid flips
"""
total = len(hist)
recent = hist[-RECENT_WINDOW:]
r_total = len(recent)
# Stage-1 override (high priority): if your Stage-1 mapping is helmet==0
if rider_cls == 0:
return False, True, "safe:rider_H"
# Strong current-frame shortcut (immediate decisions for very confident frames)
STRONG_SINGLE_FRAME = 0.60
if cls_idx == HELMET_CLASS_ID and confidence >= STRONG_SINGLE_FRAME:
return False, True, f"safe:cur_H_strong({confidence:.2f})"
if cls_idx == NO_HELMET_CLASS_ID and confidence >= STRONG_SINGLE_FRAME:
return True, False, f"violation:cur_NH_strong({confidence:.2f})"
# Confidence-weighted historical scoring (only sum confidences >= CONF_THRESHOLD)
score_h = sum(h['conf'] for h in hist
if h['class'] == HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD)
score_nh = sum(h['conf'] for h in hist
if h['class'] == NO_HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD)
# Recent count-based override (simpler and robust)
r_h = sum(1 for h in recent if h['class'] == HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD)
r_nh = sum(1 for h in recent if h['class'] == NO_HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD)
if r_h >= SAFE_RECENT_H_HITS:
return False, True, f"safe:recent_H(count={r_h})"
if r_nh >= SAFE_RECENT_NH_HITS:
return True, False, f"violation:recent_NH(count={r_nh})"
# Require minimal accumulation of frames before final decision
if total < MIN_FRAMES_BEFORE_DECIDE:
return False, False, f"warming({total}fr)"
# Margin-based confident decision
if score_h > score_nh * DECISION_MARGIN:
return False, True, f"safe:margin({score_h:.2f} vs {score_nh:.2f})"
if score_nh > score_h * DECISION_MARGIN:
return True, False, f"violation:margin({score_nh:.2f} vs {score_h:.2f})"
# Soft persistence (anti-flicker) for ongoing violations
if prev_violation and violation_age <= VIOLATION_PERSIST_FRAMES:
if r_h == 0: # do not persist if helmet appears in recent window
return True, False, f"persist(age={violation_age})"
# Neutral / uncertain state
return False, False, f"uncertain(H={score_h:.2f},NH={score_nh:.2f})"
def _draw_overlay(frame, x1, y1, x2, y2, tid, display_name, confidence, color, plate_text="", reason=""):
"""Unified overlay drawing for a single tracked rider."""
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label = f"ID:{tid} {display_name} {confidence:.2f}"
cv2.putText(frame, label, (x1, max(y1 - 10, 10)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
if reason:
cv2.putText(frame, f"({reason})", (x1, max(y1 - 25, 10)),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
if plate_text:
cv2.putText(frame, f"Plate: {plate_text}", (x1, y2 + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLOR_PLATE_BOX, 2)
# ══════════════════════════════════════════════════════════════════════════════
# OCR WORKER
# ══════════════════════════════════════════════════════════════════════════════
def background_ocr_worker():
"""Processes OCR tasks from the shared queue asynchronously."""
print("[OCR] Worker Thread Started")
client = None
for attempt in range(3):
try:
client = Client("WebashalarForML/demo-glm-ocr")
print("[OCR] Gradio Client Connected")
break
except Exception as e:
print(f"[OCR] Connection attempt {attempt + 1} failed: {e}")
if attempt == 2:
print("[OCR] Max retries reached. Worker exiting.")
return
time.sleep(2)
while True:
try:
task = current_session_data["ocr_queue"].get(timeout=1)
if task is None:
current_session_data["ocr_queue"].task_done()
continue
track_id, plate_path, session_id, socket_sid = task
if track_id in current_session_data["ocr_in_progress"]:
print(f"[OCR] ID {track_id} already in progress, skipping")
current_session_data["ocr_queue"].task_done()
continue
current_session_data["ocr_in_progress"].add(track_id)
if not os.path.exists(plate_path):
current_session_data["ocr_queue"].task_done()
current_session_data["ocr_in_progress"].discard(track_id)
continue
plate_text = "API_ERROR"
try:
result = client.predict(image=handle_file(plate_path),
api_name="/proses_intelijen")
plate_text = str(result).strip()
print(f"[OCR] ID {track_id}: {plate_text}")
except Exception as e:
print(f"[OCR] API error for ID {track_id}: {e}")
is_live = session_id.startswith('live_')
if is_live:
if socket_sid and socket_sid in live_camera_sessions:
session = live_camera_sessions[socket_sid]
if track_id not in session["track_ocr_history"]:
session["track_ocr_history"][track_id] = []
if plate_text not in ["API_ERROR", ""]:
session["track_ocr_history"][track_id].append(plate_text)
final = get_best_consensus(session["track_ocr_history"][track_id])
session["track_plate_cache"][track_id] = final
with json_lock:
if track_id in session["violations"]:
session["violations"][track_id]["plate_number"] = final
session["violations"][track_id]["ocr_attempts"] = session["track_ocr_history"][track_id]
socketio.emit('ocr_update', {
'track_id': track_id,
'plate_number': final,
'violation': session["violations"][track_id]
}, room=socket_sid)
else:
if track_id not in current_session_data["track_ocr_history"]:
current_session_data["track_ocr_history"][track_id] = []
if plate_text not in ["API_ERROR", ""]:
current_session_data["track_ocr_history"][track_id].append(plate_text)
final = get_best_consensus(current_session_data["track_ocr_history"][track_id])
current_session_data["track_plate_cache"][track_id] = final
with json_lock:
if track_id in current_session_data["violations"]:
current_session_data["violations"][track_id]["plate_number"] = final
current_session_data["violations"][track_id]["ocr_attempts"] = (
current_session_data["track_ocr_history"][track_id])
json_path = os.path.join(RESULTS_FOLDER, f"session_{session_id}.json")
with open(json_path, 'w') as f:
json.dump(list(current_session_data["violations"].values()), f, indent=4)
current_session_data["ocr_in_progress"].discard(track_id)
current_session_data["ocr_queue"].task_done()
except queue.Empty:
continue
except Exception as e:
print(f"[OCR] Loop Error: {e}")
import traceback
traceback.print_exc()
if 'track_id' in locals():
current_session_data["ocr_in_progress"].discard(track_id)
try:
current_session_data["ocr_queue"].task_done()
except Exception:
pass
NUM_OCR_WORKERS = 3
for _i in range(NUM_OCR_WORKERS):
_t = threading.Thread(target=background_ocr_worker, daemon=True, name=f"OCR-Worker-{_i+1}")
_t.start()
print(f"[INIT] Started OCR Worker {_i+1}")
# ══════════════════════════════════════════════════════════════════════════════
# THREE-MODEL PIPELINE – VIDEO (generator)
# ══════════════════════════════════════════════════════════════════════════════
def process_video_gen(video_path, session_id):
cap = cv2.VideoCapture(video_path)
tracker = sv.ByteTrack()
track_class_history = {}
track_violation_memory= {}
track_last_seen = {}
track_violation_age = {}
dead_ids = set()
frame_idx = 0
prev_frame = None # for motion scoring
video_buf_manager.reset_all() # clean slate for this video
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_idx += 1
h_orig, w_orig = frame.shape[:2]
# ── Expire old tracks & force-flush their buffers ─────────────────────
newly_dead = [t for t, last in track_last_seen.items()
if frame_idx - last > 50 and t not in dead_ids]
for tid in newly_dead:
dead_ids.add(tid)
# Force-flush dead tracks (short-track safety net)
dead_flushes = video_buf_manager.force_flush_dead(set(newly_dead))
for tid, best_entry in dead_flushes.items():
if best_entry.plate_crop is not None and best_entry.plate_crop.size > 0:
pname = f"viol_{session_id}_{tid}_plate_best.jpg"
ppath = os.path.join(RESULTS_FOLDER, pname)
cv2.imwrite(ppath, best_entry.plate_crop)
with json_lock:
if tid in current_session_data["violations"]:
current_session_data["violations"][tid]["plate_image_url"] = (
f"/static/results/{pname}")
current_session_data["ocr_queue"].put((tid, ppath, session_id, None))
print(f"[BUFFER] Dead-track flush: tid={tid} frame={best_entry.frame_idx} score={best_entry.score:.3f}")
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# ══ STAGE 1 – Rider Detection ═════════════════════════════════════════
with torch.no_grad():
rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.45)
r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig, debug_tag="rider")
if r_boxes.size > 0:
detections = sv.Detections(
xyxy=r_boxes.astype(np.float32),
confidence=r_scores.astype(np.float32),
class_id=r_labels.astype(np.int32)
)
# [USER BUG FIX] Apply NMS to prevent overlapping rider boxes
detections = detections.with_nms(threshold=0.5)
else:
detections = sv.Detections.empty()
detections = tracker.update_with_detections(detections)
for (xyxy, _mask, rider_conf, rider_cls, tracker_id, _data) in detections:
if tracker_id is None:
continue
tid = int(tracker_id)
track_last_seen[tid] = frame_idx
x1, y1, x2, y2 = map(int, xyxy)
# ══ STAGE 2 – Helmet / No-Helmet (within rider crop) ═════════════
h_cls, h_conf, h_dets = run_helmet_on_crop(frame, x1, y1, x2, y2)
# Draw individual child detections (Cascade Coordinate Shifting)
for det in h_dets:
dx1, dy1, dx2, dy2 = det["box"]
d_color = COLOR_SAFE if det["class"] == HELMET_CLASS_ID else COLOR_VIOLATION
cv2.rectangle(frame, (dx1, dy1), (dx2, dy2), d_color, 1)
d_label = "H" if det["class"] == HELMET_CLASS_ID else "NH"
cv2.putText(frame, f"{d_label}:{det['score']:.2f}", (dx1, dy1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.3, d_color, 1)
# --- unified history append: prefer Stage-2, fall back to Stage-1 ---
if tid not in track_class_history:
track_class_history[tid] = []
# debug: quickly log class disagreement for diagnosis
# print(f"[DBG-HIST] tid={tid} stage2_h_cls={h_cls} stage2_h_conf={h_conf:.2f} stage1_rider_cls={rider_cls}")
if h_cls >= 0:
# use helmet-head model (preferred source of truth)
track_class_history[tid].append({"class": h_cls, "conf": h_conf})
else:
# fallback to stage-1 only when stage-2 had no head detection
if int(rider_cls) == 0:
track_class_history[tid].append({"class": HELMET_CLASS_ID, "conf": 0.99})
# clamp history
if len(track_class_history[tid]) > HISTORY_WINDOW:
track_class_history[tid].pop(0)
hist = track_class_history.get(tid, [])
prev_viol = track_violation_memory.get(tid, False)
viol_age = track_violation_age.get(tid, 0)
is_no_helmet, is_safe, dbg = evaluate_helmet_state(
hist, h_cls, h_conf, prev_viol, viol_age, rider_cls=int(rider_cls))
if is_no_helmet:
track_violation_memory[tid] = True
track_violation_age[tid] = viol_age + 1
with json_lock:
current_session_data["safe_tracks"].discard(tid)
else:
track_violation_memory[tid] = False
track_violation_age[tid] = 0
if is_safe:
with json_lock:
current_session_data["safe_tracks"].add(tid)
if prev_viol:
current_session_data["violations"].pop(tid, None)
with json_lock:
current_session_data["total_riders"].add(tid)
print(f"[TRACK] ID={tid} h_cls={h_cls} h_conf={h_conf:.2f} | {dbg}")
# ── Display state ─────────────────────────────────────────────────
total_hist = len(hist)
if is_no_helmet:
display_name = "VIOLATION: NO HELMET"
color = COLOR_VIOLATION
elif is_safe:
display_name = "SAFE: HELMET"
color = COLOR_SAFE
elif total_hist < MIN_FRAMES_BEFORE_DECIDE:
display_name = "ANALYZING..."
color = COLOR_ANALYZING
else:
display_name = "RIDER"
color = COLOR_RIDER_CYAN
plate_text = current_session_data["track_plate_cache"].get(tid, "")
# ══ STAGE 3 – Plate Detection (violation only) ════════════════════
if is_no_helmet and tid not in dead_ids:
with json_lock:
if tid not in current_session_data["violations"]:
ts = datetime.now()
rider_img_name = f"viol_{session_id}_{tid}_rider.jpg"
cv2.imwrite(os.path.join(RESULTS_FOLDER, rider_img_name),
frame[y1:y2, x1:x2])
current_session_data["violations"][tid] = {
"id": tid,
"timestamp": ts.strftime('%H:%M:%S'),
"type": "No Helmet",
"plate_number": "Scanning...",
"image_url": f"/static/results/{rider_img_name}",
"plate_image_url": None,
"ocr_attempts": [],
"raw": {
"confidence": float(h_conf),
"box": xyxy.tolist()
}
}
current_session_data["track_capture_count"][tid] = 0
# ── Buffer-based best-frame plate capture ─────────────────────
eb = expand_box_for_plate(xyxy, w_orig, h_orig)
plate_crop_region = frame[eb[1]:eb[3], eb[0]:eb[2]]
if plate_crop_region.size > 0:
with torch.no_grad():
plate_preds = plate_model.predict(
cv2.cvtColor(plate_crop_region, cv2.COLOR_BGR2RGB),
conf=CONF_PLATE, iou=0.45)
pb, ps, _pl = parse_preds(plate_preds,
plate_crop_region.shape[1],
plate_crop_region.shape[0])
if pb.size > 0:
best_det = int(np.argmax(ps))
px1, py1, px2, py2 = map(int, pb[best_det])
plate_crop = plate_crop_region[py1:py2, px1:px2]
if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20:
# Build rider ROI for motion scoring
curr_roi = frame[y1:y2, x1:x2]
prev_roi = prev_frame[y1:y2, x1:x2] if prev_frame is not None else None
# Add to quality buffer (always, even if only 1 frame)
video_buf_manager.add(
tid, plate_crop, (x1, y1, x2, y2),
frame_idx, prev_roi, curr_roi
)
print(f"[BUFFER] tid={tid} frame={frame_idx} buffered plate crop")
# Try flush: post-peak or timeout trigger
if video_buf_manager.should_flush(tid, (x1, y1, x2, y2), frame_idx):
best_entry = video_buf_manager.flush(tid)
if best_entry is not None:
snap = current_session_data["track_capture_count"].get(tid, 0) + 1
pname = f"viol_{session_id}_{tid}_plate_best{snap}.jpg"
ppath = os.path.join(RESULTS_FOLDER, pname)
cv2.imwrite(ppath, best_entry.plate_crop)
with json_lock:
current_session_data["violations"][tid]["plate_image_url"] = (
f"/static/results/{pname}")
current_session_data["ocr_queue"].put((tid, ppath, session_id, None))
current_session_data["track_capture_count"][tid] = snap
print(f"[BUFFER] tid={tid} FLUSH → frame={best_entry.frame_idx} score={best_entry.score:.3f}")
_draw_overlay(frame, x1, y1, x2, y2, tid, display_name, h_conf, color, plate_text, reason=dbg)
prev_frame = frame.copy() # store for next-frame motion scoring
_, buffer = cv2.imencode('.jpg', frame)
yield (b'--frame\r\nContent-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n')
# ── Final cleanup: flush all remaining active tracks ─────────────────────
active_ids = set(track_last_seen.keys()) - dead_ids
if active_ids:
print(f"[VIDEO-END] Flushing {len(active_ids)} remaining tracks")
# 1. Force-flush buffers for remaining active tracks
final_flushes = video_buf_manager.force_flush_dead(active_ids)
for tid, best_entry in final_flushes.items():
if best_entry.plate_crop is not None and best_entry.plate_crop.size > 0:
pname = f"viol_{session_id}_{tid}_plate_final.jpg"
ppath = os.path.join(RESULTS_FOLDER, pname)
cv2.imwrite(ppath, best_entry.plate_crop)
with json_lock:
if tid in current_session_data["violations"]:
current_session_data["violations"][tid]["plate_image_url"] = (
f"/static/results/{pname}")
current_session_data["ocr_queue"].put((tid, ppath, session_id, None))
print(f"[BUFFER] End-of-video flush: tid={tid} frame={best_entry.frame_idx}")
# 2. Aggressive evaluation for tracks that were visible until the end
for tid in active_ids:
hist = track_class_history.get(tid, [])
if not hist: continue
prev_viol = track_violation_memory.get(tid, False)
if prev_viol: continue # already a violation
# If visible >= 3 frames and has any No-Helmet signals, consider it if end-of-feed
nh_hits = [h for h in hist if h['class'] == NO_HELMET_CLASS_ID and h['conf'] >= CONF_THRESHOLD]
if len(hist) >= 3 and len(nh_hits) >= 1:
best_nh = max(nh_hits, key=lambda x: x['conf'])
with json_lock:
if tid not in current_session_data["violations"]:
ts = datetime.now()
# Use prev_frame (the last valid frame) to save a crop if we know where it was
# We don't have the last xyxy easily here unless we store it.
# For simplicity, we'll mark it but the plate flush above is more important.
rider_img_name = f"viol_{session_id}_{tid}_rider_final.jpg"
# If we had the last xyxy, we could cv2.imwrite here.
# since we don't store track_last_box, we'll just use a placeholder or the plate url.
current_session_data["violations"][tid] = {
"id": tid,
"timestamp": ts.strftime('%H:%M:%S'),
"type": "No Helmet (Final)",
"plate_number": "Scanning...",
"image_url": f"/static/results/{rider_img_name}", # Will be updated if possible
"plate_image_url": None,
"ocr_attempts": [],
"raw": {
"confidence": float(best_nh['conf']),
"box": []
}
}
print(f"[VIDEO-END] Forced violation for tid={tid}")
cap.release()
print(f"[VIDEO-END] Processing complete for session {session_id}")
# ══════════════════════════════════════════════════════════════════════════════
# THREE-MODEL PIPELINE – LIVE CAMERA (socket frame)
# ══════════════════════════════════════════════════════════════════════════════
def process_live_frame(frame, session, session_id, socket_sid):
tracker = session['tracker']
track_class_history = session['track_class_history']
track_violation_memory = session['track_violation_memory']
track_violation_age = session.setdefault('track_violation_age', {})
track_last_seen = session['track_last_seen']
dead_ids = session['dead_ids']
live_buf_mgr = session['plate_buf_manager'] # per-session buffer
prev_frame = session.get('prev_frame') # for motion scoring
session['frame_idx'] += 1
frame_idx = session['frame_idx']
# Expire old tracks & force-flush their buffers (short-track safety net)
newly_dead = [t for t, last in track_last_seen.items()
if frame_idx - last > 50 and t not in dead_ids]
for tid in newly_dead:
dead_ids.add(tid)
dead_flushes = live_buf_mgr.force_flush_dead(set(newly_dead))
for tid, best_entry in dead_flushes.items():
if best_entry.plate_crop is not None and best_entry.plate_crop.size > 0:
pname = f"viol_live_{session_id}_{tid}_plate_best.jpg"
ppath = os.path.join(RESULTS_FOLDER, pname)
cv2.imwrite(ppath, best_entry.plate_crop)
with json_lock:
if tid in session['violations']:
session['violations'][tid]["plate_image_url"] = (
f"/static/results/{pname}")
current_session_data["ocr_queue"].put(
(tid, ppath, f"live_{session_id}", socket_sid))
print(f"[LIVE-BUFFER] Dead-track flush: tid={tid} frame={best_entry.frame_idx} score={best_entry.score:.3f}")
h_orig, w_orig = frame.shape[:2]
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# ══ STAGE 1 – Rider Detection ═════════════════════════════════════════════
with torch.no_grad():
rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.45)
r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig, debug_tag="rider")
if r_boxes.size > 0:
detections = sv.Detections(
xyxy=r_boxes.astype(np.float32),
confidence=r_scores.astype(np.float32),
class_id=r_labels.astype(np.int32)
)
# Apply NMS
detections = detections.with_nms(threshold=0.5)
else:
detections = sv.Detections.empty()
detections = tracker.update_with_detections(detections)
new_violations = []
for (xyxy, _mask, rider_conf, rider_cls, tracker_id, _data) in detections:
if tracker_id is None:
continue
tid = int(tracker_id)
track_last_seen[tid] = frame_idx
x1, y1, x2, y2 = map(int, xyxy)
# ══ STAGE 2 – Helmet / No-Helmet (within rider crop) ═════════════════
h_cls, h_conf, h_dets = run_helmet_on_crop(frame, x1, y1, x2, y2)
# Draw child detections
for det in h_dets:
dx1, dy1, dx2, dy2 = det["box"]
d_color = COLOR_SAFE if det["class"] == HELMET_CLASS_ID else COLOR_VIOLATION
cv2.rectangle(frame, (dx1, dy1), (dx2, dy2), d_color, 1)
d_label = "H" if det["class"] == HELMET_CLASS_ID else "NH"
cv2.putText(frame, f"{d_label}:{det['score']:.2f}", (dx1, dy1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.3, d_color, 1)
# --- unified history append: prefer Stage-2, fall back to Stage-1 ---
if tid not in track_class_history:
track_class_history[tid] = []
if h_cls >= 0:
# use helmet-head model (preferred source of truth)
track_class_history[tid].append({"class": h_cls, "conf": h_conf})
else:
# fallback to stage-1 only when stage-2 had no head detection
if int(rider_cls) == 0:
track_class_history[tid].append({"class": HELMET_CLASS_ID, "conf": 0.99})
# clamp history
if len(track_class_history[tid]) > HISTORY_WINDOW:
track_class_history[tid].pop(0)
# debug: quickly log class disagreement
# print(f"[LIVE-DBG-HIST] tid={tid} stage2_h_cls={h_cls} stage2_h_conf={h_conf:.2f} stage1_rider_cls={rider_cls}")
hist = track_class_history.get(tid, [])
prev_viol = track_violation_memory.get(tid, False)
viol_age = track_violation_age.get(tid, 0)
is_no_helmet, is_safe, dbg = evaluate_helmet_state(
hist, h_cls, h_conf, prev_viol, viol_age, rider_cls=int(rider_cls))
if is_no_helmet:
track_violation_memory[tid] = True
track_violation_age[tid] = viol_age + 1
with json_lock:
session.setdefault('safe_tracks', set()).discard(tid)
else:
track_violation_memory[tid] = False
track_violation_age[tid] = 0
if is_safe:
with json_lock:
session.setdefault('safe_tracks', set()).add(tid)
if prev_viol:
session['violations'].pop(tid, None)
with json_lock:
session.setdefault('total_riders', set()).add(tid)
print(f"[LIVE-TRACK] ID={tid} h_cls={h_cls} h_conf={h_conf:.2f} | {dbg}")
# ── Display state ─────────────────────────────────────────────────────
total_hist = len(hist)
if is_no_helmet:
display_name = "VIOLATION: NO HELMET"
color = COLOR_VIOLATION
elif is_safe:
display_name = "SAFE: HELMET"
color = COLOR_SAFE
elif total_hist < MIN_FRAMES_BEFORE_DECIDE:
display_name = "ANALYZING..."
color = COLOR_ANALYZING
else:
display_name = "NO VIOLATION"
color = COLOR_SAFE
plate_text = session['track_plate_cache'].get(tid, "")
# ══ STAGE 3 – Plate Detection (violation only) ════════════════════════
if is_no_helmet and tid not in dead_ids:
with json_lock:
if tid not in session['violations']:
ts = datetime.now()
rider_img_name = f"viol_live_{session_id}_{tid}_rider.jpg"
cv2.imwrite(os.path.join(RESULTS_FOLDER, rider_img_name),
frame[y1:y2, x1:x2])
viol_record = {
"id": tid,
"timestamp": ts.strftime('%H:%M:%S'),
"type": "No Helmet",
"plate_number": "Scanning...",
"image_url": f"/static/results/{rider_img_name}",
"plate_image_url": None,
"ocr_attempts": [],
"raw": {
"confidence": float(h_conf),
"box": xyxy.tolist()
}
}
session['violations'][tid] = viol_record
session['track_capture_count'][tid] = 0
new_violations.append(viol_record)
if session['track_capture_count'].get(tid, 0) < 5:
eb = expand_box_for_plate(xyxy, w_orig, h_orig)
plate_crop_region = frame[eb[1]:eb[3], eb[0]:eb[2]]
if plate_crop_region.size > 0:
with torch.no_grad():
plate_preds = plate_model.predict(
cv2.cvtColor(plate_crop_region, cv2.COLOR_BGR2RGB),
conf=CONF_PLATE, iou=0.45)
pb, ps, _pl = parse_preds(plate_preds,
plate_crop_region.shape[1],
plate_crop_region.shape[0])
if pb.size > 0:
best_det = int(np.argmax(ps))
px1, py1, px2, py2 = map(int, pb[best_det])
plate_crop = plate_crop_region[py1:py2, px1:px2]
if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20:
# Motion scoring ROIs
curr_roi = frame[y1:y2, x1:x2]
prev_roi = prev_frame[y1:y2, x1:x2] if prev_frame is not None else None
# Buffer the scored plate crop (short-track safe)
live_buf_mgr.add(
tid, plate_crop, (x1, y1, x2, y2),
frame_idx, prev_roi, curr_roi
)
print(f"[LIVE-BUFFER] tid={tid} frame={frame_idx} buffered plate crop")
# Flush trigger: post-peak or hard timeout
if live_buf_mgr.should_flush(tid, (x1, y1, x2, y2), frame_idx):
best_entry = live_buf_mgr.flush(tid)
if best_entry is not None:
snap = session['track_capture_count'].get(tid, 0) + 1
pname = f"viol_live_{session_id}_{tid}_plate_best{snap}.jpg"
ppath = os.path.join(RESULTS_FOLDER, pname)
cv2.imwrite(ppath, best_entry.plate_crop)
with json_lock:
session['violations'][tid]["plate_image_url"] = (
f"/static/results/{pname}")
current_session_data["ocr_queue"].put(
(tid, ppath, f"live_{session_id}", socket_sid))
session['track_capture_count'][tid] = snap
print(f"[LIVE-BUFFER] tid={tid} FLUSH → frame={best_entry.frame_idx} score={best_entry.score:.3f}")
_draw_overlay(frame, x1, y1, x2, y2, tid, display_name, h_conf, color, plate_text, reason=dbg)
session['prev_frame'] = frame.copy() # store for next-frame motion scoring
return frame, new_violations
# ══════════════════════════════════════════════════════════════════════════════
# FLASK ROUTES
# ══════════════════════════════════════════════════════════════════════════════
@app.route('/')
def index():
return render_template('landing.html')
@app.route('/dashboard')
def dashboard():
return render_template('dashboard.html')
@app.route('/publisher')
def publisher():
return render_template('publisher.html')
@app.route('/camera_debug')
def camera_debug():
return render_template('camera_debug.html')
@app.route('/test_simple')
def test_simple():
return send_from_directory('.', 'test_simple.html')
@app.route('/test_socket_echo')
def test_socket_echo():
return send_from_directory('.', 'test_socket_echo.html')
@app.route('/upload', methods=['POST'])
def upload_video():
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
session_id = str(uuid.uuid4())[:8]
with json_lock:
current_session_data["violations"] = {}
current_session_data["safe_tracks"] = set()
current_session_data["total_riders"] = set()
current_session_data["track_plate_cache"] = {}
current_session_data["track_capture_count"] = {}
current_session_data["track_ocr_history"] = {}
current_session_data["track_violation_age"] = {}
while not current_session_data["ocr_queue"].empty():
try:
current_session_data["ocr_queue"].get_nowait()
except Exception:
pass
filename = f"{session_id}_{file.filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
return jsonify({"filename": filename, "session_id": session_id})
@app.route('/get_violations')
def get_violations():
with json_lock:
data = list(current_session_data["violations"].values())
data.reverse()
return jsonify(data)
@app.route('/get_stats')
def get_stats():
with json_lock:
return jsonify({
'total_riders': len(current_session_data.get('total_riders', set())),
'safe_count': len(current_session_data.get('safe_tracks', set())),
'violation_count': len(current_session_data.get('violations', {}))
})
@app.route('/video_feed/<filename>/<session_id>')
def video_feed(filename, session_id):
filepath = os.path.join(UPLOAD_FOLDER, filename)
return Response(process_video_gen(filepath, session_id),
mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/mobile/<session_id>')
def mobile_node(session_id):
return render_template('mobile.html', session_id=session_id)
@app.route('/upload_frame/<session_id>', methods=['POST'])
def upload_frame(session_id):
return jsonify({"status": "received"})
@app.route('/get_live_violations/<session_id>')
def get_live_violations(session_id):
for sid, session in live_camera_sessions.items():
if session['session_id'] == session_id:
with json_lock:
data = list(session['violations'].values())
data.reverse()
return jsonify(data)
return jsonify([])
@app.route('/api/sessions')
def get_active_sessions():
"""Returns a list of all currently active session IDs."""
sessions = []
for sid, data in live_camera_sessions.items():
if 'session_id' in data:
sessions.append(data['session_id'])
return jsonify({"sessions": sessions})
# ══════════════════════════════════════════════════════════════════════════════
# SOCKET.IO – LIVE CAMERA
# ══════════════════════════════════════════════════════════════════════════════
@socketio.on('connect')
def handle_connect():
print(f"[SOCKET] Client connected: {request.sid}")
emit('connection_response', {'status': 'connected'})
@socketio.on('disconnect')
def handle_disconnect():
print(f"[SOCKET] Client disconnected: {request.sid}")
live_camera_sessions.pop(request.sid, None)
@socketio.on('start_camera_session')
def handle_start_camera(data):
session_id = data.get('session_id', str(uuid.uuid4())[:8])
print(f"[SOCKET] Starting camera session: {session_id} for {request.sid}")
# Join the session room
from flask_socketio import join_room
join_room(f"session_{session_id}")
live_camera_sessions[request.sid] = {
'session_id': session_id,
'tracker': sv.ByteTrack(),
'track_class_history': {},
'track_violation_memory':{},
'track_violation_age': {},
'track_last_seen': {},
'dead_ids': set(),
'frame_idx': 0,
'violations': {},
'safe_tracks': set(),
'total_riders': set(),
'track_plate_cache': {},
'track_capture_count': {},
'track_ocr_history': {},
'plate_buf_manager': GlobalBufferManager(), # best-frame buffer
'prev_frame': None, # for motion scoring
}
emit('camera_session_started', {'session_id': session_id})
@socketio.on('join_remote_session')
def handle_join_remote(data):
"""Allows Admin to watch a Publisher's session results."""
session_id = data.get('session_id')
if not session_id:
return
from flask_socketio import join_room
room_name = f"session_{session_id}"
join_room(room_name)
print(f"[SOCKET] Admin {request.sid} joined session room: {room_name}")
emit('remote_session_joined', {'session_id': session_id})
@socketio.on('camera_frame')
def handle_camera_frame(data):
if request.sid not in live_camera_sessions:
emit('error', {'message': 'No active session'})
return
try:
frame_data = data['frame'].split(',')[1]
frame_bytes = base64.b64decode(frame_data)
nparr = np.frombuffer(frame_bytes, np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is None:
return
session = live_camera_sessions[request.sid]
session_id = session['session_id']
processed_frame, new_violations = process_live_frame(
frame, session, session_id, request.sid)
_, buffer = cv2.imencode('.jpg', processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
proc_b64 = base64.b64encode(buffer).decode('utf-8')
socketio.emit('processed_frame', {
'frame': f'data:image/jpeg;base64,{proc_b64}',
'violations': list(session['violations'].values()), # Send FULL list for state sync
'stats': {
'total_riders': len(session.get('total_riders', set())),
'safe_count': len(session.get('safe_tracks', set())),
'violation_count': len(session.get('violations', {}))
}
}, room=request.sid)
# Also relay to any Admin viewers in this session's room
room_name = f"session_{session_id}"
socketio.emit('processed_frame_relay', {
'frame': f'data:image/jpeg;base64,{proc_b64}',
'violations': list(session['violations'].values()),
'stats': {
'total_riders': len(session.get('total_riders', set())),
'safe_count': len(session.get('safe_tracks', set())),
'violation_count': len(session.get('violations', {}))
}
}, room=room_name)
except Exception as e:
print(f"[SOCKET] Frame error: {e}")
import traceback
traceback.print_exc()
emit('error', {'message': str(e)})
@socketio.on('webrtc_offer')
def handle_webrtc_offer(data):
"""Handles WebRTC offer from Publisher (Mobile)."""
if request.sid not in live_camera_sessions:
handle_start_camera({'session_id': data.get('session_id')})
session_id = data.get('session_id')
offer = RTCSessionDescription(sdp=data['sdp'], type=data['type'])
pc = RTCPeerConnection()
pcs.add(pc)
active_pcs[request.sid] = pc
@pc.on("icecandidate")
def on_icecandidate(candidate):
if candidate:
socketio.emit('ice_candidate', {
'candidate': candidate.candidate,
'sdpMid': candidate.sdpMid,
'sdpMLineIndex': candidate.sdpMLineIndex
}, room=request.sid)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print(f"[WEBRTC] Publisher Connection ID {session_id}: {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
active_pcs.pop(request.sid, None)
publisher_tracks.pop(session_id, None)
@pc.on("track")
def on_track(track):
# ... (same as before)
if track.kind == "video":
processor = VideoProcessTrack(relay.subscribe(track), session_id, request.sid)
async def run_processor():
while True:
try: await processor.recv()
except: break
asyncio.run_coroutine_threadsafe(run_processor(), loop)
publisher_tracks[session_id] = track
async def create_answer():
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return pc.localDescription
local_desc = run_async(create_answer())
emit('webrtc_answer', {'sdp': local_desc.sdp, 'type': local_desc.type})
@socketio.on('subscriber_offer')
def handle_subscriber_offer(data):
"""Handles WebRTC offer from Admin (Viewer) wanting to see a relayed stream."""
session_id = data.get('session_id')
if session_id not in publisher_tracks:
emit('error', {'message': f'Stream {session_id} not available'})
return
offer = RTCSessionDescription(sdp=data['sdp'], type=data['type'])
pc = RTCPeerConnection()
pcs.add(pc)
active_pcs[request.sid] = pc
@pc.on("icecandidate")
def on_icecandidate(candidate):
if candidate:
socketio.emit('ice_candidate', {
'candidate': candidate.candidate,
'sdpMid': candidate.sdpMid,
'sdpMLineIndex': candidate.sdpMLineIndex
}, room=request.sid)
track = publisher_tracks[session_id]
pc.addTrack(relay.subscribe(track))
@pc.on("connectionstatechange")
async def on_connectionstatechange():
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
active_pcs.pop(request.sid, None)
async def create_answer():
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return pc.localDescription
local_desc = run_async(create_answer())
emit('subscriber_answer', {'sdp': local_desc.sdp, 'type': local_desc.type})
@socketio.on('ice_candidate')
def handle_ice_candidate(data):
if request.sid in active_pcs:
pc = active_pcs[request.sid]
print(f"[WEBRTC] Adding remote ICE candidate for {request.sid}")
async def add_candidate():
try:
# Some clients might send null candidate to indicate end-of-candidates
if data.get('candidate'):
candidate = RTCIceCandidate(
candidate=data['candidate'],
sdpMid=data.get('sdpMid'),
sdpMLineIndex=data.get('sdpMLineIndex')
)
await pc.addIceCandidate(candidate)
except Exception as e:
print(f"[WEBRTC] Error adding ICE candidate: {e}")
run_async(add_candidate())
if __name__ == '__main__':
socketio.run(app, host='0.0.0.0', port=7860, ssl_context='adhoc')