File size: 11,010 Bytes
39f3a9b
e7f0673
39f3a9b
 
 
 
 
e7f0673
 
 
34b8c6c
9199467
 
39f3a9b
 
 
 
 
 
 
 
e667a82
39f3a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9199467
39f3a9b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
os.environ["OMP_NUM_THREADS"] = "1"
import time
import glob
import tempfile
from typing import List, Tuple


# ... rest of your imports (import streamlit, import cv2, etc.) ...

import cv2
import numpy as np
import streamlit as st
from ultralytics import YOLO

st.set_page_config(page_title="Weapon Detection", layout="wide")

st.sidebar.header("Model & Source")

model_path = st.sidebar.text_input(
    "Model path",
    value="src/wd.pt",
    help="Absolute or relative path to your trained model weights.",
    key="model_path",
)

use_gpu = st.sidebar.checkbox("Use GPU (if available)", value=False, help="Requires CUDA-enabled PyTorch", key="use_gpu")

source_mode = st.sidebar.radio(
    "Choose source",
    options=[
        "Upload image(s)",
        "Local image path",
        "Upload a video",
        "Local video path",
        "Webcam",
    ],
    index=0,
    key="source_mode",
)

conf = st.sidebar.slider("Confidence threshold", 0.05, 0.95, 0.35, 0.01, key="conf")
iou = st.sidebar.slider("IoU (NMS)", 0.10, 0.90, 0.45, 0.01, key="iou")
imgsz = st.sidebar.selectbox("Inference size (imgsz)", [320, 416, 512, 640, 960], index=3, key="imgsz")

# Skip-frames option (1 = no skip)
skip_n = st.sidebar.number_input(
    "Process every Nth frame (video/webcam)", min_value=1, max_value=10, value=2, step=1, key="skip_n"
)

# Inputs (declared once)
uploaded_images: List = []
uploaded_video = None
local_image_path = ""
local_video_path = ""
cam_index = 0

if source_mode == "Upload image(s)":
    uploaded_images = st.sidebar.file_uploader(
        "Upload image(s)",
        type=["jpg", "jpeg", "png", "bmp", "webp"],
        accept_multiple_files=True,
        key="uploader_images",
    )
elif source_mode == "Local image path":
    local_image_path = st.sidebar.text_input(
        "Image file OR folder path (reads *.jpg, *.jpeg, *.png, *.bmp, *.webp)",
        value=r"d:/datasets/1 weapons/sample.jpg",
        key="local_image_path",
    )
elif source_mode == "Upload a video":
    uploaded_video = st.sidebar.file_uploader(
        "Upload a video", type=["mp4", "avi", "mov", "mkv"], key="uploader_video"
    )
elif source_mode == "Local video path":
    local_video_path = st.sidebar.text_input(
        "Video file path",
        value=r"e:/gun 2 video.mp4",
        help="Use a full path. For spaces, prefer raw string like r'e:/gun 2 video.mp4'.",
        key="local_video_path",
    )
else:
    cam_index = st.sidebar.number_input("Webcam index", min_value=0, value=0, step=1, key="cam_index")

start_clicked = st.sidebar.button("▶ Start", key="btn_start")

# =========================
# Utilities
# =========================
@st.cache_resource(show_spinner=True)
def load_model(weights_path: str, want_gpu: bool):
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Model weights not found: {weights_path}")
    m = YOLO(weights_path)
    if want_gpu:
        try:
            import torch
            if torch.cuda.is_available():
                m.to("cuda")
            else:
                st.warning("CUDA not available; running on CPU.")
        except Exception as e:
            st.warning(f"Could not move model to GPU: {e}")
    return m

def read_image_from_upload(upload) -> np.ndarray:
    """Read an uploaded image file_uploader object into a BGR numpy array."""
    file_bytes = np.asarray(bytearray(upload.read()), dtype=np.uint8)
    return cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)  # BGR

def collect_local_images(path_str: str) -> List[str]:
    """Return list of image paths from a file or a directory."""
    if not path_str:
        return []
    if os.path.isdir(path_str):
        exts = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]
        files = []
        for ext in exts:
            files.extend(glob.glob(os.path.join(path_str, ext)))
        return sorted(files)
    if os.path.isfile(path_str):
        return [path_str]
    return []

def infer_and_annotate_images(
    model: YOLO, images_bgr: List[Tuple[str, np.ndarray]], conf: float, iou: float, imgsz: int
) -> List[Tuple[str, np.ndarray, dict]]:
    """
    Run inference on list of (name, BGR image) and return (name, RGB annotated, summary dict).
    """
    out = []
    for name, bgr in images_bgr:
        res = model.predict(bgr, conf=conf, iou=iou, imgsz=imgsz, verbose=False)[0]
        annotated_bgr = res.plot()
        annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)

        counts = {}
        if res.boxes is not None and len(res.boxes) > 0:
            cls_ids = res.boxes.cls.cpu().numpy().astype(int)
            for cid in cls_ids:
                counts[cid] = counts.get(cid, 0) + 1

        out.append((name, annotated_rgb, {"detections": counts, "shape": annotated_rgb.shape}))
    return out

