npd / app_secure.py
hostel's picture
Upload 16 files
ce6e2a4 verified
Raw
History Blame Contribute Delete
9.14 kB
import streamlit as st
import pandas as pd
import cv2
import numpy as np
from ultralytics import YOLO
from datetime import datetime
import plotly.express as px
import os
import torch
import torch.nn as nn
import sys
import subprocess
from collections import OrderedDict
import yaml
from pathlib import Path
import pickle
import importlib
# Whitelist safe modules for unpickling
SAFE_MODULES = {
'torch.Size',
'torch.LongStorage',
'torch.HalfStorage',
'torch.FloatStorage',
'torch.nn.modules.container.Sequential',
'torch.nn.modules.container.ModuleList',
'torch.nn.modules.activation.SiLU',
'torch.nn.modules.conv.Conv2d',
'torch.nn.modules.batchnorm.BatchNorm2d',
'torch._utils._rebuild_tensor_v2',
'torch._utils._rebuild_parameter',
'collections.OrderedDict',
'numpy.core.multiarray.scalar',
'numpy.dtype',
'ultralytics.nn.modules.Detect',
'ultralytics.nn.modules.SPPF',
'ultralytics.nn.modules.DFL',
'ultralytics.nn.modules.Conv',
'ultralytics.nn.modules.Bottleneck',
'ultralytics.nn.modules.C2f',
'ultralytics.nn.modules.Concat',
'ultralytics.nn.tasks.DetectionModel',
'ultralytics.yolo.utils.IterableSimpleNamespace'
}
# Custom safe unpickler
class SafeUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Check if the module and class combination is in our whitelist
fullname = f"{module}.{name}"
if fullname in SAFE_MODULES:
# Import the module and return the class
if module not in sys.modules:
importlib.import_module(module)
return getattr(sys.modules[module], name)
# If not in whitelist, raise an error
raise pickle.UnpicklingError(f"Attempting to unpickle unsafe module: {fullname}")
# Configure page
st.set_page_config(
page_title="License Plate Detection",
page_icon="🚗",
layout="wide"
)
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model = None
if 'processed_frames' not in st.session_state:
st.session_state.processed_frames = 0
if 'detections' not in st.session_state:
st.session_state.detections = []
# Safe model loading with custom unpickler
@st.cache_resource
def load_model():
try:
# Set model directory
model_path = Path('best.pt')
if not model_path.exists():
raise FileNotFoundError("Model file not found")
# Initialize YOLO with safe loading
model = YOLO(
model_path,
task='detect',
verbose=False
)
# Force model to appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Verify model structure
if not hasattr(model, 'model') or not isinstance(model.model, nn.Module):
raise ValueError("Invalid model structure")
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
# Safe frame processing
def process_frame(frame, model):
try:
if model is None:
return []
# Ensure frame is valid
if frame is None or not isinstance(frame, np.ndarray):
return []
# Make prediction with error handling
with torch.no_grad():
results = model.predict(
source=frame,
conf=0.25,
iou=0.45,
verbose=False
)
# Safely extract results
if results and len(results) > 0:
return results[0].boxes.data.cpu().numpy()
return []
except Exception as e:
st.error(f"Error processing frame: {str(e)}")
return []
def main():
st.title("License Plate Detection System")
# Sidebar controls
with st.sidebar:
st.header("Controls")
confidence_threshold = st.slider(
"Confidence Threshold",
min_value=0.0,
max_value=1.0,
value=0.25,
step=0.05
)
# Load model
if st.session_state.model is None:
with st.spinner("Loading model..."):
st.session_state.model = load_model()
if st.session_state.model is None:
st.error("Failed to load model. Please check the model file.")
return
# File uploader
video_file = st.file_uploader(
"Upload Video",
type=['mp4', 'avi', 'mov'],
help="Upload a video file containing license plates"
)
if video_file:
try:
# Save uploaded file
temp_path = "temp_video.mp4"
with open(temp_path, "wb") as f:
f.write(video_file.read())
# Process video
cap = cv2.VideoCapture(temp_path)
if not cap.isOpened():
st.error("Error opening video file")
return
# Video info
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Display elements
progress_bar = st.progress(0)
frame_placeholder = st.empty()
stats_placeholder = st.empty()
# Process frames
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
progress = int((frame_count / total_frames) * 100)
progress_bar.progress(progress)
# Process every 3rd frame
if frame_count % 3 == 0:
detections = process_frame(frame, st.session_state.model)
for det in detections:
if det[4] >= confidence_threshold:
x1, y1, x2, y2 = map(int, det[:4])
conf = float(det[4])
# Draw detection
cv2.rectangle(
frame,
(x1, y1),
(x2, y2),
(0, 255, 0),
2
)
cv2.putText(
frame,
f"{conf:.2f}",
(x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
2
)
# Save detection
st.session_state.detections.append({
'time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'confidence': conf
})
# Display frame
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_placeholder.image(
frame_rgb,
channels="RGB",
use_column_width=True
)
# Update stats
with stats_placeholder:
col1, col2 = st.columns(2)
with col1:
st.metric("Processed Frames", frame_count)
with col2:
st.metric("Detections", len(st.session_state.detections))
# Clean up
cap.release()
if os.path.exists(temp_path):
os.remove(temp_path)
# Show results
if st.session_state.detections:
st.header("Detection Statistics")
df = pd.DataFrame(st.session_state.detections)
# Confidence distribution
fig = px.histogram(
df,
x='confidence',
title='Detection Confidence Distribution',
labels={'confidence': 'Confidence Score'}
)
st.plotly_chart(fig, use_container_width=True)
# Results table
st.dataframe(
df,
use_container_width=True
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if os.path.exists(temp_path):
os.remove(temp_path)
if __name__ == "__main__":
main()