import streamlit as st import pandas as pd import cv2 import numpy as np from ultralytics import YOLO from datetime import datetime import plotly.express as px import os import torch import torch.nn as nn import sys import subprocess from collections import OrderedDict import yaml from pathlib import Path import pickle import importlib # Whitelist safe modules for unpickling SAFE_MODULES = { 'torch.Size', 'torch.LongStorage', 'torch.HalfStorage', 'torch.FloatStorage', 'torch.nn.modules.container.Sequential', 'torch.nn.modules.container.ModuleList', 'torch.nn.modules.activation.SiLU', 'torch.nn.modules.conv.Conv2d', 'torch.nn.modules.batchnorm.BatchNorm2d', 'torch._utils._rebuild_tensor_v2', 'torch._utils._rebuild_parameter', 'collections.OrderedDict', 'numpy.core.multiarray.scalar', 'numpy.dtype', 'ultralytics.nn.modules.Detect', 'ultralytics.nn.modules.SPPF', 'ultralytics.nn.modules.DFL', 'ultralytics.nn.modules.Conv', 'ultralytics.nn.modules.Bottleneck', 'ultralytics.nn.modules.C2f', 'ultralytics.nn.modules.Concat', 'ultralytics.nn.tasks.DetectionModel', 'ultralytics.yolo.utils.IterableSimpleNamespace' } # Custom safe unpickler class SafeUnpickler(pickle.Unpickler): def find_class(self, module, name): # Check if the module and class combination is in our whitelist fullname = f"{module}.{name}" if fullname in SAFE_MODULES: # Import the module and return the class if module not in sys.modules: importlib.import_module(module) return getattr(sys.modules[module], name) # If not in whitelist, raise an error raise pickle.UnpicklingError(f"Attempting to unpickle unsafe module: {fullname}") # Configure page st.set_page_config( page_title="License Plate Detection", page_icon="🚗", layout="wide" ) # Initialize session state if 'model' not in st.session_state: st.session_state.model = None if 'processed_frames' not in st.session_state: st.session_state.processed_frames = 0 if 'detections' not in st.session_state: st.session_state.detections = [] # Safe model loading with custom unpickler @st.cache_resource def load_model(): try: # Set model directory model_path = Path('best.pt') if not model_path.exists(): raise FileNotFoundError("Model file not found") # Initialize YOLO with safe loading model = YOLO( model_path, task='detect', verbose=False ) # Force model to appropriate device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # Verify model structure if not hasattr(model, 'model') or not isinstance(model.model, nn.Module): raise ValueError("Invalid model structure") return model except Exception as e: st.error(f"Error loading model: {str(e)}") return None # Safe frame processing def process_frame(frame, model): try: if model is None: return [] # Ensure frame is valid if frame is None or not isinstance(frame, np.ndarray): return [] # Make prediction with error handling with torch.no_grad(): results = model.predict( source=frame, conf=0.25, iou=0.45, verbose=False ) # Safely extract results if results and len(results) > 0: return results[0].boxes.data.cpu().numpy() return [] except Exception as e: st.error(f"Error processing frame: {str(e)}") return [] def main(): st.title("License Plate Detection System") # Sidebar controls with st.sidebar: st.header("Controls") confidence_threshold = st.slider( "Confidence Threshold", min_value=0.0, max_value=1.0, value=0.25, step=0.05 ) # Load model if st.session_state.model is None: with st.spinner("Loading model..."): st.session_state.model = load_model() if st.session_state.model is None: st.error("Failed to load model. Please check the model file.") return # File uploader video_file = st.file_uploader( "Upload Video", type=['mp4', 'avi', 'mov'], help="Upload a video file containing license plates" ) if video_file: try: # Save uploaded file temp_path = "temp_video.mp4" with open(temp_path, "wb") as f: f.write(video_file.read()) # Process video cap = cv2.VideoCapture(temp_path) if not cap.isOpened(): st.error("Error opening video file") return # Video info total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) # Display elements progress_bar = st.progress(0) frame_placeholder = st.empty() stats_placeholder = st.empty() # Process frames frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_count += 1 progress = int((frame_count / total_frames) * 100) progress_bar.progress(progress) # Process every 3rd frame if frame_count % 3 == 0: detections = process_frame(frame, st.session_state.model) for det in detections: if det[4] >= confidence_threshold: x1, y1, x2, y2 = map(int, det[:4]) conf = float(det[4]) # Draw detection cv2.rectangle( frame, (x1, y1), (x2, y2), (0, 255, 0), 2 ) cv2.putText( frame, f"{conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2 ) # Save detection st.session_state.detections.append({ 'time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'confidence': conf }) # Display frame frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_placeholder.image( frame_rgb, channels="RGB", use_column_width=True ) # Update stats with stats_placeholder: col1, col2 = st.columns(2) with col1: st.metric("Processed Frames", frame_count) with col2: st.metric("Detections", len(st.session_state.detections)) # Clean up cap.release() if os.path.exists(temp_path): os.remove(temp_path) # Show results if st.session_state.detections: st.header("Detection Statistics") df = pd.DataFrame(st.session_state.detections) # Confidence distribution fig = px.histogram( df, x='confidence', title='Detection Confidence Distribution', labels={'confidence': 'Confidence Score'} ) st.plotly_chart(fig, use_container_width=True) # Results table st.dataframe( df, use_container_width=True ) except Exception as e: st.error(f"An error occurred: {str(e)}") if os.path.exists(temp_path): os.remove(temp_path) if __name__ == "__main__": main()