mrvero commited on
Commit
ac9ec2c
·
verified ·
1 Parent(s): cda3782

Upload safety_detector.py

Browse files
Files changed (1) hide show
  1. safety_detector.py +926 -0
safety_detector.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from ultralytics import YOLO
4
+ import torch
5
+ import time
6
+ from datetime import datetime
7
+ import os
8
+ import json
9
+ from threading import Thread
10
+ import queue
11
+ from typing import Dict, List, Tuple, Optional
12
+ import requests
13
+
14
+ class SafetyDetector:
15
+ """
16
+ Real-time safety compliance detection system using YOLO for object detection.
17
+ Detects people and safety equipment like hard hats, safety vests, and safety glasses.
18
+ """
19
+
20
+ def __init__(self, model_path: Optional[str] = None, confidence_threshold: float = 0.5):
21
+ """
22
+ Initialize the safety detector with a specialized PPE detection model.
23
+
24
+ Args:
25
+ model_path: Path to custom model, if None will download PPE model
26
+ confidence_threshold: Minimum confidence for detections
27
+ """
28
+ self.confidence_threshold = confidence_threshold
29
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ # Stricter confidence thresholds for different equipment types to reduce false positives
32
+ self.equipment_confidence_thresholds = {
33
+ 'hardhat': 0.7, # Higher threshold for hard hats (hair confusion)
34
+ 'safety_vest': 0.75, # Higher threshold for safety vests (clothing confusion)
35
+ 'mask': 0.6, # Moderate threshold for masks
36
+ 'person': 0.5, # Standard threshold for people
37
+ 'no_hardhat': 0.6, # Moderate threshold for NO- detections
38
+ 'no_safety_vest': 0.6,
39
+ 'no_mask': 0.6
40
+ }
41
+
42
+ # Try to load a specialized PPE detection model
43
+ self.model = self._load_ppe_model(model_path)
44
+
45
+ # PPE class names - these are the actual classes we expect from PPE models
46
+ self.ppe_classes = {
47
+ 'hardhat': ['Hardhat', 'hardhat', 'helmet', 'hard hat'],
48
+ 'safety_vest': ['Safety Vest', 'safety vest', 'vest', 'safety-vest', 'Safety-Vest'],
49
+ 'no_hardhat': ['NO-Hardhat', 'no-hardhat', 'no hardhat', 'NO-Helmet'],
50
+ 'no_safety_vest': ['NO-Safety Vest', 'no-safety-vest', 'no safety vest', 'NO-Safety-Vest'],
51
+ 'person': ['Person', 'person'],
52
+ 'mask': ['Mask', 'mask'],
53
+ 'no_mask': ['NO-Mask', 'no-mask', 'no mask'],
54
+ 'safety_gloves': ['Safety Gloves', 'safety-gloves', 'gloves', 'Gloves'],
55
+ 'safety_glasses': ['Safety Glasses', 'safety-glasses', 'glasses', 'Safety-Glasses'],
56
+ 'hearing_protection': ['Hearing Protection', 'hearing-protection', 'ear protection']
57
+ }
58
+
59
+ print(f"Using device: {self.device}")
60
+ print(f"Loaded PPE detection model with stricter confidence thresholds")
61
+ print(f"Equipment thresholds: {self.equipment_confidence_thresholds}")
62
+
63
+ # Colors for bounding boxes
64
+ self.colors = {
65
+ 'person': (0, 255, 0), # Green for compliant person
66
+ 'violation': (0, 0, 255), # Red for safety violation
67
+ 'equipment': (255, 255, 0), # Yellow for safety equipment
68
+ 'warning': (0, 165, 255) # Orange for warnings
69
+ }
70
+
71
+ # Violation tracking
72
+ self.violations = []
73
+ self.violation_images_dir = "violation_captures"
74
+ os.makedirs(self.violation_images_dir, exist_ok=True)
75
+
76
+ def _load_ppe_model(self, model_path: Optional[str] = None) -> YOLO:
77
+ """Load a specialized PPE detection model."""
78
+ if model_path and os.path.exists(model_path):
79
+ print(f"Loading custom model from {model_path}")
80
+ return YOLO(model_path)
81
+
82
+ # Try to download YOLOv8-compatible PPE models
83
+ ppe_model_urls = [
84
+ # Try the snehilsanyal YOLOv8 PPE model (best.pt)
85
+ "https://github.com/snehilsanyal/Construction-Site-Safety-PPE-Detection/raw/main/models/best.pt",
86
+ # Try mayank13-01 YOLOv8 PPE model
87
+ "https://github.com/mayank13-01/Yolov8-PPE/raw/main/YOLO-Weights/ppe.pt"
88
+ ]
89
+
90
+ for i, url in enumerate(ppe_model_urls):
91
+ try:
92
+ model_filename = f"ppe_yolov8_model_{i}.pt"
93
+ if not os.path.exists(model_filename):
94
+ print(f"Downloading PPE detection model from {url}...")
95
+ response = requests.get(url, timeout=60)
96
+ if response.status_code == 200:
97
+ with open(model_filename, 'wb') as f:
98
+ f.write(response.content)
99
+ print(f"Downloaded PPE model successfully as {model_filename}")
100
+
101
+ if os.path.exists(model_filename):
102
+ print(f"Loading YOLOv8 PPE model from {model_filename}")
103
+ model = YOLO(model_filename)
104
+
105
+ # Test if the model loads properly
106
+ classes = self._get_model_classes(model)
107
+ print(f"Model classes: {classes}")
108
+
109
+ # Check if it has PPE-related classes
110
+ ppe_related = any(
111
+ any(keyword in str(cls).lower() for keyword in ['hardhat', 'vest', 'helmet', 'mask', 'person'])
112
+ for cls in classes
113
+ )
114
+
115
+ if ppe_related:
116
+ print(f"✅ Found PPE-capable model with {len(classes)} classes")
117
+ return model
118
+ else:
119
+ print(f"⚠️ Model doesn't seem to have PPE classes: {classes}")
120
+
121
+ except Exception as e:
122
+ print(f"Failed to download/load from {url}: {e}")
123
+ continue
124
+
125
+ # Fallback to YOLOv8 with a warning
126
+ print("⚠️ Warning: Could not load specialized PPE model, falling back to YOLOv8n")
127
+ print(" Note: YOLOv8n can detect people but not safety equipment")
128
+ return YOLO('yolov8n.pt')
129
+
130
+ def _get_model_classes(self, model=None) -> List[str]:
131
+ """Get the list of classes the model can detect."""
132
+ if model is None:
133
+ model = self.model
134
+ if hasattr(model, 'names'):
135
+ return list(model.names.values())
136
+ return []
137
+
138
+ def _get_class_category(self, class_name: str) -> str:
139
+ """Map detected class name to our safety categories."""
140
+ class_name_lower = class_name.lower()
141
+
142
+ for category, variations in self.ppe_classes.items():
143
+ for variation in variations:
144
+ if variation.lower() in class_name_lower or class_name_lower in variation.lower():
145
+ return category
146
+
147
+ return class_name_lower
148
+
149
+ def detect_safety_violations(self, frame: np.ndarray) -> Dict:
150
+ """
151
+ Detect safety violations in the given frame with improved accuracy.
152
+
153
+ Returns:
154
+ Dictionary containing detection results and violations
155
+ """
156
+ start_time = time.time()
157
+
158
+ # Run detection with optimized settings for speed
159
+ results = self.model(frame, conf=0.3, verbose=False, imgsz=640, half=False)
160
+
161
+ detections = []
162
+ people_count = 0
163
+ safety_equipment_detected = {
164
+ 'hardhat': 0,
165
+ 'safety_vest': 0,
166
+ 'safety_gloves': 0,
167
+ 'safety_glasses': 0,
168
+ 'hearing_protection': 0,
169
+ 'mask': 0
170
+ }
171
+ violations = []
172
+ no_equipment_detections = [] # Track NO- detections separately
173
+
174
+ # Process detections with stricter filtering
175
+ for r in results:
176
+ boxes = r.boxes
177
+ if boxes is not None:
178
+ for box in boxes:
179
+ # Get detection info
180
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
181
+ confidence = box.conf[0].cpu().numpy()
182
+ class_id = int(box.cls[0].cpu().numpy())
183
+
184
+ # Get class name
185
+ if hasattr(self.model, 'names'):
186
+ class_name = self.model.names[class_id]
187
+ else:
188
+ class_name = f"class_{class_id}"
189
+
190
+ # Map to our categories
191
+ category = self._get_class_category(class_name)
192
+
193
+ # Apply stricter confidence thresholds based on equipment type
194
+ required_confidence = self.equipment_confidence_thresholds.get(category, self.confidence_threshold)
195
+
196
+ # Skip detections that don't meet the stricter threshold
197
+ if confidence < required_confidence:
198
+ continue
199
+
200
+ detection = {
201
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
202
+ 'confidence': float(confidence),
203
+ 'class': class_name,
204
+ 'category': category
205
+ }
206
+ detections.append(detection)
207
+
208
+ # Count people and safety equipment
209
+ if category == 'person':
210
+ people_count += 1
211
+ elif category in safety_equipment_detected:
212
+ safety_equipment_detected[category] += 1
213
+ elif category in ['hardhat', 'safety_vest', 'mask'] and not category.startswith('no_'):
214
+ safety_equipment_detected[category] += 1
215
+
216
+ # Handle negative detections (NO-Hardhat, NO-Mask, etc.)
217
+ # These indicate violations - a person without required equipment
218
+ if category.startswith('no_'):
219
+ equipment_type = category.replace('no_', '')
220
+ if equipment_type in ['hardhat', 'safety_vest', 'mask']:
221
+ no_equipment_detections.append({
222
+ 'type': f'missing_{equipment_type}',
223
+ 'severity': 'high',
224
+ 'description': f'Person detected without {equipment_type.replace("_", " ").title()}',
225
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
226
+ 'confidence': float(confidence),
227
+ 'equipment_type': equipment_type
228
+ })
229
+
230
+ # Create violations based on NO- detections (these are more reliable)
231
+ violations.extend(no_equipment_detections)
232
+
233
+ # If we have people but no NO- detections, check equipment ratios
234
+ if people_count > 0 and len(no_equipment_detections) == 0:
235
+ required_equipment = ['hardhat', 'safety_vest', 'mask']
236
+
237
+ for equipment in required_equipment:
238
+ detected_count = safety_equipment_detected[equipment]
239
+
240
+ # If significantly fewer equipment than people, assume violations
241
+ if detected_count < people_count * 0.8: # Allow some tolerance
242
+ missing_count = people_count - detected_count
243
+ equipment_name = equipment.replace("_", " ").title()
244
+ violations.append({
245
+ 'type': f'missing_{equipment}',
246
+ 'severity': 'high',
247
+ 'description': f'{missing_count} person(s) likely missing {equipment_name}',
248
+ 'count': missing_count
249
+ })
250
+
251
+ # Special handling for masks - they're often not detected well
252
+ mask_detected = safety_equipment_detected['mask']
253
+ no_mask_detected = len([v for v in no_equipment_detections if v['equipment_type'] == 'mask'])
254
+
255
+ if people_count > 0 and mask_detected == 0 and no_mask_detected == 0:
256
+ # No mask detections at all - assume people are not wearing masks
257
+ violations.append({
258
+ 'type': 'missing_mask',
259
+ 'severity': 'high',
260
+ 'description': f'{people_count} person(s) not wearing Face Mask',
261
+ 'count': people_count
262
+ })
263
+
264
+ processing_time = time.time() - start_time
265
+
266
+ return {
267
+ 'detections': detections,
268
+ 'people_count': people_count,
269
+ 'safety_equipment': safety_equipment_detected,
270
+ 'violations': violations,
271
+ 'processing_time': processing_time,
272
+ 'fps': 1.0 / processing_time if processing_time > 0 else 0
273
+ }
274
+
275
+ def draw_detections(self, frame: np.ndarray, results: Dict) -> np.ndarray:
276
+ """
277
+ Draw premium bounding boxes only for POSITIVE equipment detections.
278
+ No boxes for missing equipment - violations shown through person status only.
279
+
280
+ Args:
281
+ frame: Input frame
282
+ results: Detection results containing detections, violations, etc.
283
+
284
+ Returns:
285
+ Annotated frame with premium styling
286
+ """
287
+ annotated_frame = frame.copy()
288
+ height, width = annotated_frame.shape[:2]
289
+
290
+ # Create overlay for semi-transparent effects
291
+ overlay = annotated_frame.copy()
292
+
293
+ # Premium color scheme
294
+ colors = {
295
+ 'person_compliant': (46, 204, 113), # Emerald green
296
+ 'person_violation': (231, 76, 60), # Red
297
+ 'equipment': (52, 152, 219), # Blue
298
+ 'hardhat': (46, 204, 113), # Green
299
+ 'safety_vest': (241, 196, 15), # Yellow
300
+ 'mask': (0, 191, 255), # Deep sky blue
301
+ 'violation_bg': (231, 76, 60), # Red background
302
+ 'text_bg': (44, 62, 80), # Dark blue-gray
303
+ 'text_primary': (255, 255, 255), # White
304
+ 'text_secondary': (149, 165, 166), # Light gray
305
+ 'shadow': (0, 0, 0), # Black shadow
306
+ 'accent': (155, 89, 182), # Purple accent
307
+ }
308
+
309
+ # Track people and their compliance status
310
+ people_status = {}
311
+
312
+ # First pass: categorize people
313
+ for detection in results.get('detections', []):
314
+ class_name = detection['class'].lower()
315
+ bbox = detection['bbox']
316
+ confidence = detection['confidence']
317
+
318
+ if 'person' in class_name:
319
+ person_id = f"person_{bbox[0]}_{bbox[1]}"
320
+ people_status[person_id] = {
321
+ 'bbox': bbox,
322
+ 'confidence': confidence,
323
+ 'violations': [],
324
+ 'equipment': []
325
+ }
326
+
327
+ # Map violations to people
328
+ for violation in results.get('violations', []):
329
+ if 'bbox' in violation:
330
+ # This is a specific violation with a bounding box (from NO- detections)
331
+ violation_bbox = violation['bbox']
332
+ # Find the closest person to this violation
333
+ closest_person = None
334
+ min_distance = float('inf')
335
+
336
+ for person_id, person_data in people_status.items():
337
+ person_bbox = person_data['bbox']
338
+ # Calculate distance between violation and person
339
+ distance = abs(violation_bbox[0] - person_bbox[0]) + abs(violation_bbox[1] - person_bbox[1])
340
+ if distance < min_distance:
341
+ min_distance = distance
342
+ closest_person = person_id
343
+
344
+ if closest_person and min_distance < 100: # Within reasonable distance
345
+ violation_type = violation['type'].replace('missing_', '')
346
+ people_status[closest_person]['violations'].append(violation_type)
347
+ else:
348
+ # General violation - apply to all people (when equipment count < people count)
349
+ violation_type = violation['type'].replace('missing_', '')
350
+ for person_id in people_status:
351
+ people_status[person_id]['violations'].append(violation_type)
352
+
353
+ # If no specific violations detected but people are present, assume they're missing all required equipment
354
+ if len(people_status) > 0 and len(results.get('violations', [])) == 0:
355
+ # Check if we have any positive equipment detections
356
+ equipment_detected = any(
357
+ detection['category'] in ['hardhat', 'safety_vest', 'mask']
358
+ for detection in results.get('detections', [])
359
+ if detection['category'] in ['hardhat', 'safety_vest', 'mask']
360
+ )
361
+
362
+ # If no equipment detected at all, mark all people as having violations
363
+ if not equipment_detected:
364
+ for person_id in people_status:
365
+ people_status[person_id]['violations'] = ['hardhat', 'safety_vest', 'mask']
366
+
367
+ # ONLY draw POSITIVE equipment detections (when equipment IS being worn)
368
+ for detection in results.get('detections', []):
369
+ class_name = detection['class'].lower()
370
+ category = detection.get('category', '')
371
+
372
+ # Skip people and NO- detections - we only want positive equipment
373
+ if 'person' in class_name or 'no-' in class_name or 'no_' in category:
374
+ continue
375
+
376
+ # Only draw positive equipment detections
377
+ if category in ['hardhat', 'safety_vest', 'mask'] or any(equip in class_name for equip in ['hardhat', 'vest', 'helmet', 'safety', 'mask']):
378
+ bbox = detection['bbox']
379
+ confidence = detection['confidence']
380
+
381
+ # Choose color and label based on equipment type
382
+ if any(x in class_name for x in ['hardhat', 'helmet']) or category == 'hardhat':
383
+ color = colors['hardhat']
384
+ equipment_type = "Hard Hat ✓"
385
+ elif 'vest' in class_name or category == 'safety_vest':
386
+ color = colors['safety_vest']
387
+ equipment_type = "Safety Vest ✓"
388
+ elif 'mask' in class_name or category == 'mask':
389
+ color = colors['mask']
390
+ equipment_type = "Face Mask ✓"
391
+ else:
392
+ color = colors['equipment']
393
+ equipment_type = "Safety Equipment ✓"
394
+
395
+ # Draw equipment with premium styling
396
+ self._draw_premium_bbox(overlay, annotated_frame, bbox, color,
397
+ equipment_type, confidence,
398
+ bbox_type="equipment", colors=colors)
399
+
400
+ # Draw people with compliance status (no violation indicators on person boxes)
401
+ for person_id, person_data in people_status.items():
402
+ bbox = person_data['bbox']
403
+ confidence = person_data['confidence']
404
+ violations = person_data['violations']
405
+
406
+ # Determine person status
407
+ is_compliant = len(violations) == 0
408
+ color = colors['person_compliant'] if is_compliant else colors['person_violation']
409
+ status_text = "COMPLIANT" if is_compliant else "VIOLATION"
410
+
411
+ # Draw person with premium styling (no violation details on the box)
412
+ self._draw_premium_bbox(overlay, annotated_frame, bbox, color,
413
+ f"Person - {status_text}", confidence,
414
+ bbox_type="person", violations=None, # Don't show violation details on person box
415
+ colors=colors)
416
+
417
+ # Blend overlay with original frame for semi-transparent effects
418
+ alpha = 0.15
419
+ cv2.addWeighted(overlay, alpha, annotated_frame, 1 - alpha, 0, annotated_frame)
420
+
421
+ # Statistics are now handled by the web UI, no overlay needed on video feed
422
+
423
+ return annotated_frame
424
+
425
+ def _draw_premium_bbox(self, overlay, frame, bbox, color, label, confidence,
426
+ bbox_type="default", violations=None, colors=None):
427
+ """Draw a premium-styled bounding box with advanced visual effects."""
428
+ x1, y1, x2, y2 = map(int, bbox)
429
+
430
+ # Box dimensions
431
+ box_width = x2 - x1
432
+ box_height = y2 - y1
433
+
434
+ # Draw shadow first (slightly offset)
435
+ shadow_offset = 3
436
+ shadow_color = colors['shadow']
437
+ cv2.rectangle(overlay,
438
+ (x1 + shadow_offset, y1 + shadow_offset),
439
+ (x2 + shadow_offset, y2 + shadow_offset),
440
+ shadow_color, 2)
441
+
442
+ # Main bounding box with thinner lines
443
+ box_thickness = 2 if bbox_type == "person" else 1
444
+
445
+ # Draw main rectangle
446
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, box_thickness)
447
+
448
+ # Draw corner accents for premium look
449
+ corner_length = min(20, box_width // 4, box_height // 4)
450
+ accent_thickness = box_thickness
451
+
452
+ # Top-left corner
453
+ cv2.line(frame, (x1, y1), (x1 + corner_length, y1), color, accent_thickness)
454
+ cv2.line(frame, (x1, y1), (x1, y1 + corner_length), color, accent_thickness)
455
+
456
+ # Top-right corner
457
+ cv2.line(frame, (x2, y1), (x2 - corner_length, y1), color, accent_thickness)
458
+ cv2.line(frame, (x2, y1), (x2, y1 + corner_length), color, accent_thickness)
459
+
460
+ # Bottom-left corner
461
+ cv2.line(frame, (x1, y2), (x1 + corner_length, y2), color, accent_thickness)
462
+ cv2.line(frame, (x1, y2), (x1, y2 - corner_length), color, accent_thickness)
463
+
464
+ # Bottom-right corner
465
+ cv2.line(frame, (x2, y2), (x2 - corner_length, y2), color, accent_thickness)
466
+ cv2.line(frame, (x2, y2), (x2, y2 - corner_length), color, accent_thickness)
467
+
468
+ # Prepare label text
469
+ confidence_text = f"{confidence:.1%}"
470
+ main_text = f"{label}"
471
+
472
+ # Calculate text dimensions
473
+ font = cv2.FONT_HERSHEY_SIMPLEX
474
+ font_scale = 0.5
475
+ thickness = 1
476
+
477
+ (main_w, main_h), _ = cv2.getTextSize(main_text, font, font_scale, thickness)
478
+ (conf_w, conf_h), _ = cv2.getTextSize(confidence_text, font, font_scale - 0.1, thickness - 1)
479
+
480
+ # Label background dimensions
481
+ label_height = max(main_h, conf_h) + 12
482
+ label_width = max(main_w, conf_w) + 16
483
+
484
+ # Position label (above box if space available, otherwise below)
485
+ if y1 - label_height - 5 > 0:
486
+ label_y = y1 - label_height - 5
487
+ else:
488
+ label_y = y2 + 5
489
+
490
+ label_x = x1
491
+
492
+ # Ensure label stays within frame
493
+ if label_x + label_width > frame.shape[1]:
494
+ label_x = frame.shape[1] - label_width - 5
495
+ if label_x < 0:
496
+ label_x = 5
497
+
498
+ # Draw label background with gradient effect
499
+ bg_color = colors['text_bg']
500
+
501
+ # Main background
502
+ cv2.rectangle(overlay,
503
+ (label_x, label_y),
504
+ (label_x + label_width, label_y + label_height),
505
+ bg_color, -1)
506
+
507
+ # Colored top border
508
+ cv2.rectangle(frame,
509
+ (label_x, label_y),
510
+ (label_x + label_width, label_y + 4),
511
+ color, -1)
512
+
513
+ # Add subtle border
514
+ cv2.rectangle(frame,
515
+ (label_x, label_y),
516
+ (label_x + label_width, label_y + label_height),
517
+ color, 1)
518
+
519
+ # Draw main text
520
+ text_y = label_y + main_h + 6
521
+ cv2.putText(frame, main_text,
522
+ (label_x + 8, text_y),
523
+ font, font_scale, colors['text_primary'], thickness)
524
+
525
+ # Draw confidence text
526
+ conf_y = text_y + conf_h + 4
527
+ cv2.putText(frame, confidence_text,
528
+ (label_x + 8, conf_y),
529
+ font, font_scale - 0.1, colors['text_secondary'], max(1, thickness - 1))
530
+
531
+ # Draw violation indicators for people (only if violations are provided)
532
+ if bbox_type == "person" and violations is not None and len(violations) > 0:
533
+ self._draw_violation_indicators(frame, overlay, x1, y1, x2, y2, violations, colors)
534
+
535
+ def _draw_violation_indicators(self, frame, overlay, x1, y1, x2, y2, violations, colors):
536
+ """Draw violation indicators with premium styling."""
537
+ # Warning icon position (top-right of bounding box)
538
+ icon_size = 24
539
+ icon_x = x2 - icon_size - 5
540
+ icon_y = y1 + 5
541
+
542
+ # Draw warning background circle
543
+ cv2.circle(overlay, (icon_x + icon_size//2, icon_y + icon_size//2),
544
+ icon_size//2, colors['violation_bg'], -1)
545
+ cv2.circle(frame, (icon_x + icon_size//2, icon_y + icon_size//2),
546
+ icon_size//2, colors['violation_bg'], 2)
547
+
548
+ # Draw exclamation mark
549
+ center_x = icon_x + icon_size//2
550
+ center_y = icon_y + icon_size//2
551
+
552
+ # Exclamation line
553
+ cv2.line(frame, (center_x, center_y - 6), (center_x, center_y + 2),
554
+ colors['text_primary'], 2)
555
+ # Exclamation dot
556
+ cv2.circle(frame, (center_x, center_y + 5), 1, colors['text_primary'], -1)
557
+
558
+ # Draw violation list below the person if space allows
559
+ violation_text = "Missing: " + ", ".join(violations)
560
+ font = cv2.FONT_HERSHEY_SIMPLEX
561
+ font_scale = 0.5
562
+ thickness = 1
563
+
564
+ (text_w, text_h), _ = cv2.getTextSize(violation_text, font, font_scale, thickness)
565
+
566
+ # Position violation text
567
+ viol_x = x1
568
+ viol_y = y2 + text_h + 8
569
+
570
+ # Ensure text stays within frame
571
+ if viol_y + text_h > frame.shape[0]:
572
+ viol_y = y1 - text_h - 8
573
+ if viol_x + text_w > frame.shape[1]:
574
+ viol_x = frame.shape[1] - text_w - 5
575
+
576
+ # Draw violation text background
577
+ padding = 4
578
+ cv2.rectangle(overlay,
579
+ (viol_x - padding, viol_y - text_h - padding),
580
+ (viol_x + text_w + padding, viol_y + padding),
581
+ colors['violation_bg'], -1)
582
+
583
+ # Draw violation text
584
+ cv2.putText(frame, violation_text,
585
+ (viol_x, viol_y),
586
+ font, font_scale, colors['text_primary'], thickness)
587
+
588
+ def _draw_statistics_overlay(self, frame, results, colors, width, height):
589
+ """Draw statistics overlay with premium styling."""
590
+ # Statistics data
591
+ people_count = results.get('people_count', 0)
592
+ violations = results.get('violations', [])
593
+ violation_count = len(violations)
594
+ compliant_count = people_count - violation_count
595
+ compliance_rate = (compliant_count / max(people_count, 1)) * 100
596
+
597
+ # Statistics text
598
+ stats = [
599
+ f"People: {people_count}",
600
+ f"Compliant: {compliant_count}",
601
+ f"Violations: {violation_count}",
602
+ f"Compliance: {compliance_rate:.1f}%"
603
+ ]
604
+
605
+ # Text properties
606
+ font = cv2.FONT_HERSHEY_SIMPLEX
607
+ font_scale = 0.7
608
+ thickness = 2
609
+
610
+ # Calculate background size
611
+ max_text_width = 0
612
+ total_height = 0
613
+ line_heights = []
614
+
615
+ for text in stats:
616
+ (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness)
617
+ max_text_width = max(max_text_width, text_w)
618
+ line_heights.append(text_h)
619
+ total_height += text_h + 8
620
+
621
+ # Background dimensions
622
+ bg_width = max_text_width + 24
623
+ bg_height = total_height + 16
624
+
625
+ # Position (top-left corner)
626
+ bg_x = 20
627
+ bg_y = 20
628
+
629
+ # Draw semi-transparent background
630
+ overlay = frame.copy()
631
+ cv2.rectangle(overlay,
632
+ (bg_x, bg_y),
633
+ (bg_x + bg_width, bg_y + bg_height),
634
+ colors['text_bg'], -1)
635
+ cv2.addWeighted(overlay, 0.8, frame, 0.2, 0, frame)
636
+
637
+ # Draw border
638
+ cv2.rectangle(frame,
639
+ (bg_x, bg_y),
640
+ (bg_x + bg_width, bg_y + bg_height),
641
+ colors['accent'], 2)
642
+
643
+ # Draw statistics text
644
+ current_y = bg_y + 24
645
+ for i, text in enumerate(stats):
646
+ # Choose color based on statistic type
647
+ if "Violations:" in text and violation_count > 0:
648
+ text_color = colors['person_violation']
649
+ elif "Compliant:" in text:
650
+ text_color = colors['person_compliant']
651
+ elif "Compliance:" in text:
652
+ if compliance_rate >= 80:
653
+ text_color = colors['person_compliant']
654
+ elif compliance_rate >= 60:
655
+ text_color = colors['safety_vest']
656
+ else:
657
+ text_color = colors['person_violation']
658
+ else:
659
+ text_color = colors['text_primary']
660
+
661
+ cv2.putText(frame, text,
662
+ (bg_x + 12, current_y),
663
+ font, font_scale, text_color, thickness)
664
+ current_y += line_heights[i] + 8
665
+
666
+ def get_model_classes(self) -> List[str]:
667
+ """Get the list of classes the model can detect."""
668
+ return self._get_model_classes()
669
+
670
+ def test_detection(self, test_image_path: str = None):
671
+ """Test the detector with a sample image or webcam."""
672
+ if test_image_path and os.path.exists(test_image_path):
673
+ frame = cv2.imread(test_image_path)
674
+ if frame is not None:
675
+ results = self.detect_safety_violations(frame)
676
+ output = self.draw_detections(frame, results)
677
+
678
+ print(f"Detected classes: {[d['class'] for d in results['detections']]}")
679
+ print(f"Available model classes: {self.get_model_classes()}")
680
+
681
+ cv2.imshow('PPE Detection Test', output)
682
+ cv2.waitKey(0)
683
+ cv2.destroyAllWindows()
684
+ return results
685
+ else:
686
+ print("Testing with webcam - press 'q' to quit")
687
+ cap = cv2.VideoCapture(0)
688
+
689
+ while True:
690
+ ret, frame = cap.read()
691
+ if not ret:
692
+ break
693
+
694
+ results = self.detect_safety_violations(frame)
695
+ output = self.draw_detections(frame, results)
696
+
697
+ cv2.imshow('PPE Detection Test', output)
698
+
699
+ if cv2.waitKey(1) & 0xFF == ord('q'):
700
+ break
701
+
702
+ cap.release()
703
+ cv2.destroyAllWindows()
704
+
705
+ def analyze_safety_compliance(self, detections: List[Dict]) -> Dict:
706
+ """
707
+ Analyze safety compliance based on detected objects.
708
+
709
+ Args:
710
+ detections: List of detected objects
711
+
712
+ Returns:
713
+ Dictionary with compliance analysis
714
+ """
715
+ people_detected = []
716
+ safety_equipment = []
717
+
718
+ # Separate people and safety equipment
719
+ for detection in detections:
720
+ if detection['class'].lower() == 'person':
721
+ people_detected.append(detection)
722
+ elif any(equipment in detection['class'].lower()
723
+ for equipment in ['helmet', 'hardhat', 'vest', 'gloves', 'glasses']):
724
+ safety_equipment.append(detection)
725
+
726
+ # Analyze compliance for each person
727
+ compliance_results = []
728
+ for person in people_detected:
729
+ person_bbox = person['bbox']
730
+
731
+ # Check for nearby safety equipment
732
+ nearby_equipment = self._find_nearby_equipment(person_bbox, safety_equipment)
733
+
734
+ # Determine missing equipment
735
+ required_equipment = ['hardhat', 'safety_vest']
736
+ missing_equipment = []
737
+
738
+ for equipment in required_equipment:
739
+ if not any(equipment.lower() in item['class'].lower()
740
+ for item in nearby_equipment):
741
+ missing_equipment.append(equipment)
742
+
743
+ compliance_results.append({
744
+ 'person': person,
745
+ 'nearby_equipment': nearby_equipment,
746
+ 'missing_equipment': missing_equipment,
747
+ 'is_compliant': len(missing_equipment) == 0,
748
+ 'compliance_score': 1.0 - (len(missing_equipment) / len(required_equipment))
749
+ })
750
+
751
+ return {
752
+ 'total_people': len(people_detected),
753
+ 'compliant_people': sum(1 for result in compliance_results if result['is_compliant']),
754
+ 'violations': sum(len(result['missing_equipment']) for result in compliance_results),
755
+ 'compliance_results': compliance_results,
756
+ 'overall_compliance_rate': (
757
+ sum(result['compliance_score'] for result in compliance_results) /
758
+ max(len(compliance_results), 1)
759
+ )
760
+ }
761
+
762
+ def _find_nearby_equipment(self, person_bbox: List[int], equipment_list: List[Dict],
763
+ proximity_threshold: float = 0.3) -> List[Dict]:
764
+ """Find safety equipment near a person."""
765
+ nearby_equipment = []
766
+
767
+ person_center_x = (person_bbox[0] + person_bbox[2]) / 2
768
+ person_center_y = (person_bbox[1] + person_bbox[3]) / 2
769
+
770
+ for equipment in equipment_list:
771
+ equip_bbox = equipment['bbox']
772
+ equip_center_x = (equip_bbox[0] + equip_bbox[2]) / 2
773
+ equip_center_y = (equip_bbox[1] + equip_bbox[3]) / 2
774
+
775
+ # Calculate normalized distance
776
+ distance = np.sqrt((person_center_x - equip_center_x)**2 +
777
+ (person_center_y - equip_center_y)**2)
778
+
779
+ # Normalize by image diagonal (assuming standard frame size)
780
+ normalized_distance = distance / 1000 # Adjust based on typical frame size
781
+
782
+ if normalized_distance < proximity_threshold:
783
+ nearby_equipment.append(equipment)
784
+
785
+ return nearby_equipment
786
+
787
+ def draw_annotations(self, frame: np.ndarray, analysis: Dict) -> np.ndarray:
788
+ """
789
+ Draw bounding boxes and annotations on the frame.
790
+
791
+ Args:
792
+ frame: Input frame
793
+ analysis: Safety compliance analysis results
794
+
795
+ Returns:
796
+ Annotated frame
797
+ """
798
+ annotated_frame = frame.copy()
799
+
800
+ # Draw safety equipment
801
+ for equipment in analysis['safety_equipment']:
802
+ bbox = equipment['bbox']
803
+ cv2.rectangle(annotated_frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
804
+ self.colors['equipment'], 2)
805
+
806
+ label = f"{equipment.get('equipment_type', equipment['class'])}: {equipment['confidence']:.2f}"
807
+ cv2.putText(annotated_frame, label, (bbox[0], bbox[1] - 10),
808
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['equipment'], 2)
809
+
810
+ # Draw people with compliance status
811
+ for result in analysis['compliance_results']:
812
+ person = result['person']
813
+ bbox = person['bbox']
814
+
815
+ # Choose color based on compliance
816
+ color = self.colors['person'] if result['is_compliant'] else self.colors['violation']
817
+
818
+ # Draw bounding box
819
+ cv2.rectangle(annotated_frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 3)
820
+
821
+ # Create status label
822
+ status = "COMPLIANT" if result['is_compliant'] else "VIOLATION"
823
+ confidence_text = f"Person: {person['confidence']:.2f}"
824
+
825
+ # Draw labels
826
+ cv2.putText(annotated_frame, status, (bbox[0], bbox[1] - 30),
827
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
828
+ cv2.putText(annotated_frame, confidence_text, (bbox[0], bbox[1] - 10),
829
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
830
+
831
+ # Show missing equipment
832
+ if result['missing_equipment']:
833
+ missing_text = f"Missing: {', '.join(result['missing_equipment'])}"
834
+ cv2.putText(annotated_frame, missing_text, (bbox[0], bbox[3] + 20),
835
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['violation'], 2)
836
+
837
+ # Draw summary statistics
838
+ summary_text = [
839
+ f"Total People: {analysis['total_people']}",
840
+ f"Compliant: {analysis['compliant_people']}",
841
+ f"Violations: {analysis['violations']}",
842
+ f"Compliance Rate: {(analysis['compliant_people']/max(analysis['total_people'],1)*100):.1f}%"
843
+ ]
844
+
845
+ for i, text in enumerate(summary_text):
846
+ cv2.putText(annotated_frame, text, (10, 30 + i * 25),
847
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
848
+
849
+ return annotated_frame
850
+
851
+ def capture_violation(self, frame: np.ndarray, violation_data: Dict) -> str:
852
+ """
853
+ Capture and save an image when a safety violation is detected.
854
+
855
+ Args:
856
+ frame: Current frame
857
+ violation_data: Information about the violation
858
+
859
+ Returns:
860
+ Path to saved image
861
+ """
862
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
863
+ filename = f"violation_{timestamp}.jpg"
864
+ filepath = os.path.join(self.violation_images_dir, filename)
865
+
866
+ # Save the frame
867
+ cv2.imwrite(filepath, frame)
868
+
869
+ # Save violation metadata
870
+ metadata = {
871
+ 'timestamp': datetime.now().isoformat(),
872
+ 'filename': filename,
873
+ 'violation_data': violation_data
874
+ }
875
+
876
+ metadata_file = filepath.replace('.jpg', '_metadata.json')
877
+ with open(metadata_file, 'w') as f:
878
+ json.dump(metadata, f, indent=2)
879
+
880
+ self.violations.append(metadata)
881
+ return filepath
882
+
883
+ def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
884
+ """
885
+ Process a single frame for safety monitoring.
886
+
887
+ Args:
888
+ frame: Input video frame
889
+
890
+ Returns:
891
+ Tuple of (annotated_frame, analysis_results)
892
+ """
893
+ # Detect objects and get safety violations
894
+ results = self.detect_safety_violations(frame)
895
+
896
+ # Draw detections on frame using the main drawing method
897
+ annotated_frame = self.draw_detections(frame, results)
898
+
899
+ return annotated_frame, {
900
+ 'detections': results['detections'],
901
+ 'people_count': results['people_count'],
902
+ 'safety_equipment': results['safety_equipment'],
903
+ 'violations': results['violations'],
904
+ 'violation_summary': self.get_violation_summary(),
905
+ 'frame_stats': {
906
+ 'processing_time': results['processing_time'],
907
+ 'fps': results['fps'],
908
+ 'detection_count': len(results['detections'])
909
+ }
910
+ }
911
+
912
+ def get_violation_summary(self) -> Dict:
913
+ """Get a summary of recent violations."""
914
+ # This would typically connect to a database or log file
915
+ # For now, return a placeholder
916
+ return {
917
+ 'total_violations_today': 0,
918
+ 'most_common_violation': 'missing_hardhat',
919
+ 'compliance_trend': [] # Could track compliance over time
920
+ }
921
+
922
+ if __name__ == "__main__":
923
+ # Test the detector
924
+ detector = SafetyDetector()
925
+ print("Available classes:", detector.get_model_classes())
926
+ detector.test_detection()