def open_video_capture(mode, uploaded_file, local_path_str, cam_idx):
    """
    Return (cv2.VideoCapture, cleanup_callback or None, opened_path_str or None).
    """
    cleanup = None
    opened_path = None

    if mode == "Upload a video":
        if not uploaded_file:
            st.warning("Please upload a video to start.")
            return None, None, None
        suffix = os.path.splitext(uploaded_file.name)[1]
        tfile = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
        tfile.write(uploaded_file.read())
        tfile.flush()
        tfile.close()
        opened_path = tfile.name
        cap = cv2.VideoCapture(opened_path)

        def _cleanup():
            try:
                os.unlink(opened_path)
            except Exception:
                pass

        cleanup = _cleanup

    elif mode == "Local video path":
        if not local_path_str or not os.path.exists(local_path_str):
            st.error("Invalid or missing local video path.")
            return None, None, None
        opened_path = local_path_str
        cap = cv2.VideoCapture(opened_path)

    else:  # Webcam
        cap = cv2.VideoCapture(int(cam_idx))
        opened_path = f"webcam:{cam_idx}"

    if not cap or not cap.isOpened():
        st.error("Failed to open video source. Check the path/index and permissions.")
        if cleanup:
            cleanup()
        return None, None, None

    return cap, cleanup, opened_path

# =========================
# Main UI
# =========================
st.title("🔫 WEAPON DETECTION IN SURVEILLANCE VIDEOS")

with st.expander("Notes & Tips", expanded=False):
    st.markdown(
        """
- Renders with `st.image()` (no `cv2.imshow()`).
- Linux deps if needed: `sudo apt-get update && sudo apt-get install -y libgl1 ffmpeg`
- Lower `imgsz` (e.g., 320) and increase **Process every Nth frame** for more FPS.
- Enable **Use GPU** if your PyTorch is CUDA-enabled.
        """
    )

frame_area = st.empty()
stats_col1, stats_col2, stats_col3 = st.columns(3)

# =========================
# Run
# =========================
if start_clicked:
    try:
        model = load_model(st.session_state.model_path, st.session_state.use_gpu)
    except Exception as e:
        st.exception(e)
        st.stop()

    # ---------- IMAGE MODES ----------
    if source_mode in ("Upload image(s)", "Local image path"):
        images_to_process: List[Tuple[str, np.ndarray]] = []

        if source_mode == "Upload image(s)":
            if not uploaded_images:
                st.warning("Please upload one or more images.")
                st.stop()
            for up in uploaded_images:
                bgr = read_image_from_upload(up)
                if bgr is None:
                    st.warning(f"Could not read {up.name}")
                    continue
                images_to_process.append((up.name, bgr))
        else:  # Local image path
            paths = collect_local_images(local_image_path)
            if not paths:
                st.error("No images found at the provided path.")
                st.stop()
            for p in paths:
                bgr = cv2.imread(p, cv2.IMREAD_COLOR)
                if bgr is None:
                    st.warning(f"Could not read: {p}")
                    continue
                images_to_process.append((os.path.basename(p), bgr))

        # Inference on images
        results = infer_and_annotate_images(
            model, images_to_process, st.session_state.conf, st.session_state.iou, st.session_state.imgsz
        )

        # Display results (grid)
        n = len(results)
        cols = st.columns(3) if n >= 3 else st.columns(max(1, n))
        for idx, (name, annotated_rgb, summary) in enumerate(results):
            with cols[idx % len(cols)]:
                st.image(annotated_rgb, caption=f"{name} | detections: {summary['detections']}", use_container_width=True)

        st.success(f"Processed {len(results)} image(s).")

    # ---------- VIDEO / WEBCAM MODES ----------
    else:
        cap, cleanup_cb, opened_path = open_video_capture(
            source_mode, uploaded_video, local_video_path, st.session_state.get("cam_index", 0)
        )
        if cap is None:
            st.stop()

        st.success(f"Opened source: {opened_path}")

        # FPS (for info only; we don't throttle)
        fps_src = cap.get(cv2.CAP_PROP_FPS)
        if not fps_src or fps_src <= 0 or fps_src > 120:
            fps_src = 30.0

        frames = 0
        frame_idx = 0
        last_annotated = None
        t0 = time.time()

        try:
            while True:
                ok, frame = cap.read()
                if not ok or frame is None:
                    st.info("End of stream or cannot read frame.")
                    break

                # Skip-frame logic: run YOLO only every Nth frame or if no previous result
                if frame_idx % st.session_state.skip_n == 0 or last_annotated is None:
                    results = model.predict(
                        frame,
                        conf=st.session_state.conf,
                        iou=st.session_state.iou,
                        imgsz=st.session_state.imgsz,
                        verbose=False,
                    )
                    annotated_bgr = results[0].plot()  # BGR
                    last_annotated = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)

                # Display the latest annotated frame (reused for skipped frames)
                frame_area.image(last_annotated, channels="RGB", use_container_width=True)

                # Stats
                frames += 1
                frame_idx += 1
                elapsed = max(time.time() - t0, 1e-6)
                live_fps = frames / elapsed
                stats_col1.metric("Source FPS (approx.)", f"{fps_src:.1f}")
                stats_col2.metric("Processed frames", f"{frames}")
                stats_col3.metric("App FPS", f"{live_fps:.1f}")

                # Optional tiny sleep for UI responsiveness; comment for max throughput
                # time.sleep(0.001)

        finally:
            cap.release()
            if cleanup_cb:
                cleanup_cb()
            st.success("Processing finished.")