Upload 22 files
Browse files- src/__init__.py +22 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/detectors/__init__.py +19 -0
- src/detectors/__pycache__/__init__.cpython-310.pyc +0 -0
- src/detectors/__pycache__/anomaly_detector.cpython-310.pyc +0 -0
- src/detectors/__pycache__/pose_detector.cpython-310.pyc +0 -0
- src/detectors/__pycache__/violence_detector.cpython-310.pyc +0 -0
- src/detectors/__pycache__/weapon_detector.cpython-310.pyc +0 -0
- src/detectors/__pycache__/yolo_detector.cpython-310.pyc +0 -0
- src/detectors/anomaly_detector.py +194 -0
- src/detectors/pose_detector.py +672 -0
- src/detectors/violence_detector.py +296 -0
- src/detectors/weapon_detector.py +377 -0
- src/detectors/yolo_detector.py +86 -0
- src/pipeline/__init__.py +8 -0
- src/pipeline/__pycache__/__init__.cpython-310.pyc +0 -0
- src/pipeline/__pycache__/video_capture.cpython-310.pyc +0 -0
- src/pipeline/video_capture.py +446 -0
- src/utils/__init__.py +34 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/model_downloader.cpython-310.pyc +0 -0
- src/utils/model_downloader.py +147 -0
src/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NETRA Source Code
|
| 3 |
+
Core detection and pipeline modules
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .detectors import (
|
| 7 |
+
YOLODetector,
|
| 8 |
+
ViolenceDetector,
|
| 9 |
+
WeaponPersonDetector,
|
| 10 |
+
PoseDetection,
|
| 11 |
+
AnomalyDetector,
|
| 12 |
+
)
|
| 13 |
+
from .pipeline import VideoCapture
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'YOLODetector',
|
| 17 |
+
'ViolenceDetector',
|
| 18 |
+
'WeaponPersonDetector',
|
| 19 |
+
'PoseDetection',
|
| 20 |
+
'AnomalyDetector',
|
| 21 |
+
'VideoCapture',
|
| 22 |
+
]
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (569 Bytes). View file
|
|
|
src/detectors/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NETRA Detection Modules
|
| 3 |
+
Core AI detection components for video surveillance
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .yolo_detector import YOLODetector, Detection
|
| 7 |
+
from .violence_detector import ViolenceDetector
|
| 8 |
+
from .weapon_detector import WeaponPersonDetector
|
| 9 |
+
from .pose_detector import PoseDetection
|
| 10 |
+
from .anomaly_detector import AnomalyDetector
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
'YOLODetector',
|
| 14 |
+
'Detection',
|
| 15 |
+
'ViolenceDetector',
|
| 16 |
+
'WeaponPersonDetector',
|
| 17 |
+
'PoseDetection',
|
| 18 |
+
'AnomalyDetector',
|
| 19 |
+
]
|
src/detectors/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (697 Bytes). View file
|
|
|
src/detectors/__pycache__/anomaly_detector.cpython-310.pyc
ADDED
|
Binary file (6.25 kB). View file
|
|
|
src/detectors/__pycache__/pose_detector.cpython-310.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
src/detectors/__pycache__/violence_detector.cpython-310.pyc
ADDED
|
Binary file (8.34 kB). View file
|
|
|
src/detectors/__pycache__/weapon_detector.cpython-310.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
src/detectors/__pycache__/yolo_detector.cpython-310.pyc
ADDED
|
Binary file (3.13 kB). View file
|
|
|
src/detectors/anomaly_detector.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anomaly Detection Module
|
| 3 |
+
Loads and runs inference using the anomaly detection model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, Tuple, Dict, Any
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class AnomalyDetection:
|
| 16 |
+
"""Anomaly detection result."""
|
| 17 |
+
is_anomaly: bool = False
|
| 18 |
+
confidence: float = 0.0
|
| 19 |
+
anomaly_score: float = 0.0
|
| 20 |
+
alert_level: str = "SAFE"
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def description(self) -> str:
|
| 24 |
+
"""Get description based on anomaly detection."""
|
| 25 |
+
if not self.is_anomaly:
|
| 26 |
+
return f"Normal behavior detected (score: {self.anomaly_score:.2f})"
|
| 27 |
+
elif self.confidence >= 0.8:
|
| 28 |
+
return f"HIGH RISK - Anomaly detected (score: {self.anomaly_score:.2f})"
|
| 29 |
+
elif self.confidence >= 0.6:
|
| 30 |
+
return f"MEDIUM RISK - Possible anomaly (score: {self.anomaly_score:.2f})"
|
| 31 |
+
else:
|
| 32 |
+
return f"LOW RISK - Minor anomaly (score: {self.anomaly_score:.2f})"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AnomalyDetector:
|
| 36 |
+
"""Anomaly detection using trained model."""
|
| 37 |
+
|
| 38 |
+
def __init__(self,
|
| 39 |
+
model_path: str,
|
| 40 |
+
input_size: Tuple[int, int] = (224, 224),
|
| 41 |
+
device: str = 'cpu',
|
| 42 |
+
anomaly_threshold: float = 0.5):
|
| 43 |
+
"""
|
| 44 |
+
Initialize anomaly detector.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
model_path: Path to trained anomaly detection model (.bin file)
|
| 48 |
+
input_size: Input frame size for the model
|
| 49 |
+
device: Device to run model on ('cpu' or 'cuda')
|
| 50 |
+
anomaly_threshold: Threshold for classifying as anomaly
|
| 51 |
+
"""
|
| 52 |
+
self.model_path = Path(model_path)
|
| 53 |
+
self.input_size = input_size
|
| 54 |
+
self.device = device
|
| 55 |
+
self.anomaly_threshold = anomaly_threshold
|
| 56 |
+
self.model = None
|
| 57 |
+
self.frame_buffer = []
|
| 58 |
+
self.buffer_size = 16 # Number of frames to buffer
|
| 59 |
+
|
| 60 |
+
self._load_model()
|
| 61 |
+
|
| 62 |
+
def _load_model(self):
|
| 63 |
+
"""Load the anomaly detection model."""
|
| 64 |
+
if not self.model_path.exists():
|
| 65 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
| 66 |
+
|
| 67 |
+
self.is_tensorflow = False
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Try loading as PyTorch model first
|
| 71 |
+
self.model = torch.load(str(self.model_path), map_location=self.device)
|
| 72 |
+
self.model.eval()
|
| 73 |
+
print(f"[OK] Anomaly detection model (PyTorch) loaded from: {self.model_path}")
|
| 74 |
+
except Exception as e:
|
| 75 |
+
try:
|
| 76 |
+
# Try loading as TensorFlow SavedModel
|
| 77 |
+
import tensorflow as tf
|
| 78 |
+
self.model = tf.keras.models.load_model(str(self.model_path))
|
| 79 |
+
self.is_tensorflow = True
|
| 80 |
+
print(f"[OK] Anomaly detection model (TensorFlow) loaded from: {self.model_path}")
|
| 81 |
+
except Exception as tf_e:
|
| 82 |
+
raise RuntimeError(
|
| 83 |
+
f"Failed to load anomaly detection model. "
|
| 84 |
+
f"PyTorch error: {e}, TensorFlow error: {tf_e}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def preprocess_frame(self, frame: np.ndarray) -> np.ndarray:
|
| 88 |
+
"""Preprocess frame for model inference."""
|
| 89 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 90 |
+
resized = cv2.resize(rgb, self.input_size)
|
| 91 |
+
normalized = resized.astype(np.float32) / 255.0
|
| 92 |
+
return normalized
|
| 93 |
+
|
| 94 |
+
def predict_frame(self, frame: np.ndarray) -> Optional[AnomalyDetection]:
|
| 95 |
+
"""
|
| 96 |
+
Predict anomaly for a single frame.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
frame: Input frame (BGR format from OpenCV)
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
AnomalyDetection result or None if not enough frames buffered
|
| 103 |
+
"""
|
| 104 |
+
# Add frame to buffer
|
| 105 |
+
processed = self.preprocess_frame(frame)
|
| 106 |
+
self.frame_buffer.append(processed)
|
| 107 |
+
|
| 108 |
+
# Keep buffer at specified size
|
| 109 |
+
if len(self.frame_buffer) > self.buffer_size:
|
| 110 |
+
self.frame_buffer.pop(0)
|
| 111 |
+
|
| 112 |
+
# Need minimum frames for temporal analysis
|
| 113 |
+
if len(self.frame_buffer) < max(1, self.buffer_size // 2):
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
return self._inference()
|
| 117 |
+
|
| 118 |
+
def _inference(self) -> AnomalyDetection:
|
| 119 |
+
"""Run inference on buffered frames."""
|
| 120 |
+
try:
|
| 121 |
+
# Prepare input
|
| 122 |
+
input_data = np.array(self.frame_buffer, dtype=np.float32)
|
| 123 |
+
|
| 124 |
+
if self.is_tensorflow:
|
| 125 |
+
# TensorFlow model
|
| 126 |
+
if input_data.ndim == 3: # Single sample
|
| 127 |
+
input_data = np.expand_dims(input_data, axis=0)
|
| 128 |
+
|
| 129 |
+
prediction = self.model.predict(input_data, verbose=0)
|
| 130 |
+
else:
|
| 131 |
+
# PyTorch model
|
| 132 |
+
input_tensor = torch.FloatTensor(input_data).to(self.device)
|
| 133 |
+
if input_tensor.dim() == 3: # Add batch dimension
|
| 134 |
+
input_tensor = input_tensor.unsqueeze(0)
|
| 135 |
+
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
prediction = self.model(input_tensor)
|
| 138 |
+
|
| 139 |
+
# Extract anomaly score
|
| 140 |
+
anomaly_score = self._extract_score(prediction)
|
| 141 |
+
is_anomaly = anomaly_score >= self.anomaly_threshold
|
| 142 |
+
confidence = min(anomaly_score, 1.0)
|
| 143 |
+
|
| 144 |
+
return AnomalyDetection(
|
| 145 |
+
is_anomaly=is_anomaly,
|
| 146 |
+
confidence=confidence,
|
| 147 |
+
anomaly_score=anomaly_score,
|
| 148 |
+
alert_level=self._get_alert_level(confidence, is_anomaly)
|
| 149 |
+
)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"[ERROR] Anomaly detection inference failed: {e}")
|
| 152 |
+
return AnomalyDetection(is_anomaly=False, confidence=0.0, alert_level="ERROR")
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _extract_score(prediction) -> float:
|
| 156 |
+
"""Extract anomaly score from model output."""
|
| 157 |
+
if isinstance(prediction, torch.Tensor):
|
| 158 |
+
prediction = prediction.cpu().numpy()
|
| 159 |
+
|
| 160 |
+
# Flatten and get last value or mean
|
| 161 |
+
values = np.array(prediction).flatten()
|
| 162 |
+
|
| 163 |
+
if values.size == 0:
|
| 164 |
+
return 0.0
|
| 165 |
+
|
| 166 |
+
# Return mean or last value depending on output shape
|
| 167 |
+
if len(values) == 1:
|
| 168 |
+
return float(values[0])
|
| 169 |
+
|
| 170 |
+
return float(np.mean(values))
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def _get_alert_level(confidence: float, is_anomaly: bool) -> str:
|
| 174 |
+
"""Get alert level based on confidence and anomaly status."""
|
| 175 |
+
if not is_anomaly:
|
| 176 |
+
return "SAFE"
|
| 177 |
+
elif confidence >= 0.8:
|
| 178 |
+
return "HIGH RISK"
|
| 179 |
+
elif confidence >= 0.6:
|
| 180 |
+
return "MEDIUM RISK"
|
| 181 |
+
else:
|
| 182 |
+
return "LOW RISK"
|
| 183 |
+
|
| 184 |
+
def reset(self):
|
| 185 |
+
"""Reset frame buffer for new session."""
|
| 186 |
+
self.frame_buffer = []
|
| 187 |
+
|
| 188 |
+
def get_buffer_status(self) -> Dict[str, Any]:
|
| 189 |
+
"""Get current buffer status."""
|
| 190 |
+
return {
|
| 191 |
+
'buffered_frames': len(self.frame_buffer),
|
| 192 |
+
'buffer_size': self.buffer_size,
|
| 193 |
+
'is_ready': len(self.frame_buffer) >= max(1, self.buffer_size // 2)
|
| 194 |
+
}
|
src/detectors/pose_detector.py
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from ultralytics import YOLO
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PoseDetection:
|
| 9 |
+
def __init__(self, model_path="yolo11n-pose.pt", conf=0.25, imgsz=640, device="cpu"):
|
| 10 |
+
self.model_path = Path(model_path)
|
| 11 |
+
if not self.model_path.exists():
|
| 12 |
+
raise FileNotFoundError(
|
| 13 |
+
f"Pose model not found: {self.model_path}. "
|
| 14 |
+
"Add a YOLO pose model file such as yolo11n-pose.pt to the project or pass --model."
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
self.model = YOLO(str(self.model_path))
|
| 18 |
+
self.conf = conf
|
| 19 |
+
self.imgsz = imgsz
|
| 20 |
+
self.device = device
|
| 21 |
+
self.movement_evaluator = MovementRiskEvaluator()
|
| 22 |
+
|
| 23 |
+
def reset_movement_state(self):
|
| 24 |
+
self.movement_evaluator.reset()
|
| 25 |
+
|
| 26 |
+
def assess_movement(self, result):
|
| 27 |
+
return self.movement_evaluator.assess(result)
|
| 28 |
+
|
| 29 |
+
def predict(self, source):
|
| 30 |
+
return self.model.predict(
|
| 31 |
+
source=source,
|
| 32 |
+
conf=self.conf,
|
| 33 |
+
imgsz=self.imgsz,
|
| 34 |
+
device=self.device,
|
| 35 |
+
verbose=False,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def annotate(self, source):
|
| 39 |
+
result = self.predict(source)[0]
|
| 40 |
+
annotated = result.plot()
|
| 41 |
+
return result, annotated
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def format_detections(result, source_name: str):
|
| 45 |
+
keypoints = getattr(result, "keypoints", None)
|
| 46 |
+
if keypoints is None or keypoints.xy is None or len(keypoints.xy) == 0:
|
| 47 |
+
return [f"{source_name}: no poses detected"]
|
| 48 |
+
|
| 49 |
+
lines = [f"{source_name}: detected {len(keypoints.xy)} pose(s)"]
|
| 50 |
+
confidences = getattr(keypoints, "conf", None)
|
| 51 |
+
for index, pose_points in enumerate(keypoints.xy, start=1):
|
| 52 |
+
visible_points = int(np.sum(np.any(pose_points.cpu().numpy() > 0, axis=1)))
|
| 53 |
+
if confidences is not None:
|
| 54 |
+
pose_conf = float(np.nanmean(confidences[index - 1].cpu().numpy()))
|
| 55 |
+
lines.append(
|
| 56 |
+
f"Pose {index}: {visible_points} visible keypoints, average confidence {pose_conf:.2f}"
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
lines.append(f"Pose {index}: {visible_points} visible keypoints")
|
| 60 |
+
return lines
|
| 61 |
+
|
| 62 |
+
def save_image_result(self, image_path, output_dir):
|
| 63 |
+
image_path = Path(image_path)
|
| 64 |
+
output_dir = Path(output_dir)
|
| 65 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
self.reset_movement_state()
|
| 68 |
+
result, annotated = self.annotate(str(image_path))
|
| 69 |
+
movement = self.assess_movement(result)
|
| 70 |
+
output_path = output_dir / image_path.name
|
| 71 |
+
preview_path = output_dir / f"{image_path.stem}_preview.png"
|
| 72 |
+
cv2.imwrite(str(output_path), annotated)
|
| 73 |
+
cv2.imwrite(str(preview_path), annotated)
|
| 74 |
+
|
| 75 |
+
lines = self.format_detections(result, image_path.name)
|
| 76 |
+
lines.append(
|
| 77 |
+
f"Movement risk: {movement['risk_level']} | action: {movement['action']} | score: {movement['risk_score']:.2f}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
"type": "image",
|
| 82 |
+
"source": str(image_path),
|
| 83 |
+
"output_path": str(output_path),
|
| 84 |
+
"preview_path": str(preview_path),
|
| 85 |
+
"lines": lines,
|
| 86 |
+
"movement": movement,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def save_video_result(self, video_path, output_dir, show=False, window_name="Pose Detection"):
|
| 90 |
+
video_path = Path(video_path)
|
| 91 |
+
output_dir = Path(output_dir)
|
| 92 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 95 |
+
if not cap.isOpened():
|
| 96 |
+
raise RuntimeError(f"Could not open video: {video_path}")
|
| 97 |
+
|
| 98 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
|
| 99 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 480
|
| 100 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| 101 |
+
|
| 102 |
+
output_path = output_dir / f"{video_path.stem}_pose.mp4"
|
| 103 |
+
preview_path = output_dir / f"{video_path.stem}_preview.png"
|
| 104 |
+
writer = cv2.VideoWriter(
|
| 105 |
+
str(output_path),
|
| 106 |
+
cv2.VideoWriter_fourcc(*"mp4v"),
|
| 107 |
+
fps,
|
| 108 |
+
(width, height),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
frame_count = 0
|
| 112 |
+
last_result = None
|
| 113 |
+
last_movement = None
|
| 114 |
+
risk_counters = {"SAFE": 0, "LOW_RISK": 0, "HIGH_RISK": 0}
|
| 115 |
+
preview_saved = False
|
| 116 |
+
self.reset_movement_state()
|
| 117 |
+
while True:
|
| 118 |
+
ok, frame = cap.read()
|
| 119 |
+
if not ok:
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
last_result, annotated = self.annotate(frame)
|
| 123 |
+
last_movement = self.assess_movement(last_result)
|
| 124 |
+
risk_counters[last_movement["risk_level"]] += 1
|
| 125 |
+
self.draw_movement_banner(annotated, last_movement)
|
| 126 |
+
writer.write(annotated)
|
| 127 |
+
if not preview_saved:
|
| 128 |
+
cv2.imwrite(str(preview_path), annotated)
|
| 129 |
+
preview_saved = True
|
| 130 |
+
|
| 131 |
+
if show:
|
| 132 |
+
cv2.imshow(window_name, annotated)
|
| 133 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
frame_count += 1
|
| 137 |
+
|
| 138 |
+
cap.release()
|
| 139 |
+
writer.release()
|
| 140 |
+
cv2.destroyAllWindows()
|
| 141 |
+
|
| 142 |
+
lines = [f"{video_path.name}: processed {frame_count} frames"]
|
| 143 |
+
if last_result is not None:
|
| 144 |
+
lines.extend(self.format_detections(last_result, f"{video_path.name} last frame"))
|
| 145 |
+
if last_movement is not None:
|
| 146 |
+
lines.append(
|
| 147 |
+
f"Last movement risk: {last_movement['risk_level']} | action: {last_movement['action']} | score: {last_movement['risk_score']:.2f}"
|
| 148 |
+
)
|
| 149 |
+
lines.append(
|
| 150 |
+
"Risk distribution: "
|
| 151 |
+
f"SAFE={risk_counters['SAFE']}, LOW_RISK={risk_counters['LOW_RISK']}, HIGH_RISK={risk_counters['HIGH_RISK']}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"type": "video",
|
| 156 |
+
"source": str(video_path),
|
| 157 |
+
"output_path": str(output_path),
|
| 158 |
+
"preview_path": str(preview_path) if preview_saved else None,
|
| 159 |
+
"lines": lines,
|
| 160 |
+
"movement": last_movement,
|
| 161 |
+
"risk_counters": risk_counters,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def run_webcam(self, camera_index=0, output_dir="runs/pose_inference", show=True):
|
| 165 |
+
output_dir = Path(output_dir)
|
| 166 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
cap = cv2.VideoCapture(camera_index)
|
| 169 |
+
if not cap.isOpened():
|
| 170 |
+
raise RuntimeError(f"Could not open webcam index: {camera_index}")
|
| 171 |
+
|
| 172 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 640
|
| 173 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 480
|
| 174 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 20.0
|
| 175 |
+
output_path = output_dir / f"webcam_{camera_index}_pose.mp4"
|
| 176 |
+
writer = cv2.VideoWriter(
|
| 177 |
+
str(output_path),
|
| 178 |
+
cv2.VideoWriter_fourcc(*"mp4v"),
|
| 179 |
+
fps,
|
| 180 |
+
(width, height),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
frame_count = 0
|
| 184 |
+
last_result = None
|
| 185 |
+
last_movement = None
|
| 186 |
+
risk_counters = {"SAFE": 0, "LOW_RISK": 0, "HIGH_RISK": 0}
|
| 187 |
+
preview_path = output_dir / f"webcam_{camera_index}_preview.png"
|
| 188 |
+
preview_saved = False
|
| 189 |
+
self.reset_movement_state()
|
| 190 |
+
while True:
|
| 191 |
+
ok, frame = cap.read()
|
| 192 |
+
if not ok:
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
last_result, annotated = self.annotate(frame)
|
| 196 |
+
last_movement = self.assess_movement(last_result)
|
| 197 |
+
risk_counters[last_movement["risk_level"]] += 1
|
| 198 |
+
self.draw_movement_banner(annotated, last_movement)
|
| 199 |
+
writer.write(annotated)
|
| 200 |
+
if not preview_saved:
|
| 201 |
+
cv2.imwrite(str(preview_path), annotated)
|
| 202 |
+
preview_saved = True
|
| 203 |
+
|
| 204 |
+
if show:
|
| 205 |
+
cv2.imshow("Pose Detection", annotated)
|
| 206 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
frame_count += 1
|
| 210 |
+
|
| 211 |
+
cap.release()
|
| 212 |
+
writer.release()
|
| 213 |
+
cv2.destroyAllWindows()
|
| 214 |
+
|
| 215 |
+
lines = [f"webcam_{camera_index}: processed {frame_count} frames"]
|
| 216 |
+
if last_result is not None:
|
| 217 |
+
lines.extend(self.format_detections(last_result, f"webcam_{camera_index} last frame"))
|
| 218 |
+
if last_movement is not None:
|
| 219 |
+
lines.append(
|
| 220 |
+
f"Last movement risk: {last_movement['risk_level']} | action: {last_movement['action']} | score: {last_movement['risk_score']:.2f}"
|
| 221 |
+
)
|
| 222 |
+
lines.append(
|
| 223 |
+
"Risk distribution: "
|
| 224 |
+
f"SAFE={risk_counters['SAFE']}, LOW_RISK={risk_counters['LOW_RISK']}, HIGH_RISK={risk_counters['HIGH_RISK']}"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return {
|
| 228 |
+
"type": "webcam",
|
| 229 |
+
"source": str(camera_index),
|
| 230 |
+
"output_path": str(output_path),
|
| 231 |
+
"preview_path": str(preview_path) if preview_saved else None,
|
| 232 |
+
"lines": lines,
|
| 233 |
+
"movement": last_movement,
|
| 234 |
+
"risk_counters": risk_counters,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
def draw_movement_banner(frame, movement):
|
| 239 |
+
level = movement["risk_level"]
|
| 240 |
+
action = movement["action"]
|
| 241 |
+
score = movement["risk_score"]
|
| 242 |
+
if level == "HIGH_RISK":
|
| 243 |
+
color = (0, 0, 255)
|
| 244 |
+
elif level == "LOW_RISK":
|
| 245 |
+
color = (0, 165, 255)
|
| 246 |
+
else:
|
| 247 |
+
color = (0, 128, 0)
|
| 248 |
+
|
| 249 |
+
text = f"Risk: {level} | Action: {action} | Score: {score:.2f}"
|
| 250 |
+
cv2.rectangle(frame, (0, 0), (frame.shape[1], 40), color, -1)
|
| 251 |
+
cv2.putText(
|
| 252 |
+
frame,
|
| 253 |
+
text,
|
| 254 |
+
(10, 27),
|
| 255 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 256 |
+
0.62,
|
| 257 |
+
(255, 255, 255),
|
| 258 |
+
2,
|
| 259 |
+
cv2.LINE_AA,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class MovementRiskEvaluator:
|
| 264 |
+
def __init__(self, keypoint_conf_threshold=0.3):
|
| 265 |
+
self.keypoint_conf_threshold = keypoint_conf_threshold
|
| 266 |
+
self.prev_people = None
|
| 267 |
+
|
| 268 |
+
def reset(self):
|
| 269 |
+
self.prev_people = None
|
| 270 |
+
|
| 271 |
+
def assess(self, result):
|
| 272 |
+
people = self._extract_people(result)
|
| 273 |
+
if not people:
|
| 274 |
+
self.prev_people = None
|
| 275 |
+
return {
|
| 276 |
+
"risk_level": "SAFE",
|
| 277 |
+
"action": "other",
|
| 278 |
+
"risk_score": 0.0,
|
| 279 |
+
"details": "no_pose",
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
speed_stats = self._compute_motion(people)
|
| 283 |
+
crowd_stats = self._compute_crowd_features(people)
|
| 284 |
+
pose_stats = self._compute_pose_features(people)
|
| 285 |
+
|
| 286 |
+
# Compute violence action scores
|
| 287 |
+
punch_score = self._score_punch(speed_stats, crowd_stats, pose_stats)
|
| 288 |
+
kick_score = self._score_kick(speed_stats, crowd_stats, pose_stats)
|
| 289 |
+
slap_score = self._score_slap(speed_stats, crowd_stats, pose_stats)
|
| 290 |
+
push_score = self._score_push(speed_stats, crowd_stats, pose_stats)
|
| 291 |
+
throw_score = self._score_throw(speed_stats, crowd_stats, pose_stats)
|
| 292 |
+
headbutt_score = self._score_headbutt(speed_stats, crowd_stats, pose_stats)
|
| 293 |
+
choking_score = self._score_choking(speed_stats, crowd_stats, pose_stats)
|
| 294 |
+
weapon_score = self._score_weapon(speed_stats, crowd_stats, pose_stats)
|
| 295 |
+
aggressive_grab_score = self._score_aggressive_grab(speed_stats, crowd_stats, pose_stats)
|
| 296 |
+
grappling_score = self._score_grappling(speed_stats, crowd_stats, pose_stats)
|
| 297 |
+
falling_score = self._score_falling(speed_stats, crowd_stats, pose_stats)
|
| 298 |
+
defensive_score = self._score_defensive(speed_stats, crowd_stats, pose_stats)
|
| 299 |
+
running_score = self._score_running(speed_stats, crowd_stats, pose_stats)
|
| 300 |
+
|
| 301 |
+
# Find best matching action
|
| 302 |
+
action_scores = {
|
| 303 |
+
"punch": punch_score,
|
| 304 |
+
"kick": kick_score,
|
| 305 |
+
"slap": slap_score,
|
| 306 |
+
"push": push_score,
|
| 307 |
+
"throw": throw_score,
|
| 308 |
+
"headbutt": headbutt_score,
|
| 309 |
+
"choking": choking_score,
|
| 310 |
+
"weapon": weapon_score,
|
| 311 |
+
"aggressive_grab": aggressive_grab_score,
|
| 312 |
+
"grappling": grappling_score,
|
| 313 |
+
"falling": falling_score,
|
| 314 |
+
"defensive": defensive_score,
|
| 315 |
+
"running": running_score,
|
| 316 |
+
"other": 0.0,
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
best_action = max(action_scores, key=action_scores.get)
|
| 320 |
+
risk_score = action_scores[best_action]
|
| 321 |
+
|
| 322 |
+
# Determine risk level based on action and score
|
| 323 |
+
if best_action in ["punch", "kick", "throw", "headbutt", "choking", "weapon", "grappling"]:
|
| 324 |
+
if risk_score >= 0.85:
|
| 325 |
+
risk_level = "HIGH_RISK"
|
| 326 |
+
elif risk_score >= 0.5:
|
| 327 |
+
risk_level = "LOW_RISK"
|
| 328 |
+
else:
|
| 329 |
+
risk_level = "SAFE"
|
| 330 |
+
best_action = "other"
|
| 331 |
+
elif best_action in ["slap", "push", "aggressive_grab", "falling"]:
|
| 332 |
+
if risk_score >= 0.8:
|
| 333 |
+
risk_level = "HIGH_RISK"
|
| 334 |
+
elif risk_score >= 0.5:
|
| 335 |
+
risk_level = "LOW_RISK"
|
| 336 |
+
else:
|
| 337 |
+
risk_level = "SAFE"
|
| 338 |
+
best_action = "other"
|
| 339 |
+
elif best_action == "defensive":
|
| 340 |
+
risk_level = "LOW_RISK"
|
| 341 |
+
elif best_action == "running":
|
| 342 |
+
risk_level = "LOW_RISK" if risk_score >= 0.5 else "SAFE"
|
| 343 |
+
else:
|
| 344 |
+
risk_level = "SAFE"
|
| 345 |
+
|
| 346 |
+
self.prev_people = people
|
| 347 |
+
return {
|
| 348 |
+
"risk_level": risk_level,
|
| 349 |
+
"action": best_action,
|
| 350 |
+
"risk_score": float(min(risk_score, 1.5)),
|
| 351 |
+
"details": action_scores,
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
def _extract_people(self, result):
|
| 355 |
+
keypoints = getattr(result, "keypoints", None)
|
| 356 |
+
if keypoints is None or keypoints.xy is None or len(keypoints.xy) == 0:
|
| 357 |
+
return []
|
| 358 |
+
|
| 359 |
+
xy_sets = keypoints.xy.cpu().numpy()
|
| 360 |
+
conf_sets = keypoints.conf.cpu().numpy() if getattr(keypoints, "conf", None) is not None else None
|
| 361 |
+
people = []
|
| 362 |
+
|
| 363 |
+
for i, points in enumerate(xy_sets):
|
| 364 |
+
conf = conf_sets[i] if conf_sets is not None else np.ones(points.shape[0], dtype=np.float32)
|
| 365 |
+
valid = conf >= self.keypoint_conf_threshold
|
| 366 |
+
if np.count_nonzero(valid) < 5:
|
| 367 |
+
continue
|
| 368 |
+
|
| 369 |
+
torso = self._torso_scale(points, conf)
|
| 370 |
+
center = self._person_center(points, conf)
|
| 371 |
+
arm_extension = self._arm_extension(points, conf, torso)
|
| 372 |
+
people.append(
|
| 373 |
+
{
|
| 374 |
+
"points": points,
|
| 375 |
+
"conf": conf,
|
| 376 |
+
"torso": torso,
|
| 377 |
+
"center": center,
|
| 378 |
+
"arm_extension": arm_extension,
|
| 379 |
+
}
|
| 380 |
+
)
|
| 381 |
+
return people
|
| 382 |
+
|
| 383 |
+
@staticmethod
|
| 384 |
+
def _distance(a, b):
|
| 385 |
+
return float(np.linalg.norm(np.array(a, dtype=np.float32) - np.array(b, dtype=np.float32)))
|
| 386 |
+
|
| 387 |
+
def _torso_scale(self, points, conf):
|
| 388 |
+
pairs = [(5, 11), (6, 12), (5, 6), (11, 12)]
|
| 389 |
+
lengths = []
|
| 390 |
+
for a, b in pairs:
|
| 391 |
+
if conf[a] >= self.keypoint_conf_threshold and conf[b] >= self.keypoint_conf_threshold:
|
| 392 |
+
lengths.append(self._distance(points[a], points[b]))
|
| 393 |
+
if lengths:
|
| 394 |
+
return max(8.0, float(np.mean(lengths)))
|
| 395 |
+
return 40.0
|
| 396 |
+
|
| 397 |
+
def _person_center(self, points, conf):
|
| 398 |
+
preferred = [5, 6, 11, 12]
|
| 399 |
+
coords = [points[idx] for idx in preferred if conf[idx] >= self.keypoint_conf_threshold]
|
| 400 |
+
if not coords:
|
| 401 |
+
coords = [points[idx] for idx in range(points.shape[0]) if conf[idx] >= self.keypoint_conf_threshold]
|
| 402 |
+
if not coords:
|
| 403 |
+
return np.array([0.0, 0.0], dtype=np.float32)
|
| 404 |
+
return np.mean(np.array(coords, dtype=np.float32), axis=0)
|
| 405 |
+
|
| 406 |
+
def _arm_extension(self, points, conf, torso):
|
| 407 |
+
extensions = []
|
| 408 |
+
for shoulder_idx, wrist_idx in [(5, 9), (6, 10)]:
|
| 409 |
+
if conf[shoulder_idx] >= self.keypoint_conf_threshold and conf[wrist_idx] >= self.keypoint_conf_threshold:
|
| 410 |
+
extensions.append(self._distance(points[shoulder_idx], points[wrist_idx]) / max(torso, 1e-6))
|
| 411 |
+
if not extensions:
|
| 412 |
+
return 0.0
|
| 413 |
+
return float(np.max(extensions))
|
| 414 |
+
|
| 415 |
+
def _compute_motion(self, people):
|
| 416 |
+
if not self.prev_people:
|
| 417 |
+
return {"body_speed": 0.0, "arm_speed": 0.0, "leg_speed": 0.0}
|
| 418 |
+
|
| 419 |
+
count = min(len(people), len(self.prev_people))
|
| 420 |
+
if count == 0:
|
| 421 |
+
return {"body_speed": 0.0, "arm_speed": 0.0, "leg_speed": 0.0}
|
| 422 |
+
|
| 423 |
+
body_speeds = []
|
| 424 |
+
arm_speeds = []
|
| 425 |
+
leg_speeds = []
|
| 426 |
+
|
| 427 |
+
for i in range(count):
|
| 428 |
+
cur = people[i]
|
| 429 |
+
prev = self.prev_people[i]
|
| 430 |
+
scale = max((cur["torso"] + prev["torso"]) * 0.5, 1e-6)
|
| 431 |
+
|
| 432 |
+
body_speeds.append(self._distance(cur["center"], prev["center"]) / scale)
|
| 433 |
+
arm_speeds.append(self._average_joint_speed(cur, prev, [7, 8, 9, 10], scale))
|
| 434 |
+
leg_speeds.append(self._average_joint_speed(cur, prev, [13, 14, 15, 16], scale))
|
| 435 |
+
|
| 436 |
+
return {
|
| 437 |
+
"body_speed": float(np.max(body_speeds) if body_speeds else 0.0),
|
| 438 |
+
"arm_speed": float(np.max(arm_speeds) if arm_speeds else 0.0),
|
| 439 |
+
"leg_speed": float(np.max(leg_speeds) if leg_speeds else 0.0),
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
def _average_joint_speed(self, cur, prev, indices, scale):
|
| 443 |
+
values = []
|
| 444 |
+
for idx in indices:
|
| 445 |
+
if cur["conf"][idx] >= self.keypoint_conf_threshold and prev["conf"][idx] >= self.keypoint_conf_threshold:
|
| 446 |
+
values.append(self._distance(cur["points"][idx], prev["points"][idx]) / scale)
|
| 447 |
+
if not values:
|
| 448 |
+
return 0.0
|
| 449 |
+
return float(np.mean(values))
|
| 450 |
+
|
| 451 |
+
@staticmethod
|
| 452 |
+
def _compute_crowd_features(people):
|
| 453 |
+
if len(people) < 2:
|
| 454 |
+
return {"close_people": 0.0, "rapid_multi_person": 0.0, "arm_extension": float(max(p["arm_extension"] for p in people))}
|
| 455 |
+
|
| 456 |
+
min_center_dist = 1e9
|
| 457 |
+
avg_torso = max(8.0, float(np.mean([p["torso"] for p in people])))
|
| 458 |
+
for i in range(len(people)):
|
| 459 |
+
for j in range(i + 1, len(people)):
|
| 460 |
+
d = float(np.linalg.norm(people[i]["center"] - people[j]["center"]))
|
| 461 |
+
min_center_dist = min(min_center_dist, d)
|
| 462 |
+
|
| 463 |
+
close_people = max(0.0, 1.2 - (min_center_dist / (avg_torso * 2.2)))
|
| 464 |
+
rapid_multi_person = 1.0 if close_people > 0.35 else 0.0
|
| 465 |
+
arm_extension = float(max(p["arm_extension"] for p in people))
|
| 466 |
+
return {
|
| 467 |
+
"close_people": float(close_people),
|
| 468 |
+
"rapid_multi_person": float(rapid_multi_person),
|
| 469 |
+
"arm_extension": arm_extension,
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
def _compute_pose_features(self, people):
|
| 473 |
+
"""Compute detailed pose features for all violence types."""
|
| 474 |
+
features = {
|
| 475 |
+
"leg_extension": 0.0,
|
| 476 |
+
"head_position": [0.0, 0.0],
|
| 477 |
+
"body_angle": 0.0,
|
| 478 |
+
"arm_angles": [0.0, 0.0],
|
| 479 |
+
"upward_motion": 0.0,
|
| 480 |
+
"downward_motion": 0.0,
|
| 481 |
+
"ground_contact": False,
|
| 482 |
+
}
|
| 483 |
+
if not people:
|
| 484 |
+
return features
|
| 485 |
+
|
| 486 |
+
person = people[0]
|
| 487 |
+
points = person["points"]
|
| 488 |
+
conf = person["conf"]
|
| 489 |
+
|
| 490 |
+
# Leg extension
|
| 491 |
+
if conf[13] >= self.keypoint_conf_threshold and conf[14] >= self.keypoint_conf_threshold:
|
| 492 |
+
leg_ext = self._distance(points[13], points[14]) / max(person["torso"], 1e-6)
|
| 493 |
+
features["leg_extension"] = float(leg_ext)
|
| 494 |
+
|
| 495 |
+
# Head position
|
| 496 |
+
if conf[0] >= self.keypoint_conf_threshold:
|
| 497 |
+
features["head_position"] = list(map(float, points[0]))
|
| 498 |
+
|
| 499 |
+
# Body angle (from head to hip)
|
| 500 |
+
if conf[0] >= self.keypoint_conf_threshold and conf[11] >= self.keypoint_conf_threshold:
|
| 501 |
+
vec = points[11] - points[0]
|
| 502 |
+
angle = float(np.arctan2(vec[1], vec[0]) * 180 / np.pi)
|
| 503 |
+
features["body_angle"] = angle
|
| 504 |
+
|
| 505 |
+
# Arm angles
|
| 506 |
+
arm_angles = []
|
| 507 |
+
for shoulder, elbow, wrist in [(5, 7, 9), (6, 8, 10)]:
|
| 508 |
+
if (conf[shoulder] >= self.keypoint_conf_threshold and
|
| 509 |
+
conf[elbow] >= self.keypoint_conf_threshold and
|
| 510 |
+
conf[wrist] >= self.keypoint_conf_threshold):
|
| 511 |
+
v1 = points[shoulder] - points[elbow]
|
| 512 |
+
v2 = points[wrist] - points[elbow]
|
| 513 |
+
denom = np.linalg.norm(v1) * np.linalg.norm(v2)
|
| 514 |
+
if denom > 1e-6:
|
| 515 |
+
cos_angle = np.dot(v1, v2) / denom
|
| 516 |
+
angle = np.arccos(np.clip(cos_angle, -1, 1)) * 180 / np.pi
|
| 517 |
+
arm_angles.append(float(angle))
|
| 518 |
+
if arm_angles:
|
| 519 |
+
features["arm_angles"] = arm_angles
|
| 520 |
+
|
| 521 |
+
# Vertical motion if previous frame exists
|
| 522 |
+
if self.prev_people and len(self.prev_people) > 0:
|
| 523 |
+
prev = self.prev_people[0]
|
| 524 |
+
prev_torso = max(prev["torso"], 1e-6)
|
| 525 |
+
head_motion = self._distance(points[0], prev["points"][0]) if conf[0] >= self.keypoint_conf_threshold else 0.0
|
| 526 |
+
hip_motion = self._distance(person["center"], prev["center"])
|
| 527 |
+
if points[0][1] < prev["points"][0][1]:
|
| 528 |
+
features["upward_motion"] = head_motion / prev_torso
|
| 529 |
+
else:
|
| 530 |
+
features["downward_motion"] = head_motion / prev_torso
|
| 531 |
+
|
| 532 |
+
# Ground contact (low ankle/knee positions)
|
| 533 |
+
if conf[15] >= self.keypoint_conf_threshold or conf[16] >= self.keypoint_conf_threshold:
|
| 534 |
+
ankle_y = min(points[15][1] if conf[15] >= self.keypoint_conf_threshold else 1e9,
|
| 535 |
+
points[16][1] if conf[16] >= self.keypoint_conf_threshold else 1e9)
|
| 536 |
+
hip_y = person["center"][1]
|
| 537 |
+
if ankle_y > hip_y * 0.8:
|
| 538 |
+
features["ground_contact"] = True
|
| 539 |
+
|
| 540 |
+
return features
|
| 541 |
+
|
| 542 |
+
def _score_punch(self, speed_stats, crowd_stats, pose_stats):
|
| 543 |
+
"""Detect punching motion: high arm speed, extended arm, close to another person."""
|
| 544 |
+
score = speed_stats["arm_speed"] * 1.5
|
| 545 |
+
if crowd_stats["close_people"] > 0.3:
|
| 546 |
+
score += crowd_stats["close_people"] * 0.7
|
| 547 |
+
if crowd_stats["arm_extension"] > 1.1:
|
| 548 |
+
score += 0.4
|
| 549 |
+
score = max(0.0, score - 0.5)
|
| 550 |
+
return float(min(score, 1.0))
|
| 551 |
+
|
| 552 |
+
def _score_kick(self, speed_stats, crowd_stats, pose_stats):
|
| 553 |
+
"""Detect kicking motion: very high leg speed, leg extension."""
|
| 554 |
+
score = speed_stats["leg_speed"] * 1.8
|
| 555 |
+
if pose_stats["leg_extension"] > 1.2:
|
| 556 |
+
score += 0.5
|
| 557 |
+
score = max(0.0, score - 0.55)
|
| 558 |
+
return float(min(score, 1.0))
|
| 559 |
+
|
| 560 |
+
def _score_slap(self, speed_stats, crowd_stats, pose_stats):
|
| 561 |
+
"""Detect slapping motion: high arm speed, close range, high arm extension."""
|
| 562 |
+
score = speed_stats["arm_speed"] * 1.4
|
| 563 |
+
if crowd_stats["close_people"] > 0.35:
|
| 564 |
+
score += 0.5
|
| 565 |
+
if crowd_stats["arm_extension"] > 1.0:
|
| 566 |
+
score += 0.3
|
| 567 |
+
score = max(0.0, score - 0.6)
|
| 568 |
+
return float(min(score, 1.0))
|
| 569 |
+
|
| 570 |
+
def _score_push(self, speed_stats, crowd_stats, pose_stats):
|
| 571 |
+
"""Detect pushing motion: arm speed + body speed, close proximity."""
|
| 572 |
+
score = (speed_stats["arm_speed"] * 0.9 + speed_stats["body_speed"] * 1.0)
|
| 573 |
+
if crowd_stats["close_people"] > 0.25:
|
| 574 |
+
score += 0.6
|
| 575 |
+
score = max(0.0, score - 0.45)
|
| 576 |
+
return float(min(score, 1.0))
|
| 577 |
+
|
| 578 |
+
def _score_throw(self, speed_stats, crowd_stats, pose_stats):
|
| 579 |
+
"""Detect throwing motion: upward arm motion + body extension."""
|
| 580 |
+
score = 0.0
|
| 581 |
+
if pose_stats["upward_motion"] > 0.3:
|
| 582 |
+
score += pose_stats["upward_motion"] * 1.5
|
| 583 |
+
score += speed_stats["arm_speed"] * 0.8
|
| 584 |
+
if crowd_stats["arm_extension"] > 1.3:
|
| 585 |
+
score += 0.4
|
| 586 |
+
score = max(0.0, score - 0.3)
|
| 587 |
+
return float(min(score, 1.0))
|
| 588 |
+
|
| 589 |
+
def _score_headbutt(self, speed_stats, crowd_stats, pose_stats):
|
| 590 |
+
"""Detect headbutt: rapid head motion toward another person."""
|
| 591 |
+
score = 0.0
|
| 592 |
+
if crowd_stats["close_people"] > 0.5:
|
| 593 |
+
score += 0.8
|
| 594 |
+
if pose_stats["body_angle"] != 0.0:
|
| 595 |
+
score += abs(pose_stats["body_angle"]) / 180.0 * 0.5
|
| 596 |
+
score += speed_stats["body_speed"] * 0.6
|
| 597 |
+
score = max(0.0, score - 0.4)
|
| 598 |
+
return float(min(score, 1.0))
|
| 599 |
+
|
| 600 |
+
def _score_choking(self, speed_stats, crowd_stats, pose_stats):
|
| 601 |
+
"""Detect choking: arms around neck area, very close proximity."""
|
| 602 |
+
score = 0.0
|
| 603 |
+
if crowd_stats["close_people"] > 0.6:
|
| 604 |
+
score += 0.9
|
| 605 |
+
if len(pose_stats["arm_angles"]) > 0 and min(pose_stats["arm_angles"]) < 90:
|
| 606 |
+
score += 0.6
|
| 607 |
+
score += speed_stats["arm_speed"] * 0.5
|
| 608 |
+
score = max(0.0, score - 0.5)
|
| 609 |
+
return float(min(score, 1.0))
|
| 610 |
+
|
| 611 |
+
def _score_weapon(self, speed_stats, crowd_stats, pose_stats):
|
| 612 |
+
"""Detect weapon use: irregular arm motion patterns, quick directional changes."""
|
| 613 |
+
score = 0.0
|
| 614 |
+
if speed_stats["arm_speed"] > 0.8:
|
| 615 |
+
score += 0.5
|
| 616 |
+
if speed_stats["leg_speed"] > 0.6:
|
| 617 |
+
score += 0.3
|
| 618 |
+
if crowd_stats["close_people"] > 0.3:
|
| 619 |
+
score += 0.3
|
| 620 |
+
score = max(0.0, score - 0.35)
|
| 621 |
+
return float(min(score, 1.0))
|
| 622 |
+
|
| 623 |
+
def _score_aggressive_grab(self, speed_stats, crowd_stats, pose_stats):
|
| 624 |
+
"""Detect aggressive grabbing: very close proximity, intertwined arms."""
|
| 625 |
+
score = 0.0
|
| 626 |
+
if crowd_stats["close_people"] > 0.55:
|
| 627 |
+
score += 0.9
|
| 628 |
+
if len(pose_stats["arm_angles"]) > 0 and any(a < 100 for a in pose_stats["arm_angles"]):
|
| 629 |
+
score += 0.5
|
| 630 |
+
score += speed_stats["arm_speed"] * 0.6
|
| 631 |
+
score = max(0.0, score - 0.4)
|
| 632 |
+
return float(min(score, 1.0))
|
| 633 |
+
|
| 634 |
+
def _score_grappling(self, speed_stats, crowd_stats, pose_stats):
|
| 635 |
+
"""Detect grappling: multiple body contacts, wrestling motion."""
|
| 636 |
+
score = 0.0
|
| 637 |
+
if self.prev_people and len(self.prev_people) > 0:
|
| 638 |
+
score += speed_stats["body_speed"] * 1.1
|
| 639 |
+
if crowd_stats["close_people"] > 0.5:
|
| 640 |
+
score += 0.7
|
| 641 |
+
if crowd_stats["rapid_multi_person"] > 0.5:
|
| 642 |
+
score += 0.6
|
| 643 |
+
score = max(0.0, score - 0.45)
|
| 644 |
+
return float(min(score, 1.0))
|
| 645 |
+
|
| 646 |
+
def _score_falling(self, speed_stats, crowd_stats, pose_stats):
|
| 647 |
+
"""Detect falling or being knocked down: rapid downward motion, ground contact."""
|
| 648 |
+
score = 0.0
|
| 649 |
+
if pose_stats["downward_motion"] > 0.4:
|
| 650 |
+
score += pose_stats["downward_motion"] * 1.5
|
| 651 |
+
if pose_stats["ground_contact"]:
|
| 652 |
+
score += 0.8
|
| 653 |
+
score += speed_stats["body_speed"] * 0.7
|
| 654 |
+
score = max(0.0, score - 0.35)
|
| 655 |
+
return float(min(score, 1.0))
|
| 656 |
+
|
| 657 |
+
def _score_defensive(self, speed_stats, crowd_stats, pose_stats):
|
| 658 |
+
"""Detect defensive posture: arms raised, body protected."""
|
| 659 |
+
score = 0.0
|
| 660 |
+
if len(pose_stats["arm_angles"]) > 0:
|
| 661 |
+
avg_angle = np.mean(pose_stats["arm_angles"])
|
| 662 |
+
if avg_angle < 120:
|
| 663 |
+
score += 0.6
|
| 664 |
+
if abs(pose_stats["body_angle"]) > 30:
|
| 665 |
+
score += 0.4
|
| 666 |
+
score = max(0.0, score - 0.3)
|
| 667 |
+
return float(min(score, 1.0))
|
| 668 |
+
|
| 669 |
+
def _score_running(self, speed_stats, crowd_stats, pose_stats):
|
| 670 |
+
"""Detect running: high leg speed, high body speed."""
|
| 671 |
+
score = max(0.0, speed_stats["body_speed"] * 1.2 + speed_stats["leg_speed"] * 0.8 - 0.35)
|
| 672 |
+
return float(min(score, 1.0))
|
src/detectors/violence_detector.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Violence Detection Module
|
| 3 |
+
Loads and runs inference using the trained violence detection model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pickle
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, List, Tuple, Dict, Any
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
# Model configuration
|
| 14 |
+
MODEL_PATH = Path(__file__).parent / "model" / "violence_model.h5"
|
| 15 |
+
LABEL_PATH = Path(__file__).parent / "model" / "lb.pickle"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class ViolenceDetection:
|
| 20 |
+
"""Violence detection result."""
|
| 21 |
+
class_name: str = "Unknown"
|
| 22 |
+
confidence: float = 0.0
|
| 23 |
+
is_violence: bool = False
|
| 24 |
+
bbox: Optional[Tuple[int, int, int, int]] = None # For ROI-based detection
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def alert_level(self) -> str:
|
| 28 |
+
"""Get alert level based on confidence."""
|
| 29 |
+
if not self.is_violence:
|
| 30 |
+
return "SAFE"
|
| 31 |
+
elif self.confidence >= 0.8:
|
| 32 |
+
return "HIGH RISK"
|
| 33 |
+
elif self.confidence >= 0.6:
|
| 34 |
+
return "MEDIUM RISK"
|
| 35 |
+
else:
|
| 36 |
+
return "LOW RISK"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ViolenceDetector:
|
| 40 |
+
"""Violence detection using trained CNN model."""
|
| 41 |
+
|
| 42 |
+
def __init__(self,
|
| 43 |
+
model_path: str = str(MODEL_PATH),
|
| 44 |
+
label_path: str = str(LABEL_PATH),
|
| 45 |
+
input_size: Tuple[int, int] = (224, 224)):
|
| 46 |
+
"""
|
| 47 |
+
Initialize violence detector.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
model_path: Path to trained Keras model (.h5 file)
|
| 51 |
+
label_path: Path to label encoder pickle file
|
| 52 |
+
input_size: Expected input size for the model
|
| 53 |
+
"""
|
| 54 |
+
try:
|
| 55 |
+
import tensorflow as tf
|
| 56 |
+
from tensorflow import keras
|
| 57 |
+
except ImportError:
|
| 58 |
+
raise ImportError(
|
| 59 |
+
"TensorFlow not installed. Run: pip install tensorflow"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.model_path = model_path
|
| 63 |
+
self.label_path = label_path
|
| 64 |
+
self.input_size = input_size
|
| 65 |
+
|
| 66 |
+
# Load the trained model
|
| 67 |
+
try:
|
| 68 |
+
self.model = keras.models.load_model(model_path, compile=False)
|
| 69 |
+
print(f"Loaded violence detection model: {model_path}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
raise FileNotFoundError(f"Could not load model from {model_path}: {e}")
|
| 72 |
+
|
| 73 |
+
# Load label encoder
|
| 74 |
+
try:
|
| 75 |
+
with open(label_path, 'rb') as f:
|
| 76 |
+
self.label_encoder = pickle.load(f)
|
| 77 |
+
print(f"Loaded label encoder: {label_path}")
|
| 78 |
+
print(f"Classes: {self.label_encoder.classes_}")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise FileNotFoundError(f"Could not load label encoder from {label_path}: {e}")
|
| 81 |
+
|
| 82 |
+
# Get model input shape
|
| 83 |
+
input_shape = self.model.input_shape
|
| 84 |
+
if len(input_shape) >= 3 and input_shape[1] is not None and input_shape[2] is not None:
|
| 85 |
+
self.input_size = (input_shape[1], input_shape[2])
|
| 86 |
+
print(f"Model input size: {self.input_size}")
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def class_names(self) -> List[str]:
|
| 90 |
+
"""Get list of class names."""
|
| 91 |
+
return list(self.label_encoder.classes_)
|
| 92 |
+
|
| 93 |
+
def preprocess_frame(self, frame: np.ndarray) -> np.ndarray:
|
| 94 |
+
"""
|
| 95 |
+
Preprocess frame for violence detection model.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
frame: BGR image (numpy array)
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Preprocessed frame ready for inference
|
| 102 |
+
"""
|
| 103 |
+
# Resize to model input size
|
| 104 |
+
resized = cv2.resize(frame, self.input_size, interpolation=cv2.INTER_AREA)
|
| 105 |
+
|
| 106 |
+
# Convert BGR to RGB
|
| 107 |
+
rgb_frame = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
| 108 |
+
|
| 109 |
+
# Normalize to [0, 1]
|
| 110 |
+
normalized = rgb_frame.astype(np.float32) / 255.0
|
| 111 |
+
|
| 112 |
+
# Add batch dimension
|
| 113 |
+
batch_frame = np.expand_dims(normalized, axis=0)
|
| 114 |
+
|
| 115 |
+
return batch_frame
|
| 116 |
+
|
| 117 |
+
def detect_violence(self,
|
| 118 |
+
frame: np.ndarray,
|
| 119 |
+
confidence_threshold: float = 0.5) -> ViolenceDetection:
|
| 120 |
+
"""
|
| 121 |
+
Detect violence in a single frame.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
frame: BGR image (numpy array)
|
| 125 |
+
confidence_threshold: Minimum confidence for violence detection
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
ViolenceDetection object with results
|
| 129 |
+
"""
|
| 130 |
+
# Preprocess frame
|
| 131 |
+
preprocessed = self.preprocess_frame(frame)
|
| 132 |
+
|
| 133 |
+
# Run inference
|
| 134 |
+
predictions = self.model.predict(preprocessed, verbose=0)[0]
|
| 135 |
+
|
| 136 |
+
# Get predicted class and confidence
|
| 137 |
+
predicted_idx = np.argmax(predictions)
|
| 138 |
+
confidence = float(predictions[predicted_idx])
|
| 139 |
+
class_name = self.label_encoder.classes_[predicted_idx]
|
| 140 |
+
|
| 141 |
+
# Determine if violence detected
|
| 142 |
+
is_violence = self._is_violent_class(class_name) and confidence >= confidence_threshold
|
| 143 |
+
|
| 144 |
+
return ViolenceDetection(
|
| 145 |
+
class_name=class_name,
|
| 146 |
+
confidence=confidence,
|
| 147 |
+
is_violence=is_violence
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def detect_batch(self,
|
| 151 |
+
frames: List[np.ndarray],
|
| 152 |
+
confidence_threshold: float = 0.5) -> List[ViolenceDetection]:
|
| 153 |
+
"""
|
| 154 |
+
Detect violence in multiple frames (batch processing).
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
frames: List of BGR images
|
| 158 |
+
confidence_threshold: Minimum confidence for violence detection
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
List of ViolenceDetection objects
|
| 162 |
+
"""
|
| 163 |
+
if not frames:
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
# Preprocess all frames
|
| 167 |
+
batch = np.vstack([self.preprocess_frame(frame) for frame in frames])
|
| 168 |
+
|
| 169 |
+
# Run batch inference
|
| 170 |
+
predictions = self.model.predict(batch, verbose=0)
|
| 171 |
+
|
| 172 |
+
results = []
|
| 173 |
+
for i, pred in enumerate(predictions):
|
| 174 |
+
predicted_idx = np.argmax(pred)
|
| 175 |
+
confidence = float(pred[predicted_idx])
|
| 176 |
+
class_name = self.label_encoder.classes_[predicted_idx]
|
| 177 |
+
|
| 178 |
+
is_violence = self._is_violent_class(class_name) and confidence >= confidence_threshold
|
| 179 |
+
|
| 180 |
+
results.append(ViolenceDetection(
|
| 181 |
+
class_name=class_name,
|
| 182 |
+
confidence=confidence,
|
| 183 |
+
is_violence=is_violence
|
| 184 |
+
))
|
| 185 |
+
|
| 186 |
+
return results
|
| 187 |
+
|
| 188 |
+
def _is_violent_class(self, class_name: str) -> bool:
|
| 189 |
+
"""
|
| 190 |
+
Determine if a class name indicates violence.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
class_name: Name of the predicted class
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
True if class indicates violence
|
| 197 |
+
"""
|
| 198 |
+
# Common violence class names (adjust based on your model's classes)
|
| 199 |
+
violence_keywords = [
|
| 200 |
+
'violence', 'violent', 'fight', 'fighting', 'assault', 'attack',
|
| 201 |
+
'aggression', 'aggressive', 'hitting', 'punch', 'kick', 'weapon',
|
| 202 |
+
'gun', 'knife', 'sword', 'bat', 'stick'
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
class_lower = class_name.lower()
|
| 206 |
+
return any(keyword in class_lower for keyword in violence_keywords)
|
| 207 |
+
|
| 208 |
+
def draw_violence_detection(self,
|
| 209 |
+
frame: np.ndarray,
|
| 210 |
+
detection: ViolenceDetection,
|
| 211 |
+
roi_bbox: Optional[Tuple[int, int, int, int]] = None) -> np.ndarray:
|
| 212 |
+
"""
|
| 213 |
+
Draw violence detection results on frame.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
frame: Original BGR frame
|
| 217 |
+
detection: ViolenceDetection object
|
| 218 |
+
roi_bbox: ROI bounding box (x, y, w, h) if detection was on ROI
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Frame with detection visualization
|
| 222 |
+
"""
|
| 223 |
+
display_frame = frame.copy()
|
| 224 |
+
|
| 225 |
+
# Choose colors based on alert level
|
| 226 |
+
if detection.is_violence:
|
| 227 |
+
if detection.confidence >= 0.8:
|
| 228 |
+
color = (0, 0, 255) # Red - High risk
|
| 229 |
+
elif detection.confidence >= 0.6:
|
| 230 |
+
color = (0, 165, 255) # Orange - Medium risk
|
| 231 |
+
else:
|
| 232 |
+
color = (0, 255, 255) # Yellow - Low risk
|
| 233 |
+
else:
|
| 234 |
+
color = (0, 255, 0) # Green - Safe
|
| 235 |
+
|
| 236 |
+
# Draw ROI box if provided
|
| 237 |
+
if roi_bbox is not None:
|
| 238 |
+
x, y, w, h = roi_bbox
|
| 239 |
+
cv2.rectangle(display_frame, (x, y), (x + w, y + h), color, 3)
|
| 240 |
+
|
| 241 |
+
# Draw alert banner
|
| 242 |
+
alert_text = f"{detection.alert_level}: {detection.class_name}"
|
| 243 |
+
confidence_text = f"Confidence: {detection.confidence:.2f}"
|
| 244 |
+
|
| 245 |
+
# Background rectangle for text
|
| 246 |
+
text_size = cv2.getTextSize(alert_text, cv2.FONT_HERSHEY_SIMPLEX, 1.2, 3)[0]
|
| 247 |
+
cv2.rectangle(display_frame, (10, 10),
|
| 248 |
+
(text_size[0] + 20, 80), color, -1)
|
| 249 |
+
|
| 250 |
+
# White text on colored background
|
| 251 |
+
cv2.putText(display_frame, alert_text, (15, 40),
|
| 252 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 3)
|
| 253 |
+
cv2.putText(display_frame, confidence_text, (15, 70),
|
| 254 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
|
| 255 |
+
|
| 256 |
+
return display_frame
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def main():
|
| 260 |
+
"""Demo: Test violence detection on webcam feed."""
|
| 261 |
+
|
| 262 |
+
detector = ViolenceDetector()
|
| 263 |
+
cap = cv2.VideoCapture(0)
|
| 264 |
+
|
| 265 |
+
print("Press 'q' to quit")
|
| 266 |
+
print("Violence detection active")
|
| 267 |
+
print(f"Classes: {detector.class_names}")
|
| 268 |
+
|
| 269 |
+
while True:
|
| 270 |
+
ret, frame = cap.read()
|
| 271 |
+
if not ret:
|
| 272 |
+
break
|
| 273 |
+
|
| 274 |
+
# Detect violence
|
| 275 |
+
detection = detector.detect_violence(frame, confidence_threshold=0.5)
|
| 276 |
+
|
| 277 |
+
# Draw results
|
| 278 |
+
display_frame = detector.draw_violence_detection(frame, detection)
|
| 279 |
+
|
| 280 |
+
# Show frame
|
| 281 |
+
cv2.imshow("Violence Detection", display_frame)
|
| 282 |
+
|
| 283 |
+
# Print alerts to console
|
| 284 |
+
if detection.is_violence:
|
| 285 |
+
print(f"ALERT: {detection.alert_level} - {detection.class_name} "
|
| 286 |
+
f"(Confidence: {detection.confidence:.2f})")
|
| 287 |
+
|
| 288 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 289 |
+
break
|
| 290 |
+
|
| 291 |
+
cap.release()
|
| 292 |
+
cv2.destroyAllWindows()
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
src/detectors/weapon_detector.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Weapon & Person Detection System
|
| 3 |
+
Detects weapons in frame and counts persons visible while weapon is present.
|
| 4 |
+
Alert stays active until the person with the weapon leaves the frame.
|
| 5 |
+
|
| 6 |
+
Uses two YOLO models:
|
| 7 |
+
- best.pt (ai_models/wepan_detection) or GunDetector.pt fallback β custom weapon model
|
| 8 |
+
- yolov8n.pt β COCO model for person detection (class 0: person)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Tuple, Optional, Dict
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from ..pipeline import VideoCapture
|
| 20 |
+
from .yolo_detector import YOLODetector, Detection
|
| 21 |
+
except ImportError:
|
| 22 |
+
# Fallback for direct execution
|
| 23 |
+
import sys
|
| 24 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 25 |
+
from src.pipeline import VideoCapture
|
| 26 |
+
from src.detectors.yolo_detector import YOLODetector, Detection
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Model paths
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
_BASE = Path(__file__).parent
|
| 32 |
+
_PROJECT_ROOT = _BASE.parent
|
| 33 |
+
|
| 34 |
+
# Prefer the shared trained model under ai_models, then fall back to legacy path.
|
| 35 |
+
_WEAPON_MODEL_CANDIDATES = [
|
| 36 |
+
_PROJECT_ROOT / "ai_models" / "wepan_detection" / "best.pt",
|
| 37 |
+
_PROJECT_ROOT / "ai_models" / "weapon_detection" / "best.pt",
|
| 38 |
+
_BASE / "model" / "GunDetector.pt",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def _resolve_weapon_model_path() -> str:
|
| 42 |
+
for candidate in _WEAPON_MODEL_CANDIDATES:
|
| 43 |
+
if candidate.exists():
|
| 44 |
+
return str(candidate)
|
| 45 |
+
# Keep first candidate as default so callers get a clear file-not-found path if missing.
|
| 46 |
+
return str(_WEAPON_MODEL_CANDIDATES[0])
|
| 47 |
+
|
| 48 |
+
GUN_MODEL_PATH = _resolve_weapon_model_path()
|
| 49 |
+
|
| 50 |
+
_PERSON_MODEL_CANDIDATES = [
|
| 51 |
+
_PROJECT_ROOT / "ai_models" / "object_detection" / "yolov8n.pt",
|
| 52 |
+
_BASE / "model" / "yolov8s.pt",
|
| 53 |
+
_BASE / "model" / "yolov8n.pt",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
def _resolve_person_model_path() -> str:
|
| 57 |
+
for candidate in _PERSON_MODEL_CANDIDATES:
|
| 58 |
+
if candidate.exists():
|
| 59 |
+
return str(candidate)
|
| 60 |
+
return str(_PERSON_MODEL_CANDIDATES[0])
|
| 61 |
+
|
| 62 |
+
PERSON_MODEL_PATH = _resolve_person_model_path()
|
| 63 |
+
|
| 64 |
+
# COCO class IDs from yolov8n
|
| 65 |
+
PERSON_CLASS_ID = 0
|
| 66 |
+
KNIFE_CLASS_ID = 43 # "knife" in COCO
|
| 67 |
+
|
| 68 |
+
# How many consecutive "no-weapon" frames before we clear the alert.
|
| 69 |
+
WEAPON_COOLDOWN_FRAMES = 10
|
| 70 |
+
|
| 71 |
+
# How many consecutive frames a weapon must appear before alert triggers.
|
| 72 |
+
# This prevents random one-off false positives from firing an alert.
|
| 73 |
+
WEAPON_CONFIRM_FRAMES = 3
|
| 74 |
+
|
| 75 |
+
# Minimum bounding-box area (pixels) to accept a gun detection.
|
| 76 |
+
# Tiny boxes are almost always false positives.
|
| 77 |
+
MIN_WEAPON_AREA = 1500
|
| 78 |
+
|
| 79 |
+
# Gun model needs a higher confidence bar because it produces false positives.
|
| 80 |
+
GUN_CONF_THRESHOLD = 0.60
|
| 81 |
+
|
| 82 |
+
# Person / knife confidence can stay lower (COCO model is reliable).
|
| 83 |
+
PERSON_CONF_THRESHOLD = 0.30
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class FrameResult:
|
| 88 |
+
"""Result of analysing a single frame."""
|
| 89 |
+
weapons: List[Detection] = field(default_factory=list)
|
| 90 |
+
persons: List[Detection] = field(default_factory=list)
|
| 91 |
+
alert_active: bool = False
|
| 92 |
+
person_count: int = 0
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class WeaponPersonDetector:
|
| 96 |
+
"""Real-time weapon detection with person counting using dual YOLO models."""
|
| 97 |
+
|
| 98 |
+
def __init__(self,
|
| 99 |
+
gun_model_path: str = GUN_MODEL_PATH,
|
| 100 |
+
person_model_path: str = PERSON_MODEL_PATH,
|
| 101 |
+
gun_conf: float = GUN_CONF_THRESHOLD,
|
| 102 |
+
person_conf: float = PERSON_CONF_THRESHOLD,
|
| 103 |
+
cooldown_frames: int = WEAPON_COOLDOWN_FRAMES,
|
| 104 |
+
confirm_frames: int = WEAPON_CONFIRM_FRAMES,
|
| 105 |
+
min_weapon_area: int = MIN_WEAPON_AREA):
|
| 106 |
+
"""
|
| 107 |
+
Args:
|
| 108 |
+
gun_model_path: Path to weapon model weights (.pt).
|
| 109 |
+
person_model_path: Path to yolov8n.pt weights (for person detection).
|
| 110 |
+
gun_conf: Confidence threshold for gun detections (higher = fewer false positives).
|
| 111 |
+
person_conf: Confidence threshold for person/knife detections.
|
| 112 |
+
cooldown_frames: Frames without weapon before alert clears.
|
| 113 |
+
confirm_frames: Consecutive weapon frames needed to trigger alert.
|
| 114 |
+
min_weapon_area: Minimum bbox area (px) to accept a gun detection.
|
| 115 |
+
"""
|
| 116 |
+
print(f"Loading weapon model ({gun_model_path})...")
|
| 117 |
+
self.gun_detector = YOLODetector(gun_model_path)
|
| 118 |
+
print("Loading person model (yolov8n.pt)...")
|
| 119 |
+
self.person_detector = YOLODetector(person_model_path)
|
| 120 |
+
|
| 121 |
+
self.gun_conf = gun_conf
|
| 122 |
+
self.person_conf = person_conf
|
| 123 |
+
self.cooldown_frames = cooldown_frames
|
| 124 |
+
self.confirm_frames = confirm_frames
|
| 125 |
+
self.min_weapon_area = min_weapon_area
|
| 126 |
+
|
| 127 |
+
# State
|
| 128 |
+
self._alert_active = False
|
| 129 |
+
self._frames_since_last_weapon = 0
|
| 130 |
+
self._consecutive_weapon_frames = 0
|
| 131 |
+
self._alert_start_time: Optional[float] = None
|
| 132 |
+
|
| 133 |
+
# Build human-readable names
|
| 134 |
+
self.weapon_names = dict(self.gun_detector.class_names)
|
| 135 |
+
self.weapon_names[f"yolo_{KNIFE_CLASS_ID}"] = "knife"
|
| 136 |
+
self.person_names = {PERSON_CLASS_ID: self.person_detector.class_names.get(PERSON_CLASS_ID, "person")}
|
| 137 |
+
|
| 138 |
+
print(f"\nWeapon-Person Detector initialised")
|
| 139 |
+
print(f" Gun model : {gun_model_path}")
|
| 140 |
+
print(f" Gun confidence : {self.gun_conf} (high to reduce false positives)")
|
| 141 |
+
print(f" Person conf : {self.person_conf}")
|
| 142 |
+
print(f" Min weapon area: {self.min_weapon_area} px")
|
| 143 |
+
print(f" Confirm frames : {self.confirm_frames} (weapon must appear this many frames in a row)")
|
| 144 |
+
print(f" Cooldown : {self.cooldown_frames} frames")
|
| 145 |
+
print(f" Weapon classes : {self.weapon_names}")
|
| 146 |
+
|
| 147 |
+
# ------------------------------------------------------------------
|
| 148 |
+
# Core detection
|
| 149 |
+
# ------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
def process_frame(self, frame: np.ndarray) -> FrameResult:
|
| 152 |
+
"""Run both models on a single BGR frame and update alert state."""
|
| 153 |
+
# --- Gun detector (high confidence + size filter) ---
|
| 154 |
+
raw_guns = self.gun_detector.detect(
|
| 155 |
+
frame,
|
| 156 |
+
conf_threshold=self.gun_conf,
|
| 157 |
+
)
|
| 158 |
+
# Filter out small bounding boxes (false positives)
|
| 159 |
+
weapons = [d for d in raw_guns if d.area >= self.min_weapon_area]
|
| 160 |
+
|
| 161 |
+
# --- COCO model for person + knife ---
|
| 162 |
+
coco_detections = self.person_detector.detect(
|
| 163 |
+
frame,
|
| 164 |
+
conf_threshold=self.person_conf,
|
| 165 |
+
classes=[PERSON_CLASS_ID, KNIFE_CLASS_ID],
|
| 166 |
+
)
|
| 167 |
+
persons = [d for d in coco_detections if d.class_id == PERSON_CLASS_ID]
|
| 168 |
+
knives = [d for d in coco_detections if d.class_id == KNIFE_CLASS_ID]
|
| 169 |
+
weapons.extend(knives)
|
| 170 |
+
|
| 171 |
+
# --- Multi-frame confirmation before triggering alert ---
|
| 172 |
+
if weapons:
|
| 173 |
+
self._consecutive_weapon_frames += 1
|
| 174 |
+
self._frames_since_last_weapon = 0
|
| 175 |
+
else:
|
| 176 |
+
self._consecutive_weapon_frames = 0
|
| 177 |
+
self._frames_since_last_weapon += 1
|
| 178 |
+
|
| 179 |
+
# Only activate alert after weapon seen N frames in a row
|
| 180 |
+
if self._consecutive_weapon_frames >= self.confirm_frames:
|
| 181 |
+
if not self._alert_active:
|
| 182 |
+
self._alert_active = True
|
| 183 |
+
self._alert_start_time = time.time()
|
| 184 |
+
|
| 185 |
+
# Clear alert after cooldown with no weapons
|
| 186 |
+
if self._frames_since_last_weapon >= self.cooldown_frames:
|
| 187 |
+
self._alert_active = False
|
| 188 |
+
self._alert_start_time = None
|
| 189 |
+
|
| 190 |
+
return FrameResult(
|
| 191 |
+
weapons=weapons,
|
| 192 |
+
persons=persons,
|
| 193 |
+
alert_active=self._alert_active,
|
| 194 |
+
person_count=len(persons),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# ------------------------------------------------------------------
|
| 198 |
+
# Drawing / overlay helpers
|
| 199 |
+
# ------------------------------------------------------------------
|
| 200 |
+
|
| 201 |
+
def draw_results(self, frame: np.ndarray, result: FrameResult) -> np.ndarray:
|
| 202 |
+
"""Draw bounding boxes and compact HUD overlay on the frame."""
|
| 203 |
+
out = frame.copy()
|
| 204 |
+
h, w = out.shape[:2]
|
| 205 |
+
|
| 206 |
+
# --- Draw person boxes (thin green) ---
|
| 207 |
+
for i, det in enumerate(result.persons, 1):
|
| 208 |
+
x1, y1, x2, y2 = det.bbox
|
| 209 |
+
cv2.rectangle(out, (x1, y1), (x2, y2), (0, 200, 0), 1)
|
| 210 |
+
self._draw_label(out, f"P{i}", (x1, y1), bg_color=(0, 200, 0),
|
| 211 |
+
scale=0.35, thickness=1)
|
| 212 |
+
|
| 213 |
+
# --- Draw weapon boxes (thin red) ---
|
| 214 |
+
for det in result.weapons:
|
| 215 |
+
x1, y1, x2, y2 = det.bbox
|
| 216 |
+
cv2.rectangle(out, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
| 217 |
+
self._draw_label(out, f"{det.class_name} {det.confidence:.0%}",
|
| 218 |
+
(x1, y1), bg_color=(0, 0, 255),
|
| 219 |
+
scale=0.35, thickness=1)
|
| 220 |
+
|
| 221 |
+
# --- Compact top-right info box ---
|
| 222 |
+
if result.alert_active:
|
| 223 |
+
elapsed = time.time() - (self._alert_start_time or time.time())
|
| 224 |
+
weapon_names = ", ".join(d.class_name for d in result.weapons) or "last seen"
|
| 225 |
+
lines = [
|
| 226 |
+
f"WEAPON: {weapon_names}",
|
| 227 |
+
f"Persons: {result.person_count} Time: {elapsed:.0f}s",
|
| 228 |
+
]
|
| 229 |
+
box_color = (0, 0, 180)
|
| 230 |
+
else:
|
| 231 |
+
lines = [
|
| 232 |
+
"No Weapon",
|
| 233 |
+
f"Persons: {result.person_count}",
|
| 234 |
+
]
|
| 235 |
+
box_color = (0, 130, 0)
|
| 236 |
+
|
| 237 |
+
# Draw compact semi-transparent box in top-right corner
|
| 238 |
+
line_h = 16
|
| 239 |
+
pad = 4
|
| 240 |
+
box_h = len(lines) * line_h + pad * 2
|
| 241 |
+
max_tw = max(cv2.getTextSize(l, cv2.FONT_HERSHEY_SIMPLEX, 0.38, 1)[0][0] for l in lines)
|
| 242 |
+
box_w = max_tw + pad * 2
|
| 243 |
+
x0 = w - box_w - 5
|
| 244 |
+
y0 = 5
|
| 245 |
+
|
| 246 |
+
overlay = out.copy()
|
| 247 |
+
cv2.rectangle(overlay, (x0, y0), (x0 + box_w, y0 + box_h), box_color, -1)
|
| 248 |
+
cv2.addWeighted(overlay, 0.55, out, 0.45, 0, out)
|
| 249 |
+
|
| 250 |
+
for i, line in enumerate(lines):
|
| 251 |
+
cv2.putText(out, line, (x0 + pad, y0 + pad + (i + 1) * line_h - 2),
|
| 252 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.38, (255, 255, 255), 1)
|
| 253 |
+
|
| 254 |
+
return out
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def _draw_label(frame, text, origin, bg_color=(0, 0, 0),
|
| 258 |
+
font=cv2.FONT_HERSHEY_SIMPLEX, scale=0.55, thickness=1):
|
| 259 |
+
"""Draw a text label with filled background."""
|
| 260 |
+
x, y = origin
|
| 261 |
+
(tw, th), baseline = cv2.getTextSize(text, font, scale, thickness)
|
| 262 |
+
cv2.rectangle(frame, (x, y - th - 8), (x + tw + 4, y), bg_color, -1)
|
| 263 |
+
cv2.putText(frame, text, (x + 2, y - 4), font, scale, (255, 255, 255), thickness)
|
| 264 |
+
|
| 265 |
+
# ------------------------------------------------------------------
|
| 266 |
+
# High-level run loop
|
| 267 |
+
# ------------------------------------------------------------------
|
| 268 |
+
|
| 269 |
+
def run(self, source=0, show_window: bool = True):
|
| 270 |
+
"""
|
| 271 |
+
Start the live detection loop.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
source: Camera index (0, 1, β¦) or RTSP URL string.
|
| 275 |
+
show_window: Whether to display the OpenCV window.
|
| 276 |
+
|
| 277 |
+
Controls:
|
| 278 |
+
q β quit
|
| 279 |
+
s β save screenshot
|
| 280 |
+
+/- β increase / decrease confidence threshold
|
| 281 |
+
"""
|
| 282 |
+
capture = VideoCapture(source, use_motion_detection=False)
|
| 283 |
+
if not capture.start(verbose=True):
|
| 284 |
+
print("ERROR: Could not open video source.")
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
print("\n--- Weapon + Person Detection Running ---")
|
| 288 |
+
print("Controls: q = quit | s = screenshot | +/- = gun confidence")
|
| 289 |
+
print(f"Watching for: {list(self.weapon_names.values())}")
|
| 290 |
+
print(f"Gun conf: {self.gun_conf} | Min area: {self.min_weapon_area}px | Confirm: {self.confirm_frames} frames")
|
| 291 |
+
print()
|
| 292 |
+
|
| 293 |
+
fps_timer = time.time()
|
| 294 |
+
frame_count = 0
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
for original, _preprocessed in capture.stream_frames():
|
| 298 |
+
result = self.process_frame(original)
|
| 299 |
+
|
| 300 |
+
# Log alerts to console
|
| 301 |
+
if result.weapons:
|
| 302 |
+
names = ", ".join(f"{d.class_name}({d.confidence:.0%})" for d in result.weapons)
|
| 303 |
+
print(f"[ALERT] Weapons: {names} | Persons in frame: {result.person_count}")
|
| 304 |
+
|
| 305 |
+
if show_window:
|
| 306 |
+
display = self.draw_results(original, result)
|
| 307 |
+
|
| 308 |
+
# FPS counter
|
| 309 |
+
frame_count += 1
|
| 310 |
+
if frame_count % 15 == 0:
|
| 311 |
+
now = time.time()
|
| 312 |
+
fps = 15 / max(now - fps_timer, 1e-9)
|
| 313 |
+
fps_timer = now
|
| 314 |
+
cv2.setWindowTitle("Weapon Detection", f"Weapon Detection [{fps:.1f} FPS]")
|
| 315 |
+
|
| 316 |
+
cv2.imshow("Weapon Detection", display)
|
| 317 |
+
key = cv2.waitKey(1) & 0xFF
|
| 318 |
+
if key == ord("q"):
|
| 319 |
+
break
|
| 320 |
+
elif key == ord("s"):
|
| 321 |
+
fname = f"weapon_screenshot_{int(time.time())}.jpg"
|
| 322 |
+
cv2.imwrite(fname, display)
|
| 323 |
+
print(f"Screenshot saved: {fname}")
|
| 324 |
+
elif key == ord("+") or key == ord("="):
|
| 325 |
+
self.gun_conf = min(0.95, self.gun_conf + 0.05)
|
| 326 |
+
print(f"Gun confidence threshold -> {self.gun_conf:.2f}")
|
| 327 |
+
elif key == ord("-"):
|
| 328 |
+
self.gun_conf = max(0.10, self.gun_conf - 0.05)
|
| 329 |
+
print(f"Gun confidence threshold -> {self.gun_conf:.2f}")
|
| 330 |
+
finally:
|
| 331 |
+
capture.stop()
|
| 332 |
+
cv2.destroyAllWindows()
|
| 333 |
+
print("Detection stopped.")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# -----------------------------------------------------------------------
|
| 337 |
+
# CLI entry point
|
| 338 |
+
# -----------------------------------------------------------------------
|
| 339 |
+
|
| 340 |
+
def main():
|
| 341 |
+
import argparse
|
| 342 |
+
|
| 343 |
+
parser = argparse.ArgumentParser(description="Weapon & Person Detection System")
|
| 344 |
+
parser.add_argument("--source", default="0",
|
| 345 |
+
help="Camera index (0,1,..) or RTSP URL (default: 0)")
|
| 346 |
+
parser.add_argument("--gun-conf", type=float, default=GUN_CONF_THRESHOLD,
|
| 347 |
+
help=f"Gun model confidence threshold (default: {GUN_CONF_THRESHOLD})")
|
| 348 |
+
parser.add_argument("--person-conf", type=float, default=PERSON_CONF_THRESHOLD,
|
| 349 |
+
help=f"Person/knife confidence threshold (default: {PERSON_CONF_THRESHOLD})")
|
| 350 |
+
parser.add_argument("--cooldown", type=int, default=WEAPON_COOLDOWN_FRAMES,
|
| 351 |
+
help=f"Frames before alert clears (default: {WEAPON_COOLDOWN_FRAMES})")
|
| 352 |
+
parser.add_argument("--confirm", type=int, default=WEAPON_CONFIRM_FRAMES,
|
| 353 |
+
help=f"Consecutive frames to confirm weapon (default: {WEAPON_CONFIRM_FRAMES})")
|
| 354 |
+
parser.add_argument("--min-area", type=int, default=MIN_WEAPON_AREA,
|
| 355 |
+
help=f"Min weapon bbox area in pixels (default: {MIN_WEAPON_AREA})")
|
| 356 |
+
parser.add_argument("--gun-model", default=GUN_MODEL_PATH,
|
| 357 |
+
help="Path to gun detection model weights")
|
| 358 |
+
parser.add_argument("--person-model", default=PERSON_MODEL_PATH,
|
| 359 |
+
help="Path to person detection model weights")
|
| 360 |
+
args = parser.parse_args()
|
| 361 |
+
|
| 362 |
+
source = int(args.source) if args.source.strip().isdigit() else args.source
|
| 363 |
+
|
| 364 |
+
wpd = WeaponPersonDetector(
|
| 365 |
+
gun_model_path=args.gun_model,
|
| 366 |
+
person_model_path=args.person_model,
|
| 367 |
+
gun_conf=args.gun_conf,
|
| 368 |
+
person_conf=args.person_conf,
|
| 369 |
+
cooldown_frames=args.cooldown,
|
| 370 |
+
confirm_frames=args.confirm,
|
| 371 |
+
min_weapon_area=args.min_area,
|
| 372 |
+
)
|
| 373 |
+
wpd.run(source=source)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
main()
|
src/detectors/yolo_detector.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
YOLO Object Detection Module
|
| 3 |
+
Lightweight wrapper around Ultralytics YOLO results for consistent app usage.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from ultralytics import YOLO
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Detection:
|
| 16 |
+
"""Normalized detection record used across the NETRA app."""
|
| 17 |
+
|
| 18 |
+
class_id: int
|
| 19 |
+
class_name: str
|
| 20 |
+
confidence: float
|
| 21 |
+
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def area(self) -> int:
|
| 25 |
+
x1, y1, x2, y2 = self.bbox
|
| 26 |
+
return max(0, x2 - x1) * max(0, y2 - y1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class YOLODetector:
|
| 30 |
+
"""Simple inference wrapper over an Ultralytics YOLO model."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, model_path: Optional[str] = None):
|
| 33 |
+
resolved_path = self._resolve_model_path(model_path)
|
| 34 |
+
self.model = YOLO(resolved_path)
|
| 35 |
+
self.model_path = resolved_path
|
| 36 |
+
self.class_names: Dict[int, str] = dict(getattr(self.model, "names", {}) or {})
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def _resolve_model_path(model_path: Optional[str]) -> str:
|
| 40 |
+
if model_path:
|
| 41 |
+
return model_path
|
| 42 |
+
|
| 43 |
+
root = Path(__file__).resolve().parent.parent
|
| 44 |
+
candidates = [
|
| 45 |
+
root / "ai_models" / "object_detection" / "yolov8n.pt",
|
| 46 |
+
Path(__file__).resolve().parent / "model" / "yolov8n.pt",
|
| 47 |
+
Path(__file__).resolve().parent / "model" / "yolov8s.pt",
|
| 48 |
+
]
|
| 49 |
+
for candidate in candidates:
|
| 50 |
+
if candidate.exists():
|
| 51 |
+
return str(candidate)
|
| 52 |
+
|
| 53 |
+
return str(candidates[0])
|
| 54 |
+
|
| 55 |
+
def detect(
|
| 56 |
+
self,
|
| 57 |
+
frame: np.ndarray,
|
| 58 |
+
conf_threshold: float = 0.25,
|
| 59 |
+
classes: Optional[Sequence[int]] = None,
|
| 60 |
+
) -> List[Detection]:
|
| 61 |
+
"""Run inference on a single BGR frame and return normalized detections."""
|
| 62 |
+
results = self.model.predict(frame, conf=conf_threshold, classes=classes, verbose=False)
|
| 63 |
+
if not results:
|
| 64 |
+
return []
|
| 65 |
+
|
| 66 |
+
boxes = results[0].boxes
|
| 67 |
+
if boxes is None:
|
| 68 |
+
return []
|
| 69 |
+
|
| 70 |
+
detections: List[Detection] = []
|
| 71 |
+
for box in boxes:
|
| 72 |
+
cls_id = int(box.cls.item()) if box.cls is not None else -1
|
| 73 |
+
conf = float(box.conf.item()) if box.conf is not None else 0.0
|
| 74 |
+
xyxy = box.xyxy[0].tolist()
|
| 75 |
+
x1, y1, x2, y2 = [int(v) for v in xyxy]
|
| 76 |
+
class_name = self.class_names.get(cls_id, str(cls_id))
|
| 77 |
+
detections.append(
|
| 78 |
+
Detection(
|
| 79 |
+
class_id=cls_id,
|
| 80 |
+
class_name=class_name,
|
| 81 |
+
confidence=conf,
|
| 82 |
+
bbox=(x1, y1, x2, y2),
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return detections
|
src/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Video Processing Pipeline
|
| 3 |
+
Frame capture and preprocessing modules
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .video_capture import VideoCapture
|
| 7 |
+
|
| 8 |
+
__all__ = ['VideoCapture']
|
src/pipeline/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (412 Bytes). View file
|
|
|
src/pipeline/__pycache__/video_capture.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
src/pipeline/video_capture.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
Video Capture Module for AI Processing
|
| 4 |
+
Captures frames from webcam or RTSP stream and preprocesses them for YOLOv8
|
| 5 |
+
Uses motion detection (MOG2) to extract ROIs for faster inference
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import platform
|
| 11 |
+
from typing import Optional, Tuple, Generator, List
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ROI:
|
| 17 |
+
"""Region of Interest containing motion."""
|
| 18 |
+
x: int
|
| 19 |
+
y: int
|
| 20 |
+
width: int
|
| 21 |
+
height: int
|
| 22 |
+
cropped_frame: np.ndarray # Original cropped region
|
| 23 |
+
preprocessed: np.ndarray # Resized to 640x640 for YOLO
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MotionDetector:
|
| 27 |
+
"""Detects motion using MOG2 background subtraction."""
|
| 28 |
+
|
| 29 |
+
def __init__(self,
|
| 30 |
+
history: int = 500,
|
| 31 |
+
var_threshold: float = 16,
|
| 32 |
+
detect_shadows: bool = True,
|
| 33 |
+
min_contour_area: int = 500,
|
| 34 |
+
merge_distance: int = 50):
|
| 35 |
+
"""
|
| 36 |
+
Initialize MOG2 background subtractor.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
history: Number of frames for background model
|
| 40 |
+
var_threshold: Variance threshold for background/foreground segmentation
|
| 41 |
+
detect_shadows: Whether to detect shadows (marks them gray vs white)
|
| 42 |
+
min_contour_area: Minimum area (pixels) to consider as valid motion
|
| 43 |
+
merge_distance: Distance to merge nearby contours into single ROI
|
| 44 |
+
"""
|
| 45 |
+
self.bg_subtractor = cv2.createBackgroundSubtractorMOG2(
|
| 46 |
+
history=history,
|
| 47 |
+
varThreshold=var_threshold,
|
| 48 |
+
detectShadows=detect_shadows
|
| 49 |
+
)
|
| 50 |
+
self.min_contour_area = min_contour_area
|
| 51 |
+
self.merge_distance = merge_distance
|
| 52 |
+
self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 53 |
+
|
| 54 |
+
def get_foreground_mask(self, frame: np.ndarray) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Get binary mask of moving objects.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
frame: Input BGR frame
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Binary mask where white = motion
|
| 63 |
+
"""
|
| 64 |
+
# Apply background subtraction
|
| 65 |
+
fg_mask = self.bg_subtractor.apply(frame)
|
| 66 |
+
|
| 67 |
+
# Remove shadows (gray pixels become black)
|
| 68 |
+
_, fg_mask = cv2.threshold(fg_mask, 250, 255, cv2.THRESH_BINARY)
|
| 69 |
+
|
| 70 |
+
# Morphological operations to clean up noise
|
| 71 |
+
fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_OPEN, self.kernel)
|
| 72 |
+
fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, self.kernel)
|
| 73 |
+
fg_mask = cv2.dilate(fg_mask, self.kernel, iterations=2)
|
| 74 |
+
|
| 75 |
+
return fg_mask
|
| 76 |
+
|
| 77 |
+
def _merge_bounding_boxes(self, boxes: List[Tuple[int, int, int, int]]) -> List[Tuple[int, int, int, int]]:
|
| 78 |
+
"""Merge nearby bounding boxes to reduce fragmentation."""
|
| 79 |
+
if not boxes:
|
| 80 |
+
return []
|
| 81 |
+
|
| 82 |
+
merged = []
|
| 83 |
+
used = [False] * len(boxes)
|
| 84 |
+
|
| 85 |
+
for i, (x1, y1, w1, h1) in enumerate(boxes):
|
| 86 |
+
if used[i]:
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
# Start with current box
|
| 90 |
+
min_x, min_y = x1, y1
|
| 91 |
+
max_x, max_y = x1 + w1, y1 + h1
|
| 92 |
+
used[i] = True
|
| 93 |
+
|
| 94 |
+
# Find and merge nearby boxes
|
| 95 |
+
changed = True
|
| 96 |
+
while changed:
|
| 97 |
+
changed = False
|
| 98 |
+
for j, (x2, y2, w2, h2) in enumerate(boxes):
|
| 99 |
+
if used[j]:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
# Check if boxes are close enough to merge
|
| 103 |
+
if (x2 < max_x + self.merge_distance and
|
| 104 |
+
x2 + w2 > min_x - self.merge_distance and
|
| 105 |
+
y2 < max_y + self.merge_distance and
|
| 106 |
+
y2 + h2 > min_y - self.merge_distance):
|
| 107 |
+
|
| 108 |
+
min_x = min(min_x, x2)
|
| 109 |
+
min_y = min(min_y, y2)
|
| 110 |
+
max_x = max(max_x, x2 + w2)
|
| 111 |
+
max_y = max(max_y, y2 + h2)
|
| 112 |
+
used[j] = True
|
| 113 |
+
changed = True
|
| 114 |
+
|
| 115 |
+
merged.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
| 116 |
+
|
| 117 |
+
return merged
|
| 118 |
+
|
| 119 |
+
def detect_motion_regions(self, frame: np.ndarray,
|
| 120 |
+
padding: int = 20) -> List[Tuple[int, int, int, int]]:
|
| 121 |
+
"""
|
| 122 |
+
Detect regions with motion.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
frame: Input BGR frame
|
| 126 |
+
padding: Pixels to add around detected regions
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
List of bounding boxes (x, y, width, height)
|
| 130 |
+
"""
|
| 131 |
+
fg_mask = self.get_foreground_mask(frame)
|
| 132 |
+
|
| 133 |
+
# Find contours of moving objects
|
| 134 |
+
contours, _ = cv2.findContours(fg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 135 |
+
|
| 136 |
+
boxes = []
|
| 137 |
+
h, w = frame.shape[:2]
|
| 138 |
+
|
| 139 |
+
for contour in contours:
|
| 140 |
+
area = cv2.contourArea(contour)
|
| 141 |
+
if area < self.min_contour_area:
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
x, y, bw, bh = cv2.boundingRect(contour)
|
| 145 |
+
|
| 146 |
+
# Add padding and clamp to frame bounds
|
| 147 |
+
x = max(0, x - padding)
|
| 148 |
+
y = max(0, y - padding)
|
| 149 |
+
bw = min(w - x, bw + 2 * padding)
|
| 150 |
+
bh = min(h - y, bh + 2 * padding)
|
| 151 |
+
|
| 152 |
+
boxes.append((x, y, bw, bh))
|
| 153 |
+
|
| 154 |
+
# Merge nearby boxes
|
| 155 |
+
return self._merge_bounding_boxes(boxes)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class VideoCapture:
|
| 159 |
+
"""Captures and preprocesses video frames for AI inference."""
|
| 160 |
+
|
| 161 |
+
# YOLOv8 native input size
|
| 162 |
+
TARGET_SIZE = (640, 640)
|
| 163 |
+
|
| 164 |
+
def __init__(self, source: int | str = 0, use_motion_detection: bool = True):
|
| 165 |
+
"""
|
| 166 |
+
Initialize video capture.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
source: Camera index (0 for default webcam) or RTSP URL string
|
| 170 |
+
Example RTSP: "rtsp://username:password@ip_address:port/stream"
|
| 171 |
+
use_motion_detection: Enable MOG2 motion detection for ROI extraction
|
| 172 |
+
"""
|
| 173 |
+
self.source = self._normalize_source(source)
|
| 174 |
+
self.cap: Optional[cv2.VideoCapture] = None
|
| 175 |
+
self.use_motion_detection = use_motion_detection
|
| 176 |
+
self.motion_detector: Optional[MotionDetector] = None
|
| 177 |
+
self.active_source: Optional[int | str] = None
|
| 178 |
+
self.active_backend: Optional[int] = None
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def _normalize_source(source: int | str) -> int | str:
|
| 182 |
+
"""Normalize source values so numeric strings map to camera indices."""
|
| 183 |
+
if isinstance(source, str) and source.strip().isdigit():
|
| 184 |
+
return int(source.strip())
|
| 185 |
+
return source
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def _backend_name(backend: int) -> str:
|
| 189 |
+
"""Get a readable backend name for diagnostics."""
|
| 190 |
+
names = {
|
| 191 |
+
cv2.CAP_ANY: "CAP_ANY",
|
| 192 |
+
cv2.CAP_DSHOW: "CAP_DSHOW",
|
| 193 |
+
cv2.CAP_MSMF: "CAP_MSMF",
|
| 194 |
+
}
|
| 195 |
+
return names.get(backend, str(backend))
|
| 196 |
+
|
| 197 |
+
def _source_candidates(self) -> List[int | str]:
|
| 198 |
+
"""Return source candidates to try opening in order."""
|
| 199 |
+
if isinstance(self.source, int):
|
| 200 |
+
candidates = [self.source]
|
| 201 |
+
if self.source == 0:
|
| 202 |
+
candidates.extend([1, 2])
|
| 203 |
+
return candidates
|
| 204 |
+
return [self.source]
|
| 205 |
+
|
| 206 |
+
def _backend_candidates(self) -> List[int]:
|
| 207 |
+
"""Return backend candidates based on platform and source type."""
|
| 208 |
+
if isinstance(self.source, str):
|
| 209 |
+
return [cv2.CAP_ANY]
|
| 210 |
+
|
| 211 |
+
if platform.system().lower().startswith("win"):
|
| 212 |
+
return [cv2.CAP_DSHOW, cv2.CAP_MSMF, cv2.CAP_ANY]
|
| 213 |
+
|
| 214 |
+
return [cv2.CAP_ANY]
|
| 215 |
+
|
| 216 |
+
def start(self, verbose: bool = True) -> bool:
|
| 217 |
+
"""
|
| 218 |
+
Start the video capture.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
True if capture started successfully, False otherwise
|
| 222 |
+
"""
|
| 223 |
+
self.stop()
|
| 224 |
+
|
| 225 |
+
open_attempts = []
|
| 226 |
+
for source_candidate in self._source_candidates():
|
| 227 |
+
for backend in self._backend_candidates():
|
| 228 |
+
cap = cv2.VideoCapture(source_candidate, backend)
|
| 229 |
+
open_attempts.append(f"{source_candidate} via {self._backend_name(backend)}")
|
| 230 |
+
|
| 231 |
+
if cap.isOpened():
|
| 232 |
+
self.cap = cap
|
| 233 |
+
self.active_source = source_candidate
|
| 234 |
+
self.active_backend = backend
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
cap.release()
|
| 238 |
+
|
| 239 |
+
if self.cap is not None:
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
if self.cap is None:
|
| 243 |
+
if verbose:
|
| 244 |
+
print(f"Error: Could not open video source: {self.source}")
|
| 245 |
+
print("Tried:")
|
| 246 |
+
for attempt in open_attempts:
|
| 247 |
+
print(f" - {attempt}")
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
# Set buffer size to minimize latency (useful for RTSP streams)
|
| 251 |
+
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
| 252 |
+
|
| 253 |
+
# Initialize motion detector if enabled
|
| 254 |
+
if self.use_motion_detection:
|
| 255 |
+
self.motion_detector = MotionDetector()
|
| 256 |
+
if verbose:
|
| 257 |
+
print("Motion detection enabled (MOG2)")
|
| 258 |
+
|
| 259 |
+
# Print capture info
|
| 260 |
+
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 261 |
+
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 262 |
+
fps = self.cap.get(cv2.CAP_PROP_FPS)
|
| 263 |
+
if verbose:
|
| 264 |
+
backend_name = self._backend_name(self.active_backend) if self.active_backend is not None else "Unknown"
|
| 265 |
+
print(f"Video capture started: source={self.active_source}, backend={backend_name}")
|
| 266 |
+
print(f"Resolution: {width}x{height} @ {fps:.1f} FPS")
|
| 267 |
+
|
| 268 |
+
return True
|
| 269 |
+
|
| 270 |
+
def stop(self):
|
| 271 |
+
"""Release the video capture resources."""
|
| 272 |
+
if self.cap is not None:
|
| 273 |
+
self.cap.release()
|
| 274 |
+
self.cap = None
|
| 275 |
+
self.active_source = None
|
| 276 |
+
self.active_backend = None
|
| 277 |
+
print("Video capture stopped")
|
| 278 |
+
|
| 279 |
+
def preprocess_frame(self, frame: np.ndarray) -> np.ndarray:
|
| 280 |
+
"""
|
| 281 |
+
Preprocess frame for YOLOv8 inference.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
frame: Raw BGR frame from camera
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
Preprocessed frame resized to 640x640
|
| 288 |
+
"""
|
| 289 |
+
# Resize to YOLOv8 native size (640x640)
|
| 290 |
+
resized = cv2.resize(frame, self.TARGET_SIZE, interpolation=cv2.INTER_LINEAR)
|
| 291 |
+
return resized
|
| 292 |
+
|
| 293 |
+
def read_frame(self) -> Tuple[bool, Optional[np.ndarray], Optional[np.ndarray]]:
|
| 294 |
+
"""
|
| 295 |
+
Read and preprocess a single frame.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Tuple of (success, original_frame, preprocessed_frame)
|
| 299 |
+
"""
|
| 300 |
+
if self.cap is None:
|
| 301 |
+
return False, None, None
|
| 302 |
+
|
| 303 |
+
ret, frame = self.cap.read()
|
| 304 |
+
|
| 305 |
+
if not ret or frame is None:
|
| 306 |
+
return False, None, None
|
| 307 |
+
|
| 308 |
+
preprocessed = self.preprocess_frame(frame)
|
| 309 |
+
return True, frame, preprocessed
|
| 310 |
+
|
| 311 |
+
def extract_rois(self, frame: np.ndarray) -> List[ROI]:
|
| 312 |
+
"""
|
| 313 |
+
Extract regions of interest (moving objects) from frame.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
frame: Input BGR frame
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
List of ROI objects containing cropped and preprocessed regions
|
| 320 |
+
"""
|
| 321 |
+
if self.motion_detector is None:
|
| 322 |
+
# If no motion detection, return whole frame as single ROI
|
| 323 |
+
preprocessed = self.preprocess_frame(frame)
|
| 324 |
+
return [ROI(0, 0, frame.shape[1], frame.shape[0], frame, preprocessed)]
|
| 325 |
+
|
| 326 |
+
boxes = self.motion_detector.detect_motion_regions(frame)
|
| 327 |
+
|
| 328 |
+
if not boxes:
|
| 329 |
+
return []
|
| 330 |
+
|
| 331 |
+
rois = []
|
| 332 |
+
for x, y, w, h in boxes:
|
| 333 |
+
cropped = frame[y:y+h, x:x+w]
|
| 334 |
+
preprocessed = cv2.resize(cropped, self.TARGET_SIZE, interpolation=cv2.INTER_LINEAR)
|
| 335 |
+
rois.append(ROI(x, y, w, h, cropped, preprocessed))
|
| 336 |
+
|
| 337 |
+
return rois
|
| 338 |
+
|
| 339 |
+
def read_frame_with_rois(self) -> Tuple[bool, Optional[np.ndarray], List[ROI], Optional[np.ndarray]]:
|
| 340 |
+
"""
|
| 341 |
+
Read frame and extract ROIs for motion regions.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
Tuple of (success, original_frame, list_of_rois, foreground_mask)
|
| 345 |
+
"""
|
| 346 |
+
if self.cap is None:
|
| 347 |
+
return False, None, [], None
|
| 348 |
+
|
| 349 |
+
ret, frame = self.cap.read()
|
| 350 |
+
|
| 351 |
+
if not ret or frame is None:
|
| 352 |
+
return False, None, [], None
|
| 353 |
+
|
| 354 |
+
rois = self.extract_rois(frame)
|
| 355 |
+
|
| 356 |
+
# Get foreground mask for visualization
|
| 357 |
+
fg_mask = None
|
| 358 |
+
if self.motion_detector is not None:
|
| 359 |
+
fg_mask = self.motion_detector.get_foreground_mask(frame)
|
| 360 |
+
|
| 361 |
+
return True, frame, rois, fg_mask
|
| 362 |
+
|
| 363 |
+
def stream_frames(self) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
|
| 364 |
+
"""
|
| 365 |
+
Generator that continuously yields frames (no motion detection).
|
| 366 |
+
|
| 367 |
+
Yields:
|
| 368 |
+
Tuple of (original_frame, preprocessed_frame)
|
| 369 |
+
"""
|
| 370 |
+
while True:
|
| 371 |
+
success, original, preprocessed = self.read_frame()
|
| 372 |
+
if not success:
|
| 373 |
+
break
|
| 374 |
+
yield original, preprocessed
|
| 375 |
+
|
| 376 |
+
def stream_rois(self) -> Generator[Tuple[np.ndarray, List[ROI], Optional[np.ndarray]], None, None]:
|
| 377 |
+
"""
|
| 378 |
+
Generator that yields frames with motion-detected ROIs.
|
| 379 |
+
|
| 380 |
+
Yields:
|
| 381 |
+
Tuple of (original_frame, list_of_rois, foreground_mask)
|
| 382 |
+
"""
|
| 383 |
+
while True:
|
| 384 |
+
success, original, rois, fg_mask = self.read_frame_with_rois()
|
| 385 |
+
if not success:
|
| 386 |
+
break
|
| 387 |
+
yield original, rois, fg_mask
|
| 388 |
+
|
| 389 |
+
def __enter__(self):
|
| 390 |
+
"""Context manager entry."""
|
| 391 |
+
if not self.start():
|
| 392 |
+
raise RuntimeError(f"Could not open video source: {self.source}")
|
| 393 |
+
return self
|
| 394 |
+
|
| 395 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 396 |
+
"""Context manager exit."""
|
| 397 |
+
self.stop()
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def main():
|
| 401 |
+
"""Demo: Capture frames with motion detection from webcam."""
|
| 402 |
+
|
| 403 |
+
# Use 0 for default webcam, or provide RTSP URL for IP camera
|
| 404 |
+
# Example RTSP: "rtsp://admin:password@192.168.1.100:554/stream1"
|
| 405 |
+
source = 0
|
| 406 |
+
|
| 407 |
+
with VideoCapture(source, use_motion_detection=True) as capture:
|
| 408 |
+
print("Press 'q' to quit")
|
| 409 |
+
print("Motion detection active - only moving regions will be processed")
|
| 410 |
+
|
| 411 |
+
for original, rois, fg_mask in capture.stream_rois():
|
| 412 |
+
# Draw bounding boxes around motion regions
|
| 413 |
+
display_frame = original.copy()
|
| 414 |
+
|
| 415 |
+
for i, roi in enumerate(rois):
|
| 416 |
+
# Draw green rectangle around ROI
|
| 417 |
+
cv2.rectangle(display_frame,
|
| 418 |
+
(roi.x, roi.y),
|
| 419 |
+
(roi.x + roi.width, roi.y + roi.height),
|
| 420 |
+
(0, 255, 0), 2)
|
| 421 |
+
|
| 422 |
+
# Label with ROI index and size
|
| 423 |
+
label = f"ROI {i+1}: {roi.width}x{roi.height}"
|
| 424 |
+
cv2.putText(display_frame, label, (roi.x, roi.y - 10),
|
| 425 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
| 426 |
+
|
| 427 |
+
# Show info
|
| 428 |
+
info = f"Motion ROIs: {len(rois)} | Press 'q' to quit"
|
| 429 |
+
cv2.putText(display_frame, info, (10, 30),
|
| 430 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 431 |
+
|
| 432 |
+
cv2.imshow("Video Capture - Motion Detection", display_frame)
|
| 433 |
+
|
| 434 |
+
# Show foreground mask
|
| 435 |
+
if fg_mask is not None:
|
| 436 |
+
cv2.imshow("Foreground Mask", fg_mask)
|
| 437 |
+
|
| 438 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 439 |
+
break
|
| 440 |
+
|
| 441 |
+
cv2.destroyAllWindows()
|
| 442 |
+
print("Done!")
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if __name__ == "__main__":
|
| 446 |
+
main()
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility Functions
|
| 3 |
+
Common helper functions and utilities
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Setup logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ensure_dir(directory: Path) -> Path:
|
| 14 |
+
"""
|
| 15 |
+
Ensure directory exists, create if needed
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
directory: Path to directory
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Path object
|
| 22 |
+
"""
|
| 23 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
return directory
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_project_root() -> Path:
|
| 28 |
+
"""
|
| 29 |
+
Get root directory of NETRA project
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Path to project root
|
| 33 |
+
"""
|
| 34 |
+
return Path(__file__).parent.parent.parent
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (983 Bytes). View file
|
|
|
src/utils/__pycache__/model_downloader.cpython-310.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
src/utils/model_downloader.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Downloader - Downloads AI models from Hugging Face Hub
|
| 3 |
+
Automatically caches models locally after first download
|
| 4 |
+
FULLY PORTABLE - Works on any device with any project path
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
# Detect PROJECT_ROOT dynamically
|
| 14 |
+
def get_project_root():
|
| 15 |
+
"""
|
| 16 |
+
Find project root by looking for config/ directory
|
| 17 |
+
Works regardless of where app.py is located
|
| 18 |
+
"""
|
| 19 |
+
current_path = Path(__file__).resolve() # Full path to this file
|
| 20 |
+
|
| 21 |
+
# Go up from src/utils/model_downloader.py to project root
|
| 22 |
+
for parent in current_path.parents:
|
| 23 |
+
if (parent / 'config').exists() and (parent / 'webapp').exists():
|
| 24 |
+
return parent
|
| 25 |
+
|
| 26 |
+
# Fallback: assume parent of src/
|
| 27 |
+
return current_path.parent.parent.parent
|
| 28 |
+
|
| 29 |
+
PROJECT_ROOT = get_project_root()
|
| 30 |
+
REPO_ID = "itsluckysharma01/NETRA-Models"
|
| 31 |
+
CACHE_DIR = PROJECT_ROOT / 'ai_models' # Models cached in project root
|
| 32 |
+
|
| 33 |
+
print(f"\nπ [Model Downloader] PROJECT_ROOT detected: {PROJECT_ROOT}")
|
| 34 |
+
print(f"π [Model Downloader] CACHE_DIR: {CACHE_DIR}\n")
|
| 35 |
+
|
| 36 |
+
def download_model(filename):
|
| 37 |
+
"""
|
| 38 |
+
Download model from Hugging Face Hub with automatic path handling
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
filename: Model file path (e.g., 'ai_models/activity_recognition/violence_model.h5')
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
str: Path to downloaded/cached model (absolute path)
|
| 45 |
+
"""
|
| 46 |
+
try:
|
| 47 |
+
# Ensure cache directory exists
|
| 48 |
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
# Check if model already exists in flat structure
|
| 51 |
+
local_path = CACHE_DIR / filename
|
| 52 |
+
if local_path.exists():
|
| 53 |
+
print(f"β
Model cached: {filename}")
|
| 54 |
+
return str(local_path)
|
| 55 |
+
|
| 56 |
+
# Download from Hugging Face Hub (goes to HF cache)
|
| 57 |
+
print(f"π₯ Downloading: {filename}")
|
| 58 |
+
downloaded_path = hf_hub_download(
|
| 59 |
+
repo_id=REPO_ID,
|
| 60 |
+
filename=filename,
|
| 61 |
+
cache_dir=str(CACHE_DIR),
|
| 62 |
+
local_files_only=False
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Copy from HF cache structure to flat ai_models/ structure
|
| 66 |
+
src_path = Path(downloaded_path)
|
| 67 |
+
|
| 68 |
+
# Create destination directory
|
| 69 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
# Copy file to flat structure
|
| 72 |
+
shutil.copy2(src_path, local_path)
|
| 73 |
+
print(f"β
Downloaded and cached: {filename}")
|
| 74 |
+
return str(local_path)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"β Error downloading {filename}: {e}")
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def ensure_model_exists(filename):
|
| 82 |
+
"""
|
| 83 |
+
Ensure a model exists locally, download if necessary
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
filename: Model file path
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
bool: True if model exists or was downloaded successfully
|
| 90 |
+
"""
|
| 91 |
+
local_path = CACHE_DIR / filename
|
| 92 |
+
|
| 93 |
+
# Already exists
|
| 94 |
+
if local_path.exists():
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
# Try to download
|
| 98 |
+
result = download_model(filename)
|
| 99 |
+
return result is not None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def setup_all_models():
|
| 103 |
+
"""Download all required models on startup"""
|
| 104 |
+
models = [
|
| 105 |
+
"ai_models/activity_recognition/violence_model.h5",
|
| 106 |
+
"ai_models/object_detection/yolov8n.pt",
|
| 107 |
+
"ai_models/pose_detection/yolo11n-pose.pt",
|
| 108 |
+
"ai_models/weapon_detection/best.pt",
|
| 109 |
+
"ai_models/analysis_models/binarycnn200.h5",
|
| 110 |
+
"ai_models/analysis_models/CNN93.h5",
|
| 111 |
+
"ai_models/analysis_models/CustomCNN.h5",
|
| 112 |
+
"ai_models/analysis_models/fight_detection_model.h5",
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
print("\n" + "=" * 60)
|
| 116 |
+
print("π₯ SETTING UP AI MODELS FROM HUGGING FACE HUB")
|
| 117 |
+
print("=" * 60)
|
| 118 |
+
print(f"π PROJECT_ROOT: {PROJECT_ROOT}")
|
| 119 |
+
print(f"π CACHE_DIR: {CACHE_DIR}")
|
| 120 |
+
print(f"π Cache exists: {CACHE_DIR.exists()}")
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
|
| 123 |
+
downloaded = 0
|
| 124 |
+
cached = 0
|
| 125 |
+
failed = 0
|
| 126 |
+
|
| 127 |
+
for model in models:
|
| 128 |
+
local_path = CACHE_DIR / model
|
| 129 |
+
|
| 130 |
+
if local_path.exists():
|
| 131 |
+
print(f"β
Cached: {model}")
|
| 132 |
+
cached += 1
|
| 133 |
+
else:
|
| 134 |
+
try:
|
| 135 |
+
result = download_model(model)
|
| 136 |
+
if result:
|
| 137 |
+
downloaded += 1
|
| 138 |
+
else:
|
| 139 |
+
failed += 1
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"β οΈ Warning: Could not load {model}")
|
| 142 |
+
failed += 1
|
| 143 |
+
|
| 144 |
+
print("\n" + "=" * 60)
|
| 145 |
+
print(f"β
Setup Complete: {downloaded} downloaded, {cached} cached, {failed} warnings")
|
| 146 |
+
print(f"π Models should be at: {CACHE_DIR}")
|
| 147 |
+
print("=" * 60 + "\n")
|