smhs16's picture
Update app.py
1e03a8b verified
import streamlit as st
from ultralytics import YOLO
from PIL import Image, ImageDraw
import cv2
import tempfile
# --- Page Config ---
st.set_page_config(
page_title="Weapon Detection System",
page_icon="🛡️",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Sidebar ---
st.sidebar.title("⚙️ Settings")
conf_thres = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.05)
iou_thres = st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.05)
st.sidebar.info("Model: YOLOv11n (Custom Trained)")
st.sidebar.info("Classes: Weapon, Knife, Pistol")
# --- App Header ---
st.title("🛡️ AI-Powered Weapon Detection (YOLOv11n)")
st.markdown("Upload an image or video to detect weapons instantly.")
st.markdown("---")
# --- Model Loading ---
@st.cache_resource
def load_model(model_path="best.pt"):
# Load your trained YOLOv11n weights
return YOLO(model_path)
try:
model = load_model()
except Exception as e:
st.error(f"Error loading YOLOv11n model: {e}")
st.stop()
# --- File Upload ---
uploaded_file = st.file_uploader("Choose a file...", type=['png', 'jpg', 'jpeg', 'mp4', 'avi', 'mov'])
if uploaded_file:
file_type = uploaded_file.name.split('.')[-1].lower()
# --- Image Processing ---
if file_type in ['png', 'jpg', 'jpeg']:
col1, col2 = st.columns(2)
with col1:
st.subheader("📸 Original Image")
image = Image.open(uploaded_file).convert("RGB")
st.image(image, use_container_width=True)
with col2:
st.subheader("🎯 Detection Results")
if st.button("Analyze Image"):
with st.spinner("Processing..."):
results = model.predict(image, conf=conf_thres, iou=iou_thres)
result = results[0]
# Draw boxes with PIL
detected_image = image.copy()
draw = ImageDraw.Draw(detected_image)
boxes = result.boxes
if len(boxes) > 0:
class_counts = {}
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
cls_id = int(box.cls[0])
label = model.names[cls_id]
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
draw.text((x1, y1), label, fill="red")
class_counts[label] = class_counts.get(label, 0) + 1
st.image(detected_image, use_container_width=True)
st.success(f"⚠️ Threat Detected! Found {len(boxes)} potential weapon(s).")
st.markdown("#### Detailed Report:")
for name, count in class_counts.items():
st.write(f"- **{name.capitalize()}**: {count}")
else:
st.success("✅ Safe. No weapons detected.")
# --- Video Processing ---
elif file_type in ['mp4', 'avi', 'mov']:
st.subheader("🎥 Video Analysis")
# Save video to temp file
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
tfile.close()
st.video(tfile.name)
if st.button("Start Video Verification"):
cap = cv2.VideoCapture(tfile.name)
st_frame = st.empty()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.predict(frame, conf=conf_thres, iou=iou_thres)
res_frame = results[0].plot()
st_frame.image(res_frame, channels="BGR", use_container_width=True)
cap.release()
st.success("Video processing complete.")