import os os.environ["OMP_NUM_THREADS"] = "1" import time import glob import tempfile from typing import List, Tuple # ... rest of your imports (import streamlit, import cv2, etc.) ... import cv2 import numpy as np import streamlit as st from ultralytics import YOLO st.set_page_config(page_title="Weapon Detection", layout="wide") st.sidebar.header("Model & Source") model_path = st.sidebar.text_input( "Model path", value="src/wd.pt", help="Absolute or relative path to your trained model weights.", key="model_path", ) use_gpu = st.sidebar.checkbox("Use GPU (if available)", value=False, help="Requires CUDA-enabled PyTorch", key="use_gpu") source_mode = st.sidebar.radio( "Choose source", options=[ "Upload image(s)", "Local image path", "Upload a video", "Local video path", "Webcam", ], index=0, key="source_mode", ) conf = st.sidebar.slider("Confidence threshold", 0.05, 0.95, 0.35, 0.01, key="conf") iou = st.sidebar.slider("IoU (NMS)", 0.10, 0.90, 0.45, 0.01, key="iou") imgsz = st.sidebar.selectbox("Inference size (imgsz)", [320, 416, 512, 640, 960], index=3, key="imgsz") # Skip-frames option (1 = no skip) skip_n = st.sidebar.number_input( "Process every Nth frame (video/webcam)", min_value=1, max_value=10, value=2, step=1, key="skip_n" ) # Inputs (declared once) uploaded_images: List = [] uploaded_video = None local_image_path = "" local_video_path = "" cam_index = 0 if source_mode == "Upload image(s)": uploaded_images = st.sidebar.file_uploader( "Upload image(s)", type=["jpg", "jpeg", "png", "bmp", "webp"], accept_multiple_files=True, key="uploader_images", ) elif source_mode == "Local image path": local_image_path = st.sidebar.text_input( "Image file OR folder path (reads *.jpg, *.jpeg, *.png, *.bmp, *.webp)", value=r"d:/datasets/1 weapons/sample.jpg", key="local_image_path", ) elif source_mode == "Upload a video": uploaded_video = st.sidebar.file_uploader( "Upload a video", type=["mp4", "avi", "mov", "mkv"], key="uploader_video" ) elif source_mode == "Local video path": local_video_path = st.sidebar.text_input( "Video file path", value=r"e:/gun 2 video.mp4", help="Use a full path. For spaces, prefer raw string like r'e:/gun 2 video.mp4'.", key="local_video_path", ) else: cam_index = st.sidebar.number_input("Webcam index", min_value=0, value=0, step=1, key="cam_index") start_clicked = st.sidebar.button("▶ Start", key="btn_start") # ========================= # Utilities # ========================= @st.cache_resource(show_spinner=True) def load_model(weights_path: str, want_gpu: bool): if not os.path.exists(weights_path): raise FileNotFoundError(f"Model weights not found: {weights_path}") m = YOLO(weights_path) if want_gpu: try: import torch if torch.cuda.is_available(): m.to("cuda") else: st.warning("CUDA not available; running on CPU.") except Exception as e: st.warning(f"Could not move model to GPU: {e}") return m def read_image_from_upload(upload) -> np.ndarray: """Read an uploaded image file_uploader object into a BGR numpy array.""" file_bytes = np.asarray(bytearray(upload.read()), dtype=np.uint8) return cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) # BGR def collect_local_images(path_str: str) -> List[str]: """Return list of image paths from a file or a directory.""" if not path_str: return [] if os.path.isdir(path_str): exts = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"] files = [] for ext in exts: files.extend(glob.glob(os.path.join(path_str, ext))) return sorted(files) if os.path.isfile(path_str): return [path_str] return [] def infer_and_annotate_images( model: YOLO, images_bgr: List[Tuple[str, np.ndarray]], conf: float, iou: float, imgsz: int ) -> List[Tuple[str, np.ndarray, dict]]: """ Run inference on list of (name, BGR image) and return (name, RGB annotated, summary dict). """ out = [] for name, bgr in images_bgr: res = model.predict(bgr, conf=conf, iou=iou, imgsz=imgsz, verbose=False)[0] annotated_bgr = res.plot() annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) counts = {} if res.boxes is not None and len(res.boxes) > 0: cls_ids = res.boxes.cls.cpu().numpy().astype(int) for cid in cls_ids: counts[cid] = counts.get(cid, 0) + 1 out.append((name, annotated_rgb, {"detections": counts, "shape": annotated_rgb.shape})) return out def open_video_capture(mode, uploaded_file, local_path_str, cam_idx): """ Return (cv2.VideoCapture, cleanup_callback or None, opened_path_str or None). """ cleanup = None opened_path = None if mode == "Upload a video": if not uploaded_file: st.warning("Please upload a video to start.") return None, None, None suffix = os.path.splitext(uploaded_file.name)[1] tfile = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) tfile.write(uploaded_file.read()) tfile.flush() tfile.close() opened_path = tfile.name cap = cv2.VideoCapture(opened_path) def _cleanup(): try: os.unlink(opened_path) except Exception: pass cleanup = _cleanup elif mode == "Local video path": if not local_path_str or not os.path.exists(local_path_str): st.error("Invalid or missing local video path.") return None, None, None opened_path = local_path_str cap = cv2.VideoCapture(opened_path) else: # Webcam cap = cv2.VideoCapture(int(cam_idx)) opened_path = f"webcam:{cam_idx}" if not cap or not cap.isOpened(): st.error("Failed to open video source. Check the path/index and permissions.") if cleanup: cleanup() return None, None, None return cap, cleanup, opened_path # ========================= # Main UI # ========================= st.title("🔫 WEAPON DETECTION IN SURVEILLANCE VIDEOS") with st.expander("Notes & Tips", expanded=False): st.markdown( """ - Renders with `st.image()` (no `cv2.imshow()`). - Linux deps if needed: `sudo apt-get update && sudo apt-get install -y libgl1 ffmpeg` - Lower `imgsz` (e.g., 320) and increase **Process every Nth frame** for more FPS. - Enable **Use GPU** if your PyTorch is CUDA-enabled. """ ) frame_area = st.empty() stats_col1, stats_col2, stats_col3 = st.columns(3) # ========================= # Run # ========================= if start_clicked: try: model = load_model(st.session_state.model_path, st.session_state.use_gpu) except Exception as e: st.exception(e) st.stop() # ---------- IMAGE MODES ---------- if source_mode in ("Upload image(s)", "Local image path"): images_to_process: List[Tuple[str, np.ndarray]] = [] if source_mode == "Upload image(s)": if not uploaded_images: st.warning("Please upload one or more images.") st.stop() for up in uploaded_images: bgr = read_image_from_upload(up) if bgr is None: st.warning(f"Could not read {up.name}") continue images_to_process.append((up.name, bgr)) else: # Local image path paths = collect_local_images(local_image_path) if not paths: st.error("No images found at the provided path.") st.stop() for p in paths: bgr = cv2.imread(p, cv2.IMREAD_COLOR) if bgr is None: st.warning(f"Could not read: {p}") continue images_to_process.append((os.path.basename(p), bgr)) # Inference on images results = infer_and_annotate_images( model, images_to_process, st.session_state.conf, st.session_state.iou, st.session_state.imgsz ) # Display results (grid) n = len(results) cols = st.columns(3) if n >= 3 else st.columns(max(1, n)) for idx, (name, annotated_rgb, summary) in enumerate(results): with cols[idx % len(cols)]: st.image(annotated_rgb, caption=f"{name} | detections: {summary['detections']}", use_container_width=True) st.success(f"Processed {len(results)} image(s).") # ---------- VIDEO / WEBCAM MODES ---------- else: cap, cleanup_cb, opened_path = open_video_capture( source_mode, uploaded_video, local_video_path, st.session_state.get("cam_index", 0) ) if cap is None: st.stop() st.success(f"Opened source: {opened_path}") # FPS (for info only; we don't throttle) fps_src = cap.get(cv2.CAP_PROP_FPS) if not fps_src or fps_src <= 0 or fps_src > 120: fps_src = 30.0 frames = 0 frame_idx = 0 last_annotated = None t0 = time.time() try: while True: ok, frame = cap.read() if not ok or frame is None: st.info("End of stream or cannot read frame.") break # Skip-frame logic: run YOLO only every Nth frame or if no previous result if frame_idx % st.session_state.skip_n == 0 or last_annotated is None: results = model.predict( frame, conf=st.session_state.conf, iou=st.session_state.iou, imgsz=st.session_state.imgsz, verbose=False, ) annotated_bgr = results[0].plot() # BGR last_annotated = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) # Display the latest annotated frame (reused for skipped frames) frame_area.image(last_annotated, channels="RGB", use_container_width=True) # Stats frames += 1 frame_idx += 1 elapsed = max(time.time() - t0, 1e-6) live_fps = frames / elapsed stats_col1.metric("Source FPS (approx.)", f"{fps_src:.1f}") stats_col2.metric("Processed frames", f"{frames}") stats_col3.metric("App FPS", f"{live_fps:.1f}") # Optional tiny sleep for UI responsiveness; comment for max throughput # time.sleep(0.001) finally: cap.release() if cleanup_cb: cleanup_cb() st.success("Processing finished.")