from flask import Flask, render_template, request, jsonify, Response, send_from_directory from flask_socketio import SocketIO, emit import cv2 import numpy as np import os import json import uuid import threading import queue import torch import time from datetime import datetime from collections import Counter import base64 import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from rfdetr import RFDETRMedium import supervision as sv from gradio_client import Client, handle_file app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key-here' app.config['SESSION_TYPE'] = 'filesystem' socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading', manage_session=False) # --- CONFIG --- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static/uploads') RESULTS_FOLDER = os.path.join(BASE_DIR, 'static/results') os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(RESULTS_FOLDER, exist_ok=True) # Model Paths HELMET_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoint_best_ema_bike_helmet.pth") PLATE_WEIGHTS = os.path.join(BASE_DIR, "Model/checkpoint_best_ema_plate.pth") HELMET_CLASSES = ["motorbike and helmet", "motorbike and no helmet"] # Thresholds (matching test.py) CONF_RIDER = 0.25 CONF_PLATE = 0.30 CONF_HELMET_CONFIRM = 0.40 CONF_NO_HELMET_TRIGGER = 0.35 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INIT] Targeted Device: {device}") # --- MODEL INITIALIZATION --- print("[INIT] Loading Rider Detection Model...") rider_model = RFDETRMedium(num_classes=len(HELMET_CLASSES), pretrain_weights=HELMET_WEIGHTS) try: rider_model.optimize_for_inference(compile=True, batch_size=1) if device.type == "cuda": rider_model.model.half() except Exception as e: print(f"[WARNING] Rider model opt failed: {e}") print("[INIT] Loading License Plate Model...") plate_model = RFDETRMedium(num_classes=1, pretrain_weights=PLATE_WEIGHTS) try: plate_model.optimize_for_inference(compile=True, batch_size=1) if device.type == "cuda": plate_model.model.half() except Exception as e: print(f"[WARNING] Plate model opt failed: {e}") # --- GLOBAL STATE & WORKERS --- # We use a dictionary to separate state by session_id in a real app, # but for simplicity here we reset globals on upload or use a simplified structure. current_session_data = { "violations": {}, # {track_id: {data}} "ocr_queue": queue.Queue(), "track_plate_cache": {}, "track_capture_count": {}, "track_ocr_history": {} } # Live Camera Session State live_camera_sessions = {} # {session_id: {tracker, history, etc}} json_lock = threading.Lock() def get_best_consensus(results): from collections import Counter cleaned = [r.replace("\n", " ").strip() for r in results if r not in ["API_ERROR", "PENDING...", ""]] if not cleaned: return "PENDING..." if len(cleaned) == 1: return cleaned[0] max_len = max(len(r) for r in cleaned) final_chars = [] for i in range(max_len): char_pool = [r[i] for r in cleaned if i < len(r)] if char_pool: final_chars.append(Counter(char_pool).most_common(1)[0][0]) return "".join(final_chars).strip() def clamp_box(box, w, h): x1, y1, x2, y2 = map(int, box) return [max(0, x1), max(0, y1), min(w - 1, x2), min(h - 1, y2)] def expand_box(box, w, h): x1, y1, x2, y2 = map(int, box) bw, bh = x2 - x1, y2 - y1 return clamp_box([x1 - bw * 0.1, y1 + bh * 0.4, x2 + bw * 0.1, y2 + bh * 0.4], w, h) def background_ocr_worker(): print("[OCR] Worker Thread Started") try: # Initialize Gradio Client client = Client("WebashalarForML/demo-glm-ocr") except Exception as e: print(f"[OCR] Connection Failed: {e}") return while True: try: task = current_session_data["ocr_queue"].get() if task is None: continue # Keep alive track_id, plate_path, session_id = task print(f"[OCR] Processing ID {track_id}...") try: # Call external API result = client.predict(image=handle_file(plate_path), api_name="/proses_intelijen") plate_text = str(result).strip() except Exception as e: print(f"[OCR] API Error: {e}") plate_text = "API_ERROR" # Update History if track_id not in current_session_data["track_ocr_history"]: current_session_data["track_ocr_history"][track_id] = [] if plate_text not in ["API_ERROR", ""]: current_session_data["track_ocr_history"][track_id].append(plate_text) # Consensus final_plate = get_best_consensus(current_session_data["track_ocr_history"][track_id]) current_session_data["track_plate_cache"][track_id] = final_plate # Update Main JSON Record with json_lock: if track_id in current_session_data["violations"]: current_session_data["violations"][track_id]["plate_number"] = final_plate current_session_data["violations"][track_id]["ocr_attempts"] = current_session_data["track_ocr_history"][track_id] # Save to JSON file for persistence json_path = os.path.join(RESULTS_FOLDER, f"session_{session_id}.json") with open(json_path, 'w') as f: json.dump(list(current_session_data["violations"].values()), f, indent=4) current_session_data["ocr_queue"].task_done() except Exception as e: print(f"[OCR] Loop Error: {e}") # Start OCR Thread threading.Thread(target=background_ocr_worker, daemon=True).start() def parse_preds(preds, W, H): boxes, scores, labels = np.array([]), np.array([]), np.array([]) if hasattr(preds, "xyxy"): boxes = preds.xyxy.cpu().numpy() if not isinstance(preds.xyxy, np.ndarray) else preds.xyxy scores = preds.confidence.cpu().numpy() if not isinstance(preds.confidence, np.ndarray) else preds.confidence labels = preds.class_id.cpu().numpy() if not isinstance(preds.class_id, np.ndarray) else preds.class_id if boxes.size > 0 and boxes.max() <= 1.01: boxes = boxes.copy() boxes[:, [0, 2]] *= W boxes[:, [1, 3]] *= H return boxes, scores, labels def expand_box(box, w, h): x1, y1, x2, y2 = map(int, box) bw, bh = x2 - x1, y2 - y1 return clamp_box([x1 - bw * 0.1, y1 + bh * 0.4, x2 + bw * 0.1, y2 + bh * 0.4], w, h) def process_video_gen(video_path, session_id): cap = cv2.VideoCapture(video_path) tracker = sv.ByteTrack() # Session specific tracking (matching test.py structure) track_class_history = {} track_violation_memory = {} track_last_seen = {} dead_ids = set() frame_idx = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_idx += 1 # Cleanup dead tracks (matching test.py) to_kill = [tid for tid, last in track_last_seen.items() if frame_idx - last > 50 and tid not in dead_ids] for tk in to_kill: dead_ids.add(tk) rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) h_orig, w_orig = frame.shape[:2] # 1. RIDER DETECTION with torch.no_grad(): rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.5) r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig) if r_boxes.size > 0: detections = sv.Detections( xyxy=r_boxes.astype(np.float32), confidence=r_scores.astype(np.float32), class_id=r_labels.astype(np.int32) ) else: detections = sv.Detections.empty() detections = tracker.update_with_detections(detections) for i, (xyxy, mask, confidence, class_id, tracker_id, data) in enumerate(detections): if tracker_id is None: continue tid = int(tracker_id) track_last_seen[tid] = frame_idx x1, y1, x2, y2 = map(int, xyxy) cls_idx = int(class_id) # Track class history (matching test.py logic) if tid not in track_class_history: track_class_history[tid] = [] track_class_history[tid].append({"class": cls_idx, "conf": confidence}) if len(track_class_history[tid]) > 20: track_class_history[tid].pop(0) # Robust helmet state mapping (matching test.py) is_nh_current = (cls_idx == 2) is_h_current = (cls_idx == 0) if cls_idx == 1: hist = [h['class'] for h in track_class_history[tid]] if 2 in hist: is_h_current = True else: is_nh_current = True # Violation memory logic (matching test.py) if tid not in track_violation_memory: if is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: track_violation_memory[tid] = True hist = [h['class'] for h in track_class_history[tid]] nh_hits = sum(1 for c in hist if c == 2 or (c == 1 and 2 not in hist)) if nh_hits > 3: track_violation_memory[tid] = True is_no_helmet = track_violation_memory.get(tid, False) # Display name logic (matching test.py) if is_no_helmet: display_name, color = "VIOLATION: NO HELMET", (0, 0, 255) elif is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: display_name, color = "WARNING: NO HELMET", (0, 165, 255) elif is_nh_current and confidence > 0.15: display_name, color = "ANALYZING...", (0, 255, 255) elif is_h_current and confidence >= CONF_HELMET_CONFIRM: display_name, color = "HELMET", (0, 255, 0) else: display_name, color = f"TRACKING (C{cls_idx})", (180, 180, 180) # 3. LOG VIOLATION & CROP if is_no_helmet and tid not in dead_ids: # 3. LOG VIOLATION & CROP with json_lock: if tid not in current_session_data["violations"]: ts = datetime.now() # Save Rider Image rider_img_name = f"viol_{session_id}_{tid}_rider.jpg" rider_path = os.path.join(RESULTS_FOLDER, rider_img_name) cv2.imwrite(rider_path, frame[y1:y2, x1:x2]) # Initialize Record current_session_data["violations"][tid] = { "id": tid, "timestamp": ts.strftime('%H:%M:%S'), "type": "No Helmet", "plate_number": "Scanning...", "image_url": f"/static/results/{rider_img_name}", "plate_image_url": None, # Will fill later "ocr_attempts": [], "raw": { "confidence": float(confidence), "box": xyxy.tolist() } } current_session_data["track_capture_count"][tid] = 0 # 4. PLATE DETECTION (Only if violation confirmed) # Expand rider box to find plate inside/near it (matching test.py) eb = expand_box(xyxy, w_orig, h_orig) rider_crop = frame[eb[1]:eb[3], eb[0]:eb[2]] if rider_crop.size > 0 and current_session_data["track_capture_count"].get(tid, 0) < 3: # Run Plate Model on Crop with torch.no_grad(): plate_preds = plate_model.predict(cv2.cvtColor(rider_crop, cv2.COLOR_BGR2RGB), conf=CONF_PLATE, iou=0.5) pb, ps, pl = parse_preds(plate_preds, rider_crop.shape[1], rider_crop.shape[0]) if pb.size > 0: # Get best plate best_idx = np.argmax(ps) px1, py1, px2, py2 = map(int, pb[best_idx]) # Validate size plate_crop = rider_crop[py1:py2, px1:px2] if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20: # Save Plate Image s_idx = current_session_data['track_capture_count'][tid] + 1 plate_img_name = f"viol_{session_id}_{tid}_plate_snap{s_idx}.jpg" plate_path = os.path.join(RESULTS_FOLDER, plate_img_name) cv2.imwrite(plate_path, plate_crop) # Update JSON with plate image URL (use the latest one) with json_lock: current_session_data["violations"][tid]["plate_image_url"] = f"/static/results/{plate_img_name}" # Trigger OCR current_session_data["ocr_queue"].put((tid, plate_path, session_id)) current_session_data["track_capture_count"][tid] += 1 # Draw UI (matching test.py style) plate_text = current_session_data["track_plate_cache"].get(tid, "") cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) if plate_text: cv2.putText(frame, f"Plate: {plate_text}", (x1, y2 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) else: # Normal Helmet/Tracking (matching test.py) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) # Encode for streaming _, buffer = cv2.imencode('.jpg', frame) yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') cap.release() # --- ROUTES --- @app.route('/') def index(): return render_template('landing.html') @app.route('/dashboard') def dashboard(): return render_template('dashboard.html') @app.route('/camera_debug') def camera_debug(): return render_template('camera_debug.html') @app.route('/test_simple') def test_simple(): return send_from_directory('.', 'test_simple.html') @app.route('/test_socket_echo') def test_socket_echo(): return send_from_directory('.', 'test_socket_echo.html') @app.route('/upload', methods=['POST']) def upload_video(): if 'file' not in request.files: return jsonify({"error": "No file part"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"error": "No selected file"}), 400 # 1. Generate Session ID session_id = str(uuid.uuid4())[:8] # 2. Reset Session Data with json_lock: current_session_data["violations"] = {} current_session_data["track_plate_cache"] = {} current_session_data["track_capture_count"] = {} current_session_data["track_ocr_history"] = {} # Clear queue while not current_session_data["ocr_queue"].empty(): try: current_session_data["ocr_queue"].get_nowait() except: pass filename = f"{session_id}_{file.filename}" filepath = os.path.join(UPLOAD_FOLDER, filename) file.save(filepath) return jsonify({"filename": filename, "session_id": session_id}) @app.route('/get_violations') def get_violations(): # Return list of violations with json_lock: data = list(current_session_data["violations"].values()) # Sort by timestamp (descending) data.reverse() return jsonify(data) @app.route('/video_feed//') def video_feed(filename, session_id): filepath = os.path.join(UPLOAD_FOLDER, filename) return Response(process_video_gen(filepath, session_id), mimetype='multipart/x-mixed-replace; boundary=frame') @app.route('/mobile/') def mobile_node(session_id): return render_template('mobile.html', session_id=session_id) @app.route('/upload_frame/', methods=['POST']) def upload_frame(session_id): # (Simplified for now - can extend to run detection on mobile frames too) return jsonify({"status": "received"}) # --- SOCKET.IO HANDLERS FOR LIVE CAMERA --- @socketio.on('connect') def handle_connect(): print(f"[SOCKET] Client connected: {request.sid}") emit('connection_response', {'status': 'connected'}) @socketio.on('disconnect') def handle_disconnect(): print(f"[SOCKET] Client disconnected: {request.sid}") # Cleanup session if exists if request.sid in live_camera_sessions: del live_camera_sessions[request.sid] @socketio.on('start_camera_session') def handle_start_camera(data): session_id = data.get('session_id', str(uuid.uuid4())[:8]) print(f"[SOCKET] Starting camera session: {session_id}") # Initialize session state live_camera_sessions[request.sid] = { 'session_id': session_id, 'tracker': sv.ByteTrack(), 'track_class_history': {}, 'track_violation_memory': {}, 'track_last_seen': {}, 'dead_ids': set(), 'frame_idx': 0, 'violations': {}, 'track_plate_cache': {}, 'track_capture_count': {}, 'track_ocr_history': {} } emit('camera_session_started', {'session_id': session_id}) @socketio.on('camera_frame') def handle_camera_frame(data): if request.sid not in live_camera_sessions: print(f"[SOCKET] No session found for {request.sid}") emit('error', {'message': 'No active session'}) return try: print(f"[SOCKET] Received frame from {request.sid}, size: {len(data.get('frame', ''))} bytes") # Decode base64 frame frame_data = data['frame'].split(',')[1] # Remove data:image/jpeg;base64, frame_bytes = base64.b64decode(frame_data) nparr = np.frombuffer(frame_bytes, np.uint8) frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if frame is None: print("[SOCKET] Failed to decode frame") return print(f"[SOCKET] Frame decoded: {frame.shape}") # Process frame session = live_camera_sessions[request.sid] session_id = session['session_id'] processed_frame, new_violations = process_live_frame(frame, session, session_id) print(f"[SOCKET] Frame processed, violations: {len(new_violations)}") # Encode processed frame _, buffer = cv2.imencode('.jpg', processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) processed_b64 = base64.b64encode(buffer).decode('utf-8') print(f"[SOCKET] Encoded frame size: {len(processed_b64)} bytes") # Send back processed frame (use explicit room targeting) socketio.emit('processed_frame', { 'frame': f'data:image/jpeg;base64,{processed_b64}', 'violations': new_violations }, room=request.sid) print(f"[SOCKET] Emitted processed_frame to client {request.sid}") except Exception as e: print(f"[SOCKET] Frame processing error: {e}") import traceback traceback.print_exc() emit('error', {'message': str(e)}) def process_live_frame(frame, session, session_id): """Process a single frame from live camera feed""" tracker = session['tracker'] track_class_history = session['track_class_history'] track_violation_memory = session['track_violation_memory'] track_last_seen = session['track_last_seen'] dead_ids = session['dead_ids'] frame_idx = session['frame_idx'] session['frame_idx'] += 1 frame_idx = session['frame_idx'] # Cleanup dead tracks to_kill = [tid for tid, last in track_last_seen.items() if frame_idx - last > 50 and tid not in dead_ids] for tk in to_kill: dead_ids.add(tk) rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) h_orig, w_orig = frame.shape[:2] # 1. RIDER DETECTION with torch.no_grad(): rider_preds = rider_model.predict(rgb_frame, conf=CONF_RIDER, iou=0.5) r_boxes, r_scores, r_labels = parse_preds(rider_preds, w_orig, h_orig) if r_boxes.size > 0: detections = sv.Detections( xyxy=r_boxes.astype(np.float32), confidence=r_scores.astype(np.float32), class_id=r_labels.astype(np.int32) ) else: detections = sv.Detections.empty() detections = tracker.update_with_detections(detections) new_violations = [] for i, (xyxy, mask, confidence, class_id, tracker_id, data) in enumerate(detections): if tracker_id is None: continue tid = int(tracker_id) track_last_seen[tid] = frame_idx x1, y1, x2, y2 = map(int, xyxy) cls_idx = int(class_id) # Track class history if tid not in track_class_history: track_class_history[tid] = [] track_class_history[tid].append({"class": cls_idx, "conf": confidence}) if len(track_class_history[tid]) > 20: track_class_history[tid].pop(0) # Robust helmet state mapping is_nh_current = (cls_idx == 2) is_h_current = (cls_idx == 0) if cls_idx == 1: hist = [h['class'] for h in track_class_history[tid]] if 2 in hist: is_h_current = True else: is_nh_current = True # Violation memory logic if tid not in track_violation_memory: if is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: track_violation_memory[tid] = True hist = [h['class'] for h in track_class_history[tid]] nh_hits = sum(1 for c in hist if c == 2 or (c == 1 and 2 not in hist)) if nh_hits > 3: track_violation_memory[tid] = True is_no_helmet = track_violation_memory.get(tid, False) # Display name logic if is_no_helmet: display_name, color = "VIOLATION: NO HELMET", (0, 0, 255) elif is_nh_current and confidence >= CONF_NO_HELMET_TRIGGER: display_name, color = "WARNING: NO HELMET", (0, 165, 255) elif is_nh_current and confidence > 0.15: display_name, color = "ANALYZING...", (0, 255, 255) elif is_h_current and confidence >= CONF_HELMET_CONFIRM: display_name, color = "HELMET", (0, 255, 0) else: display_name, color = f"TRACKING (C{cls_idx})", (180, 180, 180) # LOG VIOLATION & CROP if is_no_helmet and tid not in dead_ids: with json_lock: if tid not in session['violations']: ts = datetime.now() # Save Rider Image rider_img_name = f"viol_live_{session_id}_{tid}_rider.jpg" rider_path = os.path.join(RESULTS_FOLDER, rider_img_name) cv2.imwrite(rider_path, frame[y1:y2, x1:x2]) # Initialize Record violation_record = { "id": tid, "timestamp": ts.strftime('%H:%M:%S'), "type": "No Helmet", "plate_number": "Scanning...", "image_url": f"/static/results/{rider_img_name}", "plate_image_url": None, "ocr_attempts": [], "raw": { "confidence": float(confidence), "box": xyxy.tolist() } } session['violations'][tid] = violation_record session['track_capture_count'][tid] = 0 new_violations.append(violation_record) # PLATE DETECTION eb = expand_box(xyxy, w_orig, h_orig) rider_crop = frame[eb[1]:eb[3], eb[0]:eb[2]] if rider_crop.size > 0 and session['track_capture_count'].get(tid, 0) < 3: with torch.no_grad(): plate_preds = plate_model.predict(cv2.cvtColor(rider_crop, cv2.COLOR_BGR2RGB), conf=CONF_PLATE, iou=0.5) pb, ps, pl = parse_preds(plate_preds, rider_crop.shape[1], rider_crop.shape[0]) if pb.size > 0: best_idx = np.argmax(ps) px1, py1, px2, py2 = map(int, pb[best_idx]) plate_crop = rider_crop[py1:py2, px1:px2] if plate_crop.size > 0 and plate_crop.shape[0] > 10 and plate_crop.shape[1] > 20: s_idx = session['track_capture_count'][tid] + 1 plate_img_name = f"viol_live_{session_id}_{tid}_plate_snap{s_idx}.jpg" plate_path = os.path.join(RESULTS_FOLDER, plate_img_name) cv2.imwrite(plate_path, plate_crop) with json_lock: session['violations'][tid]["plate_image_url"] = f"/static/results/{plate_img_name}" # Trigger OCR (using shared queue) current_session_data["ocr_queue"].put((tid, plate_path, f"live_{session_id}")) session['track_capture_count'][tid] += 1 # Draw UI plate_text = session['track_plate_cache'].get(tid, "") cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) if plate_text: cv2.putText(frame, f"Plate: {plate_text}", (x1, y2 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) else: # Normal Helmet/Tracking cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, f"ID:{tid} {display_name} {confidence:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return frame, new_violations @app.route('/get_live_violations/') def get_live_violations(session_id): """Get violations for a specific live camera session""" for sid, session in live_camera_sessions.items(): if session['session_id'] == session_id: with json_lock: data = list(session['violations'].values()) data.reverse() return jsonify(data) return jsonify([]) if __name__ == '__main__': socketio.run(app, host='0.0.0.0', debug=True, port=7860, allow_unsafe_werkzeug=True)