Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, jsonify, Response, send_from_directory | |
| from flask_socketio import SocketIO, emit | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import json | |
| import uuid | |
| import threading | |
| import queue | |
| import torch | |
| import time | |
| from datetime import datetime | |
| from collections import Counter | |
| import base64 | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from rfdetr import RFDETRMedium | |
| import supervision as sv | |
| from gradio_client import Client, handle_file | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = 'your-secret-key-here' | |
| app.config['SESSION_TYPE'] = 'filesystem' | |
| socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading', manage_session=False) | |
| # --- CONFIG --- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static/uploads') | |
| RESULTS_FOLDER = os.path.join(BASE_DIR, 'static/results') | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
| # Model Paths | |
| HELMET_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoint_best_ema_bike_helmet.pth") | |
| PLATE_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoint_best_ema_plate.pth") | |
| HELMET_CLASSES = ["motorbike and helmet", "motorbike and no helmet"] | |
| # Thresholds (matching test.py) | |
| CONF_RIDER = 0.25 | |
| CONF_PLATE = 0.30 | |
| CONF_HELMET_CONFIRM = 0.40 | |
| CONF_NO_HELMET_TRIGGER = 0.35 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[INIT] Targeted Device: {device}") | |
| # --- MODEL INITIALIZATION --- | |
| print("[INIT] Loading Rider Detection Model...") | |
| rider_model = RFDETRMedium(num_classes=len(HELMET_CLASSES), pretrain_weights=HELMET_WEIGHTS) | |
| try: | |
| rider_model.optimize_for_inference(compile=True, batch_size=1) | |
| if device.type == "cuda": rider_model.model.half() | |
| except Exception as e: print(f"[WARNING] Rider model opt failed: {e}") | |
| print("[INIT] Loading License Plate Model...") | |
| plate_model = RFDETRMedium(num_classes=1, pretrain_weights=PLATE_WEIGHTS) | |
| try: | |
| plate_model.optimize_for_inference(compile=True, batch_size=1) | |
| if device.type == "cuda": plate_model.model.half() | |
| except Exception as e: print(f"[WARNING] Plate model opt failed: {e}") | |
| # --- GLOBAL STATE & WORKERS --- | |
| # We use a dictionary to separate state by session_id in a real app, | |
| # but for simplicity here we reset globals on upload or use a simplified structure. | |
| current_session_data = { | |
| "violations": {}, # {track_id: {data}} | |
| "ocr_queue": queue.Queue(), | |
| "track_plate_cache": {}, | |
| "track_capture_count": {}, | |
| "track_ocr_history": {} | |
| } | |
| # Live Camera Session State | |
| live_camera_sessions = {} # {session_id: {tracker, history, etc}} | |
| json_lock = threading.Lock() | |
| def get_best_consensus(results): | |
| from collections import Counter | |
| cleaned = [r.replace("\n", " ").strip() for r in results if r not in ["API_ERROR", "PENDING...", ""]] | |
| if not cleaned: return "PENDING..." | |
| if len(cleaned) == 1: return cleaned[0] | |
| max_len = max(len(r) for r in cleaned) | |
| final_chars = [] | |
| for i in range(max_len): | |
| char_pool = [r[i] for r in cleaned if i < len(r)] | |
| if char_pool: | |
| final_chars.append(Counter(char_pool).most_common(1)[0][0]) | |
| return "".join(final_chars).strip() | |
| def clamp_box(box, w, h): | |
| x1, y1, x2, y2 = map(int, box) | |
| return [max(0, x1), max(0, y1), min(w - 1, x2), min(h - 1, y2)] | |
| def expand_box(box, w, h): | |
| x1, y1, x2, y2 = map(int, box) | |
| bw, bh = x2 - x1, y2 - y1 | |
| return clamp_box([x1 - bw * 0.1, y1 + bh * 0.4, x2 + bw * 0.1, y2 + bh * 0.4], w, h) | |
| def background_ocr_worker(): | |
| print("[OCR] Worker Thread Started") | |
| try: | |
| # Initialize Gradio Client | |
| client = Client("WebashalarForML/demo-glm-ocr") | |
| except Exception as e: | |
| print(f"[OCR] Connection Failed: {e}") | |
| return | |
| while True: | |
| try: | |
| task = current_session_data["ocr_queue"].get() | |
| if task is None: continue # Keep alive | |
| track_id, plate_path, session_id = task | |
| print(f"[OCR] Processing ID {track_id}...") | |
| try: | |
| # Call external API | |
| result = client.predict(image=handle_file(plate_path), api_name="/proses_intelijen") | |
| plate_text = str(result).strip() | |
| except Exception as e: | |
| print(f"[OCR] API Error: {e}") | |
| plate_text = "API_ERROR" | |
| # Update History | |
| if track_id not in current_session_data["track_ocr_history"]: | |
| current_session_data["track_ocr_history"][track_id] = [] | |
| if plate_text not in ["API_ERROR", ""]: | |
| current_session_data["track_ocr_history"][track_id].append(plate_text) | |
| # Consensus | |
| final_plate = get_best_consensus(current_session_data["track_ocr_history"][track_id]) | |
| current_session_data["track_plate_cache"][track_id] = final_plate | |
| # Update Main JSON Record | |
| with json_lock: | |
| if track_id in current_session_data["violations"]: | |
| current_session_data["violations"][track_id]["plate_number"] = final_plate | |
| current_session_data["violations"][track_id]["ocr_attempts"] = current_session_data["track_ocr_history"][track_id] | |
| # Save to JSON file for persistence | |
| json_path = os.path.join(RESULTS_FOLDER, f"session_{session_id}.json") | |
| with open(json_path, 'w') as f: | |
| json.dump(list(current_session_data["violations"].values()), f, indent=4) | |
| current_session_data["ocr_queue"].task_done() | |
| except Exception as e: | |
| print(f"[OCR] Loop Error: {e}") | |
| # Start OCR Thread | |
| threading.Thread(target=background_ocr_worker, daemon=True).start() | |
| def parse_preds(preds, W, H): | |
| boxes, scores, labels = np.array([]), np.array([]), np.array([]) | |
| if hasattr(preds, "xyxy"): | |
| boxes = preds.xyxy.cpu().numpy() if not isinstance(preds.xyxy, np.ndarray) else preds.xyxy | |
| scores = preds.confidence.cpu().numpy() if not isinstance(preds.confidence, np.ndarray) else preds.confidence | |
| labels = preds.class_id.cpu().numpy() if not isinstance(preds.class_id, np.ndarray) else preds.class_id | |
| if boxes.size > 0 and boxes.max() <= 1.01: | |
| boxes = boxes.copy() | |
| boxes[:, [0, 2]] *= W | |
| boxes[:, [1, 3]] *= H | |
| return boxes, scores, labels | |
| def expand_box(box, w, h): | |
| x1, y1, x2, y2 = map(int, box) | |
| bw, bh = x2 - x1, y2 - y1 | |
| return clamp_box([x1 - bw * 0.1, y1 + bh * 0.4, x2 + bw * 0.1, y2 + bh * 0.4], w, h) | |
| def process_video_gen(video_path, session_id): | |
| cap = cv2.VideoCapture(video_path) | |
| tracker = sv.ByteTrack() | |
| # Session specific tracking (matching test.py structure) | |
| track_class_history = {} | |
| track_violation_memory = {} | |
| track_last_seen = {} | |
| dead_ids = set() | |
| frame_idx = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| frame_idx += 1 | |
| # Cleanup dead tracks (matching test.py) | |
| to_kill = [tid for tid, last in track_last_seen.items() if frame_idx - last > 50 and tid not in dead_ids] | |
| for tk in to_kill: | |
| dead_ids.add(tk) | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| h_orig, w_orig = frame.shape[:2] | |
| # 1. RIDER DETECTION | |
| with torch.no_grad(): | |
| rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.5) | |
| r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig) | |
| if r_boxes.size > 0: | |
| detections = sv.Detections( | |
| xyxy=r_boxes.astype(np.float32), | |
| confidence=r_scores.astype(np.float32), | |
| class_id=r_labels.astype(np.int32) | |
| ) | |
| else: | |
| detections = sv.Detections.empty() | |
| detections = tracker.update_with_detections(detections) | |
| for i, (xyxy, mask, confidence, class_id, tracker_id, data) in enumerate(detections): | |
| if tracker_id is None: continue | |
| tid = int(tracker_id) | |
| track_last_seen[tid] = frame_idx | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| cls_idx = int(class_id) | |
| # Track class history (matching test.py logic) | |
| if tid not in track_class_history: | |
| track_class_history[tid] = [] | |
| track_class_history[tid].append({"class": cls_idx, "conf": confidence}) | |
| if len(track_class_history[tid]) > 20: | |
| track_class_history[tid].pop(0) | |
| # Robust helmet state mapping (matching test.py) | |
| is_nh_current = (cls_idx == 2) | |
| is_h_current = (cls_idx == 0) | |
| if cls_idx == 1: | |
| hist = [h['class'] for h in track_class_history[tid]] | |
| if 2 in hist: | |
| is_h_current = True | |
| else: | |
| is_nh_current = True | |
| # Violation memory logic (matching test.py) | |
| if tid not in track_violation_memory: | |
| if is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: | |
| track_violation_memory[tid] = True | |
| hist = [h['class'] for h in track_class_history[tid]] | |
| nh_hits = sum(1 for c in hist if c == 2 or (c == 1 and 2 not in hist)) | |
| if nh_hits > 3: | |
| track_violation_memory[tid] = True | |
| is_no_helmet = track_violation_memory.get(tid, False) | |
| # Display name logic (matching test.py) | |
| if is_no_helmet: | |
| display_name, color = "VIOLATION: NO HELMET", (0, 0, 255) | |
| elif is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: | |
| display_name, color = "WARNING: NO HELMET", (0, 165, 255) | |
| elif is_nh_current and confidence > 0.15: | |
| display_name, color = "ANALYZING...", (0, 255, 255) | |
| elif is_h_current and confidence >= CONF_HELMET_CONFIRM: | |
| display_name, color = "HELMET", (0, 255, 0) | |
| else: | |
| display_name, color = f"TRACKING (C{cls_idx})", (180, 180, 180) | |
| # 3. LOG VIOLATION & CROP | |
| if is_no_helmet and tid not in dead_ids: | |
| # 3. LOG VIOLATION & CROP | |
| with json_lock: | |
| if tid not in current_session_data["violations"]: | |
| ts = datetime.now() | |
| # Save Rider Image | |
| rider_img_name = f"viol_{session_id}_{tid}_rider.jpg" | |
| rider_path = os.path.join(RESULTS_FOLDER, rider_img_name) | |
| cv2.imwrite(rider_path, frame[y1:y2, x1:x2]) | |
| # Initialize Record | |
| current_session_data["violations"][tid] = { | |
| "id": tid, | |
| "timestamp": ts.strftime('%H:%M:%S'), | |
| "type": "No Helmet", | |
| "plate_number": "Scanning...", | |
| "image_url": f"/static/results/{rider_img_name}", | |
| "plate_image_url": None, # Will fill later | |
| "ocr_attempts": [], | |
| "raw": { | |
| "confidence": float(confidence), | |
| "box": xyxy.tolist() | |
| } | |
| } | |
| current_session_data["track_capture_count"][tid] = 0 | |
| # 4. PLATE DETECTION (Only if violation confirmed) | |
| # Expand rider box to find plate inside/near it (matching test.py) | |
| eb = expand_box(xyxy, w_orig, h_orig) | |
| rider_crop = frame[eb[1]:eb[3], eb[0]:eb[2]] | |
| if rider_crop.size > 0 and current_session_data["track_capture_count"].get(tid, 0) < 3: | |
| # Run Plate Model on Crop | |
| with torch.no_grad(): | |
| plate_preds = plate_model.predict(cv2.cvtColor(rider_crop, cv2.COLOR_BGR2RGB), conf=CONF_PLATE, iou=0.5) | |
| pb, ps, pl = parse_preds(plate_preds, rider_crop.shape[1], rider_crop.shape[0]) | |
| if pb.size > 0: | |
| # Get best plate | |
| best_idx = np.argmax(ps) | |
| px1, py1, px2, py2 = map(int, pb[best_idx]) | |
| # Validate size | |
| plate_crop = rider_crop[py1:py2, px1:px2] | |
| if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20: | |
| # Save Plate Image | |
| s_idx = current_session_data['track_capture_count'][tid] + 1 | |
| plate_img_name = f"viol_{session_id}_{tid}_plate_snap{s_idx}.jpg" | |
| plate_path = os.path.join(RESULTS_FOLDER, plate_img_name) | |
| cv2.imwrite(plate_path, plate_crop) | |
| # Update JSON with plate image URL (use the latest one) | |
| with json_lock: | |
| current_session_data["violations"][tid]["plate_image_url"] = f"/static/results/{plate_img_name}" | |
| # Trigger OCR | |
| current_session_data["ocr_queue"].put((tid, plate_path, session_id)) | |
| current_session_data["track_capture_count"][tid] += 1 | |
| # Draw UI (matching test.py style) | |
| plate_text = current_session_data["track_plate_cache"].get(tid, "") | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| if plate_text: | |
| cv2.putText(frame, f"Plate: {plate_text}", (x1, y2 + 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
| else: | |
| # Normal Helmet/Tracking (matching test.py) | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| # Encode for streaming | |
| _, buffer = cv2.imencode('.jpg', frame) | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') | |
| cap.release() | |
| # --- ROUTES --- | |
| def index(): | |
| return render_template('landing.html') | |
| def dashboard(): | |
| return render_template('dashboard.html') | |
| def camera_debug(): | |
| return render_template('camera_debug.html') | |
| def test_simple(): | |
| return send_from_directory('.', 'test_simple.html') | |
| def test_socket_echo(): | |
| return send_from_directory('.', 'test_socket_echo.html') | |
| def upload_video(): | |
| if 'file' not in request.files: | |
| return jsonify({"error": "No file part"}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"error": "No selected file"}), 400 | |
| # 1. Generate Session ID | |
| session_id = str(uuid.uuid4())[:8] | |
| # 2. Reset Session Data | |
| with json_lock: | |
| current_session_data["violations"] = {} | |
| current_session_data["track_plate_cache"] = {} | |
| current_session_data["track_capture_count"] = {} | |
| current_session_data["track_ocr_history"] = {} | |
| # Clear queue | |
| while not current_session_data["ocr_queue"].empty(): | |
| try: current_session_data["ocr_queue"].get_nowait() | |
| except: pass | |
| filename = f"{session_id}_{file.filename}" | |
| filepath = os.path.join(UPLOAD_FOLDER, filename) | |
| file.save(filepath) | |
| return jsonify({"filename": filename, "session_id": session_id}) | |
| def get_violations(): | |
| # Return list of violations | |
| with json_lock: | |
| data = list(current_session_data["violations"].values()) | |
| # Sort by timestamp (descending) | |
| data.reverse() | |
| return jsonify(data) | |
| def video_feed(filename, session_id): | |
| filepath = os.path.join(UPLOAD_FOLDER, filename) | |
| return Response(process_video_gen(filepath, session_id), | |
| mimetype='multipart/x-mixed-replace; boundary=frame') | |
| def mobile_node(session_id): | |
| return render_template('mobile.html', session_id=session_id) | |
| def upload_frame(session_id): | |
| # (Simplified for now - can extend to run detection on mobile frames too) | |
| return jsonify({"status": "received"}) | |
| # --- SOCKET.IO HANDLERS FOR LIVE CAMERA --- | |
| def handle_connect(): | |
| print(f"[SOCKET] Client connected: {request.sid}") | |
| emit('connection_response', {'status': 'connected'}) | |
| def handle_disconnect(): | |
| print(f"[SOCKET] Client disconnected: {request.sid}") | |
| # Cleanup session if exists | |
| if request.sid in live_camera_sessions: | |
| del live_camera_sessions[request.sid] | |
| def handle_start_camera(data): | |
| session_id = data.get('session_id', str(uuid.uuid4())[:8]) | |
| print(f"[SOCKET] Starting camera session: {session_id}") | |
| # Initialize session state | |
| live_camera_sessions[request.sid] = { | |
| 'session_id': session_id, | |
| 'tracker': sv.ByteTrack(), | |
| 'track_class_history': {}, | |
| 'track_violation_memory': {}, | |
| 'track_last_seen': {}, | |
| 'dead_ids': set(), | |
| 'frame_idx': 0, | |
| 'violations': {}, | |
| 'track_plate_cache': {}, | |
| 'track_capture_count': {}, | |
| 'track_ocr_history': {} | |
| } | |
| emit('camera_session_started', {'session_id': session_id}) | |
| def handle_camera_frame(data): | |
| if request.sid not in live_camera_sessions: | |
| print(f"[SOCKET] No session found for {request.sid}") | |
| emit('error', {'message': 'No active session'}) | |
| return | |
| try: | |
| print(f"[SOCKET] Received frame from {request.sid}, size: {len(data.get('frame', ''))} bytes") | |
| # Decode base64 frame | |
| frame_data = data['frame'].split(',')[1] # Remove data:image/jpeg;base64, | |
| frame_bytes = base64.b64decode(frame_data) | |
| nparr = np.frombuffer(frame_bytes, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| print("[SOCKET] Failed to decode frame") | |
| return | |
| print(f"[SOCKET] Frame decoded: {frame.shape}") | |
| # Process frame | |
| session = live_camera_sessions[request.sid] | |
| session_id = session['session_id'] | |
| processed_frame, new_violations = process_live_frame(frame, session, session_id) | |
| print(f"[SOCKET] Frame processed, violations: {len(new_violations)}") | |
| # Encode processed frame | |
| _, buffer = cv2.imencode('.jpg', processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
| processed_b64 = base64.b64encode(buffer).decode('utf-8') | |
| print(f"[SOCKET] Encoded frame size: {len(processed_b64)} bytes") | |
| # Send back processed frame (use explicit room targeting) | |
| socketio.emit('processed_frame', { | |
| 'frame': f'data:image/jpeg;base64,{processed_b64}', | |
| 'violations': new_violations | |
| }, room=request.sid) | |
| print(f"[SOCKET] Emitted processed_frame to client {request.sid}") | |
| except Exception as e: | |
| print(f"[SOCKET] Frame processing error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| emit('error', {'message': str(e)}) | |
| def process_live_frame(frame, session, session_id): | |
| """Process a single frame from live camera feed""" | |
| tracker = session['tracker'] | |
| track_class_history = session['track_class_history'] | |
| track_violation_memory = session['track_violation_memory'] | |
| track_last_seen = session['track_last_seen'] | |
| dead_ids = session['dead_ids'] | |
| frame_idx = session['frame_idx'] | |
| session['frame_idx'] += 1 | |
| frame_idx = session['frame_idx'] | |
| # Cleanup dead tracks | |
| to_kill = [tid for tid, last in track_last_seen.items() if frame_idx - last > 50 and tid not in dead_ids] | |
| for tk in to_kill: | |
| dead_ids.add(tk) | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| h_orig, w_orig = frame.shape[:2] | |
| # 1. RIDER DETECTION | |
| with torch.no_grad(): | |
| rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.5) | |
| r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig) | |
| if r_boxes.size > 0: | |
| detections = sv.Detections( | |
| xyxy=r_boxes.astype(np.float32), | |
| confidence=r_scores.astype(np.float32), | |
| class_id=r_labels.astype(np.int32) | |
| ) | |
| else: | |
| detections = sv.Detections.empty() | |
| detections = tracker.update_with_detections(detections) | |
| new_violations = [] | |
| for i, (xyxy, mask, confidence, class_id, tracker_id, data) in enumerate(detections): | |
| if tracker_id is None: continue | |
| tid = int(tracker_id) | |
| track_last_seen[tid] = frame_idx | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| cls_idx = int(class_id) | |
| # Track class history | |
| if tid not in track_class_history: | |
| track_class_history[tid] = [] | |
| track_class_history[tid].append({"class": cls_idx, "conf": confidence}) | |
| if len(track_class_history[tid]) > 20: | |
| track_class_history[tid].pop(0) | |
| # Robust helmet state mapping | |
| is_nh_current = (cls_idx == 2) | |
| is_h_current = (cls_idx == 0) | |
| if cls_idx == 1: | |
| hist = [h['class'] for h in track_class_history[tid]] | |
| if 2 in hist: | |
| is_h_current = True | |
| else: | |
| is_nh_current = True | |
| # Violation memory logic | |
| if tid not in track_violation_memory: | |
| if is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: | |
| track_violation_memory[tid] = True | |
| hist = [h['class'] for h in track_class_history[tid]] | |
| nh_hits = sum(1 for c in hist if c == 2 or (c == 1 and 2 not in hist)) | |
| if nh_hits > 3: | |
| track_violation_memory[tid] = True | |
| is_no_helmet = track_violation_memory.get(tid, False) | |
| # Display name logic | |
| if is_no_helmet: | |
| display_name, color = "VIOLATION: NO HELMET", (0, 0, 255) | |
| elif is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: | |
| display_name, color = "WARNING: NO HELMET", (0, 165, 255) | |
| elif is_nh_current and confidence > 0.15: | |
| display_name, color = "ANALYZING...", (0, 255, 255) | |
| elif is_h_current and confidence >= CONF_HELMET_CONFIRM: | |
| display_name, color = "HELMET", (0, 255, 0) | |
| else: | |
| display_name, color = f"TRACKING (C{cls_idx})", (180, 180, 180) | |
| # LOG VIOLATION & CROP | |
| if is_no_helmet and tid not in dead_ids: | |
| with json_lock: | |
| if tid not in session['violations']: | |
| ts = datetime.now() | |
| # Save Rider Image | |
| rider_img_name = f"viol_live_{session_id}_{tid}_rider.jpg" | |
| rider_path = os.path.join(RESULTS_FOLDER, rider_img_name) | |
| cv2.imwrite(rider_path, frame[y1:y2, x1:x2]) | |
| # Initialize Record | |
| violation_record = { | |
| "id": tid, | |
| "timestamp": ts.strftime('%H:%M:%S'), | |
| "type": "No Helmet", | |
| "plate_number": "Scanning...", | |
| "image_url": f"/static/results/{rider_img_name}", | |
| "plate_image_url": None, | |
| "ocr_attempts": [], | |
| "raw": { | |
| "confidence": float(confidence), | |
| "box": xyxy.tolist() | |
| } | |
| } | |
| session['violations'][tid] = violation_record | |
| session['track_capture_count'][tid] = 0 | |
| new_violations.append(violation_record) | |
| # PLATE DETECTION | |
| eb = expand_box(xyxy, w_orig, h_orig) | |
| rider_crop = frame[eb[1]:eb[3], eb[0]:eb[2]] | |
| if rider_crop.size > 0 and session['track_capture_count'].get(tid, 0) < 3: | |
| with torch.no_grad(): | |
| plate_preds = plate_model.predict(cv2.cvtColor(rider_crop, cv2.COLOR_BGR2RGB), conf=CONF_PLATE, iou=0.5) | |
| pb, ps, pl = parse_preds(plate_preds, rider_crop.shape[1], rider_crop.shape[0]) | |
| if pb.size > 0: | |
| best_idx = np.argmax(ps) | |
| px1, py1, px2, py2 = map(int, pb[best_idx]) | |
| plate_crop = rider_crop[py1:py2, px1:px2] | |
| if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20: | |
| s_idx = session['track_capture_count'][tid] + 1 | |
| plate_img_name = f"viol_live_{session_id}_{tid}_plate_snap{s_idx}.jpg" | |
| plate_path = os.path.join(RESULTS_FOLDER, plate_img_name) | |
| cv2.imwrite(plate_path, plate_crop) | |
| with json_lock: | |
| session['violations'][tid]["plate_image_url"] = f"/static/results/{plate_img_name}" | |
| # Trigger OCR (using shared queue) | |
| current_session_data["ocr_queue"].put((tid, plate_path, f"live_{session_id}")) | |
| session['track_capture_count'][tid] += 1 | |
| # Draw UI | |
| plate_text = session['track_plate_cache'].get(tid, "") | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| if plate_text: | |
| cv2.putText(frame, f"Plate: {plate_text}", (x1, y2 + 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
| else: | |
| # Normal Helmet/Tracking | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| return frame, new_violations | |
| def get_live_violations(session_id): | |
| """Get violations for a specific live camera session""" | |
| for sid, session in live_camera_sessions.items(): | |
| if session['session_id'] == session_id: | |
| with json_lock: | |
| data = list(session['violations'].values()) | |
| data.reverse() | |
| return jsonify(data) | |
| return jsonify([]) | |
| if __name__ == '__main__': | |
| socketio.run(app, host='0.0.0.0', debug=True, port=7860, allow_unsafe_werkzeug=True) |