Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import torch | |
| import time | |
| from datetime import datetime | |
| import os | |
| import json | |
| from threading import Thread | |
| import queue | |
| from typing import Dict, List, Tuple, Optional | |
| import requests | |
| class SafetyDetector: | |
| """ | |
| Real-time safety compliance detection system using YOLO for object detection. | |
| Detects people and safety equipment like hard hats, safety vests, and safety glasses. | |
| """ | |
| def __init__(self, model_path: Optional[str] = None, confidence_threshold: float = 0.5): | |
| """ | |
| Initialize the safety detector with a specialized PPE detection model. | |
| Args: | |
| model_path: Path to custom model, if None will download PPE model | |
| confidence_threshold: Minimum confidence for detections | |
| """ | |
| self.confidence_threshold = confidence_threshold | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Stricter confidence thresholds for different equipment types to reduce false positives | |
| self.equipment_confidence_thresholds = { | |
| 'hardhat': 0.7, # Higher threshold for hard hats (hair confusion) | |
| 'safety_vest': 0.75, # Higher threshold for safety vests (clothing confusion) | |
| 'mask': 0.6, # Moderate threshold for masks | |
| 'person': 0.5, # Standard threshold for people | |
| 'no_hardhat': 0.6, # Moderate threshold for NO- detections | |
| 'no_safety_vest': 0.6, | |
| 'no_mask': 0.6 | |
| } | |
| # Try to load a specialized PPE detection model | |
| self.model = self._load_ppe_model(model_path) | |
| # PPE class names - these are the actual classes we expect from PPE models | |
| self.ppe_classes = { | |
| 'hardhat': ['Hardhat', 'hardhat', 'helmet', 'hard hat'], | |
| 'safety_vest': ['Safety Vest', 'safety vest', 'vest', 'safety-vest', 'Safety-Vest'], | |
| 'no_hardhat': ['NO-Hardhat', 'no-hardhat', 'no hardhat', 'NO-Helmet'], | |
| 'no_safety_vest': ['NO-Safety Vest', 'no-safety-vest', 'no safety vest', 'NO-Safety-Vest'], | |
| 'person': ['Person', 'person'], | |
| 'mask': ['Mask', 'mask'], | |
| 'no_mask': ['NO-Mask', 'no-mask', 'no mask'], | |
| 'safety_gloves': ['Safety Gloves', 'safety-gloves', 'gloves', 'Gloves'], | |
| 'safety_glasses': ['Safety Glasses', 'safety-glasses', 'glasses', 'Safety-Glasses'], | |
| 'hearing_protection': ['Hearing Protection', 'hearing-protection', 'ear protection'] | |
| } | |
| print(f"Using device: {self.device}") | |
| print(f"Loaded PPE detection model with stricter confidence thresholds") | |
| print(f"Equipment thresholds: {self.equipment_confidence_thresholds}") | |
| # Colors for bounding boxes | |
| self.colors = { | |
| 'person': (0, 255, 0), # Green for compliant person | |
| 'violation': (0, 0, 255), # Red for safety violation | |
| 'equipment': (255, 255, 0), # Yellow for safety equipment | |
| 'warning': (0, 165, 255) # Orange for warnings | |
| } | |
| # Violation tracking | |
| self.violations = [] | |
| self.violation_images_dir = "violation_captures" | |
| os.makedirs(self.violation_images_dir, exist_ok=True) | |
| def _load_ppe_model(self, model_path: Optional[str] = None) -> YOLO: | |
| """Load a specialized PPE detection model.""" | |
| if model_path and os.path.exists(model_path): | |
| print(f"Loading custom model from {model_path}") | |
| return YOLO(model_path) | |
| # Try to download YOLOv8-compatible PPE models | |
| ppe_model_urls = [ | |
| # Try the snehilsanyal YOLOv8 PPE model (best.pt) | |
| "https://github.com/snehilsanyal/Construction-Site-Safety-PPE-Detection/raw/main/models/best.pt", | |
| # Try mayank13-01 YOLOv8 PPE model | |
| "https://github.com/mayank13-01/Yolov8-PPE/raw/main/YOLO-Weights/ppe.pt" | |
| ] | |
| for i, url in enumerate(ppe_model_urls): | |
| try: | |
| model_filename = f"ppe_yolov8_model_{i}.pt" | |
| if not os.path.exists(model_filename): | |
| print(f"Downloading PPE detection model from {url}...") | |
| response = requests.get(url, timeout=60) | |
| if response.status_code == 200: | |
| with open(model_filename, 'wb') as f: | |
| f.write(response.content) | |
| print(f"Downloaded PPE model successfully as {model_filename}") | |
| if os.path.exists(model_filename): | |
| print(f"Loading YOLOv8 PPE model from {model_filename}") | |
| model = YOLO(model_filename) | |
| # Test if the model loads properly | |
| classes = self._get_model_classes(model) | |
| print(f"Model classes: {classes}") | |
| # Check if it has PPE-related classes | |
| ppe_related = any( | |
| any(keyword in str(cls).lower() for keyword in ['hardhat', 'vest', 'helmet', 'mask', 'person']) | |
| for cls in classes | |
| ) | |
| if ppe_related: | |
| print(f"✅ Found PPE-capable model with {len(classes)} classes") | |
| return model | |
| else: | |
| print(f"⚠️ Model doesn't seem to have PPE classes: {classes}") | |
| except Exception as e: | |
| print(f"Failed to download/load from {url}: {e}") | |
| continue | |
| # Fallback to YOLOv8 with a warning | |
| print("⚠️ Warning: Could not load specialized PPE model, falling back to YOLOv8n") | |
| print(" Note: YOLOv8n can detect people but not safety equipment") | |
| return YOLO('yolov8n.pt') | |
| def _get_model_classes(self, model=None) -> List[str]: | |
| """Get the list of classes the model can detect.""" | |
| if model is None: | |
| model = self.model | |
| if hasattr(model, 'names'): | |
| return list(model.names.values()) | |
| return [] | |
| def _get_class_category(self, class_name: str) -> str: | |
| """Map detected class name to our safety categories.""" | |
| class_name_lower = class_name.lower() | |
| for category, variations in self.ppe_classes.items(): | |
| for variation in variations: | |
| if variation.lower() in class_name_lower or class_name_lower in variation.lower(): | |
| return category | |
| return class_name_lower | |
| def detect_safety_violations(self, frame: np.ndarray) -> Dict: | |
| """ | |
| Detect safety violations in the given frame with improved accuracy. | |
| Returns: | |
| Dictionary containing detection results and violations | |
| """ | |
| start_time = time.time() | |
| # Run detection with optimized settings for speed | |
| results = self.model(frame, conf=0.3, verbose=False, imgsz=640, half=False) | |
| detections = [] | |
| people_count = 0 | |
| safety_equipment_detected = { | |
| 'hardhat': 0, | |
| 'safety_vest': 0, | |
| 'safety_gloves': 0, | |
| 'safety_glasses': 0, | |
| 'hearing_protection': 0, | |
| 'mask': 0 | |
| } | |
| violations = [] | |
| no_equipment_detections = [] # Track NO- detections separately | |
| # Process detections with stricter filtering | |
| for r in results: | |
| boxes = r.boxes | |
| if boxes is not None: | |
| for box in boxes: | |
| # Get detection info | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| confidence = box.conf[0].cpu().numpy() | |
| class_id = int(box.cls[0].cpu().numpy()) | |
| # Get class name | |
| if hasattr(self.model, 'names'): | |
| class_name = self.model.names[class_id] | |
| else: | |
| class_name = f"class_{class_id}" | |
| # Map to our categories | |
| category = self._get_class_category(class_name) | |
| # Apply stricter confidence thresholds based on equipment type | |
| required_confidence = self.equipment_confidence_thresholds.get(category, self.confidence_threshold) | |
| # Skip detections that don't meet the stricter threshold | |
| if confidence < required_confidence: | |
| continue | |
| detection = { | |
| 'bbox': [int(x1), int(y1), int(x2), int(y2)], | |
| 'confidence': float(confidence), | |
| 'class': class_name, | |
| 'category': category | |
| } | |
| detections.append(detection) | |
| # Count people and safety equipment | |
| if category == 'person': | |
| people_count += 1 | |
| elif category in safety_equipment_detected: | |
| safety_equipment_detected[category] += 1 | |
| elif category in ['hardhat', 'safety_vest', 'mask'] and not category.startswith('no_'): | |
| safety_equipment_detected[category] += 1 | |
| # Handle negative detections (NO-Hardhat, NO-Mask, etc.) | |
| # These indicate violations - a person without required equipment | |
| if category.startswith('no_'): | |
| equipment_type = category.replace('no_', '') | |
| if equipment_type in ['hardhat', 'safety_vest', 'mask']: | |
| no_equipment_detections.append({ | |
| 'type': f'missing_{equipment_type}', | |
| 'severity': 'high', | |
| 'description': f'Person detected without {equipment_type.replace("_", " ").title()}', | |
| 'bbox': [int(x1), int(y1), int(x2), int(y2)], | |
| 'confidence': float(confidence), | |
| 'equipment_type': equipment_type | |
| }) | |
| # Create violations based on NO- detections (these are more reliable) | |
| violations.extend(no_equipment_detections) | |
| # If we have people but no NO- detections, check equipment ratios | |
| if people_count > 0 and len(no_equipment_detections) == 0: | |
| required_equipment = ['hardhat', 'safety_vest', 'mask'] | |
| for equipment in required_equipment: | |
| detected_count = safety_equipment_detected[equipment] | |
| # If significantly fewer equipment than people, assume violations | |
| if detected_count < people_count * 0.8: # Allow some tolerance | |
| missing_count = people_count - detected_count | |
| equipment_name = equipment.replace("_", " ").title() | |
| violations.append({ | |
| 'type': f'missing_{equipment}', | |
| 'severity': 'high', | |
| 'description': f'{missing_count} person(s) likely missing {equipment_name}', | |
| 'count': missing_count | |
| }) | |
| # Special handling for masks - they're often not detected well | |
| mask_detected = safety_equipment_detected['mask'] | |
| no_mask_detected = len([v for v in no_equipment_detections if v['equipment_type'] == 'mask']) | |
| if people_count > 0 and mask_detected == 0 and no_mask_detected == 0: | |
| # No mask detections at all - assume people are not wearing masks | |
| violations.append({ | |
| 'type': 'missing_mask', | |
| 'severity': 'high', | |
| 'description': f'{people_count} person(s) not wearing Face Mask', | |
| 'count': people_count | |
| }) | |
| processing_time = time.time() - start_time | |
| return { | |
| 'detections': detections, | |
| 'people_count': people_count, | |
| 'safety_equipment': safety_equipment_detected, | |
| 'violations': violations, | |
| 'processing_time': processing_time, | |
| 'fps': 1.0 / processing_time if processing_time > 0 else 0 | |
| } | |
| def draw_detections(self, frame: np.ndarray, results: Dict) -> np.ndarray: | |
| """ | |
| Draw premium bounding boxes only for POSITIVE equipment detections. | |
| No boxes for missing equipment - violations shown through person status only. | |
| Args: | |
| frame: Input frame | |
| results: Detection results containing detections, violations, etc. | |
| Returns: | |
| Annotated frame with premium styling | |
| """ | |
| annotated_frame = frame.copy() | |
| height, width = annotated_frame.shape[:2] | |
| # Create overlay for semi-transparent effects | |
| overlay = annotated_frame.copy() | |
| # Premium color scheme | |
| colors = { | |
| 'person_compliant': (46, 204, 113), # Emerald green | |
| 'person_violation': (231, 76, 60), # Red | |
| 'equipment': (52, 152, 219), # Blue | |
| 'hardhat': (46, 204, 113), # Green | |
| 'safety_vest': (241, 196, 15), # Yellow | |
| 'mask': (0, 191, 255), # Deep sky blue | |
| 'violation_bg': (231, 76, 60), # Red background | |
| 'text_bg': (44, 62, 80), # Dark blue-gray | |
| 'text_primary': (255, 255, 255), # White | |
| 'text_secondary': (149, 165, 166), # Light gray | |
| 'shadow': (0, 0, 0), # Black shadow | |
| 'accent': (155, 89, 182), # Purple accent | |
| } | |
| # Track people and their compliance status | |
| people_status = {} | |
| # First pass: categorize people | |
| for detection in results.get('detections', []): | |
| class_name = detection['class'].lower() | |
| bbox = detection['bbox'] | |
| confidence = detection['confidence'] | |
| if 'person' in class_name: | |
| person_id = f"person_{bbox[0]}_{bbox[1]}" | |
| people_status[person_id] = { | |
| 'bbox': bbox, | |
| 'confidence': confidence, | |
| 'violations': [], | |
| 'equipment': [] | |
| } | |
| # Map violations to people | |
| for violation in results.get('violations', []): | |
| if 'bbox' in violation: | |
| # This is a specific violation with a bounding box (from NO- detections) | |
| violation_bbox = violation['bbox'] | |
| # Find the closest person to this violation | |
| closest_person = None | |
| min_distance = float('inf') | |
| for person_id, person_data in people_status.items(): | |
| person_bbox = person_data['bbox'] | |
| # Calculate distance between violation and person | |
| distance = abs(violation_bbox[0] - person_bbox[0]) + abs(violation_bbox[1] - person_bbox[1]) | |
| if distance < min_distance: | |
| min_distance = distance | |
| closest_person = person_id | |
| if closest_person and min_distance < 100: # Within reasonable distance | |
| violation_type = violation['type'].replace('missing_', '') | |
| people_status[closest_person]['violations'].append(violation_type) | |
| else: | |
| # General violation - apply to all people (when equipment count < people count) | |
| violation_type = violation['type'].replace('missing_', '') | |
| for person_id in people_status: | |
| people_status[person_id]['violations'].append(violation_type) | |
| # If no specific violations detected but people are present, assume they're missing all required equipment | |
| if len(people_status) > 0 and len(results.get('violations', [])) == 0: | |
| # Check if we have any positive equipment detections | |
| equipment_detected = any( | |
| detection['category'] in ['hardhat', 'safety_vest', 'mask'] | |
| for detection in results.get('detections', []) | |
| if detection['category'] in ['hardhat', 'safety_vest', 'mask'] | |
| ) | |
| # If no equipment detected at all, mark all people as having violations | |
| if not equipment_detected: | |
| for person_id in people_status: | |
| people_status[person_id]['violations'] = ['hardhat', 'safety_vest', 'mask'] | |
| # ONLY draw POSITIVE equipment detections (when equipment IS being worn) | |
| for detection in results.get('detections', []): | |
| class_name = detection['class'].lower() | |
| category = detection.get('category', '') | |
| # Skip people and NO- detections - we only want positive equipment | |
| if 'person' in class_name or 'no-' in class_name or 'no_' in category: | |
| continue | |
| # Only draw positive equipment detections | |
| if category in ['hardhat', 'safety_vest', 'mask'] or any(equip in class_name for equip in ['hardhat', 'vest', 'helmet', 'safety', 'mask']): | |
| bbox = detection['bbox'] | |
| confidence = detection['confidence'] | |
| # Choose color and label based on equipment type | |
| if any(x in class_name for x in ['hardhat', 'helmet']) or category == 'hardhat': | |
| color = colors['hardhat'] | |
| equipment_type = "Hard Hat ✓" | |
| elif 'vest' in class_name or category == 'safety_vest': | |
| color = colors['safety_vest'] | |
| equipment_type = "Safety Vest ✓" | |
| elif 'mask' in class_name or category == 'mask': | |
| color = colors['mask'] | |
| equipment_type = "Face Mask ✓" | |
| else: | |
| color = colors['equipment'] | |
| equipment_type = "Safety Equipment ✓" | |
| # Draw equipment with premium styling | |
| self._draw_premium_bbox(overlay, annotated_frame, bbox, color, | |
| equipment_type, confidence, | |
| bbox_type="equipment", colors=colors) | |
| # Draw people with compliance status (no violation indicators on person boxes) | |
| for person_id, person_data in people_status.items(): | |
| bbox = person_data['bbox'] | |
| confidence = person_data['confidence'] | |
| violations = person_data['violations'] | |
| # Determine person status | |
| is_compliant = len(violations) == 0 | |
| color = colors['person_compliant'] if is_compliant else colors['person_violation'] | |
| status_text = "COMPLIANT" if is_compliant else "VIOLATION" | |
| # Draw person with premium styling (no violation details on the box) | |
| self._draw_premium_bbox(overlay, annotated_frame, bbox, color, | |
| f"Person - {status_text}", confidence, | |
| bbox_type="person", violations=None, # Don't show violation details on person box | |
| colors=colors) | |
| # Blend overlay with original frame for semi-transparent effects | |
| alpha = 0.15 | |
| cv2.addWeighted(overlay, alpha, annotated_frame, 1 - alpha, 0, annotated_frame) | |
| # Statistics are now handled by the web UI, no overlay needed on video feed | |
| return annotated_frame | |
| def _draw_premium_bbox(self, overlay, frame, bbox, color, label, confidence, | |
| bbox_type="default", violations=None, colors=None): | |
| """Draw a premium-styled bounding box with advanced visual effects.""" | |
| x1, y1, x2, y2 = map(int, bbox) | |
| # Box dimensions | |
| box_width = x2 - x1 | |
| box_height = y2 - y1 | |
| # Draw shadow first (slightly offset) | |
| shadow_offset = 3 | |
| shadow_color = colors['shadow'] | |
| cv2.rectangle(overlay, | |
| (x1 + shadow_offset, y1 + shadow_offset), | |
| (x2 + shadow_offset, y2 + shadow_offset), | |
| shadow_color, 2) | |
| # Main bounding box with thinner lines | |
| box_thickness = 2 if bbox_type == "person" else 1 | |
| # Draw main rectangle | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, box_thickness) | |
| # Draw corner accents for premium look | |
| corner_length = min(20, box_width // 4, box_height // 4) | |
| accent_thickness = box_thickness | |
| # Top-left corner | |
| cv2.line(frame, (x1, y1), (x1 + corner_length, y1), color, accent_thickness) | |
| cv2.line(frame, (x1, y1), (x1, y1 + corner_length), color, accent_thickness) | |
| # Top-right corner | |
| cv2.line(frame, (x2, y1), (x2 - corner_length, y1), color, accent_thickness) | |
| cv2.line(frame, (x2, y1), (x2, y1 + corner_length), color, accent_thickness) | |
| # Bottom-left corner | |
| cv2.line(frame, (x1, y2), (x1 + corner_length, y2), color, accent_thickness) | |
| cv2.line(frame, (x1, y2), (x1, y2 - corner_length), color, accent_thickness) | |
| # Bottom-right corner | |
| cv2.line(frame, (x2, y2), (x2 - corner_length, y2), color, accent_thickness) | |
| cv2.line(frame, (x2, y2), (x2, y2 - corner_length), color, accent_thickness) | |
| # Prepare label text | |
| confidence_text = f"{confidence:.1%}" | |
| main_text = f"{label}" | |
| # Calculate text dimensions | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.5 | |
| thickness = 1 | |
| (main_w, main_h), _ = cv2.getTextSize(main_text, font, font_scale, thickness) | |
| (conf_w, conf_h), _ = cv2.getTextSize(confidence_text, font, font_scale - 0.1, thickness - 1) | |
| # Label background dimensions | |
| label_height = max(main_h, conf_h) + 12 | |
| label_width = max(main_w, conf_w) + 16 | |
| # Position label (above box if space available, otherwise below) | |
| if y1 - label_height - 5 > 0: | |
| label_y = y1 - label_height - 5 | |
| else: | |
| label_y = y2 + 5 | |
| label_x = x1 | |
| # Ensure label stays within frame | |
| if label_x + label_width > frame.shape[1]: | |
| label_x = frame.shape[1] - label_width - 5 | |
| if label_x < 0: | |
| label_x = 5 | |
| # Draw label background with gradient effect | |
| bg_color = colors['text_bg'] | |
| # Main background | |
| cv2.rectangle(overlay, | |
| (label_x, label_y), | |
| (label_x + label_width, label_y + label_height), | |
| bg_color, -1) | |
| # Colored top border | |
| cv2.rectangle(frame, | |
| (label_x, label_y), | |
| (label_x + label_width, label_y + 4), | |
| color, -1) | |
| # Add subtle border | |
| cv2.rectangle(frame, | |
| (label_x, label_y), | |
| (label_x + label_width, label_y + label_height), | |
| color, 1) | |
| # Draw main text | |
| text_y = label_y + main_h + 6 | |
| cv2.putText(frame, main_text, | |
| (label_x + 8, text_y), | |
| font, font_scale, colors['text_primary'], thickness) | |
| # Draw confidence text | |
| conf_y = text_y + conf_h + 4 | |
| cv2.putText(frame, confidence_text, | |
| (label_x + 8, conf_y), | |
| font, font_scale - 0.1, colors['text_secondary'], max(1, thickness - 1)) | |
| # Draw violation indicators for people (only if violations are provided) | |
| if bbox_type == "person" and violations is not None and len(violations) > 0: | |
| self._draw_violation_indicators(frame, overlay, x1, y1, x2, y2, violations, colors) | |
| def _draw_violation_indicators(self, frame, overlay, x1, y1, x2, y2, violations, colors): | |
| """Draw violation indicators with premium styling.""" | |
| # Warning icon position (top-right of bounding box) | |
| icon_size = 24 | |
| icon_x = x2 - icon_size - 5 | |
| icon_y = y1 + 5 | |
| # Draw warning background circle | |
| cv2.circle(overlay, (icon_x + icon_size//2, icon_y + icon_size//2), | |
| icon_size//2, colors['violation_bg'], -1) | |
| cv2.circle(frame, (icon_x + icon_size//2, icon_y + icon_size//2), | |
| icon_size//2, colors['violation_bg'], 2) | |
| # Draw exclamation mark | |
| center_x = icon_x + icon_size//2 | |
| center_y = icon_y + icon_size//2 | |
| # Exclamation line | |
| cv2.line(frame, (center_x, center_y - 6), (center_x, center_y + 2), | |
| colors['text_primary'], 2) | |
| # Exclamation dot | |
| cv2.circle(frame, (center_x, center_y + 5), 1, colors['text_primary'], -1) | |
| # Draw violation list below the person if space allows | |
| violation_text = "Missing: " + ", ".join(violations) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.5 | |
| thickness = 1 | |
| (text_w, text_h), _ = cv2.getTextSize(violation_text, font, font_scale, thickness) | |
| # Position violation text | |
| viol_x = x1 | |
| viol_y = y2 + text_h + 8 | |
| # Ensure text stays within frame | |
| if viol_y + text_h > frame.shape[0]: | |
| viol_y = y1 - text_h - 8 | |
| if viol_x + text_w > frame.shape[1]: | |
| viol_x = frame.shape[1] - text_w - 5 | |
| # Draw violation text background | |
| padding = 4 | |
| cv2.rectangle(overlay, | |
| (viol_x - padding, viol_y - text_h - padding), | |
| (viol_x + text_w + padding, viol_y + padding), | |
| colors['violation_bg'], -1) | |
| # Draw violation text | |
| cv2.putText(frame, violation_text, | |
| (viol_x, viol_y), | |
| font, font_scale, colors['text_primary'], thickness) | |
| def _draw_statistics_overlay(self, frame, results, colors, width, height): | |
| """Draw statistics overlay with premium styling.""" | |
| # Statistics data | |
| people_count = results.get('people_count', 0) | |
| violations = results.get('violations', []) | |
| violation_count = len(violations) | |
| compliant_count = people_count - violation_count | |
| compliance_rate = (compliant_count / max(people_count, 1)) * 100 | |
| # Statistics text | |
| stats = [ | |
| f"People: {people_count}", | |
| f"Compliant: {compliant_count}", | |
| f"Violations: {violation_count}", | |
| f"Compliance: {compliance_rate:.1f}%" | |
| ] | |
| # Text properties | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.7 | |
| thickness = 2 | |
| # Calculate background size | |
| max_text_width = 0 | |
| total_height = 0 | |
| line_heights = [] | |
| for text in stats: | |
| (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
| max_text_width = max(max_text_width, text_w) | |
| line_heights.append(text_h) | |
| total_height += text_h + 8 | |
| # Background dimensions | |
| bg_width = max_text_width + 24 | |
| bg_height = total_height + 16 | |
| # Position (top-left corner) | |
| bg_x = 20 | |
| bg_y = 20 | |
| # Draw semi-transparent background | |
| overlay = frame.copy() | |
| cv2.rectangle(overlay, | |
| (bg_x, bg_y), | |
| (bg_x + bg_width, bg_y + bg_height), | |
| colors['text_bg'], -1) | |
| cv2.addWeighted(overlay, 0.8, frame, 0.2, 0, frame) | |
| # Draw border | |
| cv2.rectangle(frame, | |
| (bg_x, bg_y), | |
| (bg_x + bg_width, bg_y + bg_height), | |
| colors['accent'], 2) | |
| # Draw statistics text | |
| current_y = bg_y + 24 | |
| for i, text in enumerate(stats): | |
| # Choose color based on statistic type | |
| if "Violations:" in text and violation_count > 0: | |
| text_color = colors['person_violation'] | |
| elif "Compliant:" in text: | |
| text_color = colors['person_compliant'] | |
| elif "Compliance:" in text: | |
| if compliance_rate >= 80: | |
| text_color = colors['person_compliant'] | |
| elif compliance_rate >= 60: | |
| text_color = colors['safety_vest'] | |
| else: | |
| text_color = colors['person_violation'] | |
| else: | |
| text_color = colors['text_primary'] | |
| cv2.putText(frame, text, | |
| (bg_x + 12, current_y), | |
| font, font_scale, text_color, thickness) | |
| current_y += line_heights[i] + 8 | |
| def get_model_classes(self) -> List[str]: | |
| """Get the list of classes the model can detect.""" | |
| return self._get_model_classes() | |
| def test_detection(self, test_image_path: str = None): | |
| """Test the detector with a sample image or webcam.""" | |
| if test_image_path and os.path.exists(test_image_path): | |
| frame = cv2.imread(test_image_path) | |
| if frame is not None: | |
| results = self.detect_safety_violations(frame) | |
| output = self.draw_detections(frame, results) | |
| print(f"Detected classes: {[d['class'] for d in results['detections']]}") | |
| print(f"Available model classes: {self.get_model_classes()}") | |
| cv2.imshow('PPE Detection Test', output) | |
| cv2.waitKey(0) | |
| cv2.destroyAllWindows() | |
| return results | |
| else: | |
| print("Testing with webcam - press 'q' to quit") | |
| cap = cv2.VideoCapture(0) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| results = self.detect_safety_violations(frame) | |
| output = self.draw_detections(frame, results) | |
| cv2.imshow('PPE Detection Test', output) | |
| if cv2.waitKey(1) & 0xFF == ord('q'): | |
| break | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| def analyze_safety_compliance(self, detections: List[Dict]) -> Dict: | |
| """ | |
| Analyze safety compliance based on detected objects. | |
| Args: | |
| detections: List of detected objects | |
| Returns: | |
| Dictionary with compliance analysis | |
| """ | |
| people_detected = [] | |
| safety_equipment = [] | |
| # Separate people and safety equipment | |
| for detection in detections: | |
| if detection['class'].lower() == 'person': | |
| people_detected.append(detection) | |
| elif any(equipment in detection['class'].lower() | |
| for equipment in ['helmet', 'hardhat', 'vest', 'gloves', 'glasses']): | |
| safety_equipment.append(detection) | |
| # Analyze compliance for each person | |
| compliance_results = [] | |
| for person in people_detected: | |
| person_bbox = person['bbox'] | |
| # Check for nearby safety equipment | |
| nearby_equipment = self._find_nearby_equipment(person_bbox, safety_equipment) | |
| # Determine missing equipment | |
| required_equipment = ['hardhat', 'safety_vest'] | |
| missing_equipment = [] | |
| for equipment in required_equipment: | |
| if not any(equipment.lower() in item['class'].lower() | |
| for item in nearby_equipment): | |
| missing_equipment.append(equipment) | |
| compliance_results.append({ | |
| 'person': person, | |
| 'nearby_equipment': nearby_equipment, | |
| 'missing_equipment': missing_equipment, | |
| 'is_compliant': len(missing_equipment) == 0, | |
| 'compliance_score': 1.0 - (len(missing_equipment) / len(required_equipment)) | |
| }) | |
| return { | |
| 'total_people': len(people_detected), | |
| 'compliant_people': sum(1 for result in compliance_results if result['is_compliant']), | |
| 'violations': sum(len(result['missing_equipment']) for result in compliance_results), | |
| 'compliance_results': compliance_results, | |
| 'overall_compliance_rate': ( | |
| sum(result['compliance_score'] for result in compliance_results) / | |
| max(len(compliance_results), 1) | |
| ) | |
| } | |
| def _find_nearby_equipment(self, person_bbox: List[int], equipment_list: List[Dict], | |
| proximity_threshold: float = 0.3) -> List[Dict]: | |
| """Find safety equipment near a person.""" | |
| nearby_equipment = [] | |
| person_center_x = (person_bbox[0] + person_bbox[2]) / 2 | |
| person_center_y = (person_bbox[1] + person_bbox[3]) / 2 | |
| for equipment in equipment_list: | |
| equip_bbox = equipment['bbox'] | |
| equip_center_x = (equip_bbox[0] + equip_bbox[2]) / 2 | |
| equip_center_y = (equip_bbox[1] + equip_bbox[3]) / 2 | |
| # Calculate normalized distance | |
| distance = np.sqrt((person_center_x - equip_center_x)**2 + | |
| (person_center_y - equip_center_y)**2) | |
| # Normalize by image diagonal (assuming standard frame size) | |
| normalized_distance = distance / 1000 # Adjust based on typical frame size | |
| if normalized_distance < proximity_threshold: | |
| nearby_equipment.append(equipment) | |
| return nearby_equipment | |
| def draw_annotations(self, frame: np.ndarray, analysis: Dict) -> np.ndarray: | |
| """ | |
| Draw bounding boxes and annotations on the frame. | |
| Args: | |
| frame: Input frame | |
| analysis: Safety compliance analysis results | |
| Returns: | |
| Annotated frame | |
| """ | |
| annotated_frame = frame.copy() | |
| # Draw safety equipment | |
| for equipment in analysis['safety_equipment']: | |
| bbox = equipment['bbox'] | |
| cv2.rectangle(annotated_frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), | |
| self.colors['equipment'], 2) | |
| label = f"{equipment.get('equipment_type', equipment['class'])}: {equipment['confidence']:.2f}" | |
| cv2.putText(annotated_frame, label, (bbox[0], bbox[1] - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['equipment'], 2) | |
| # Draw people with compliance status | |
| for result in analysis['compliance_results']: | |
| person = result['person'] | |
| bbox = person['bbox'] | |
| # Choose color based on compliance | |
| color = self.colors['person'] if result['is_compliant'] else self.colors['violation'] | |
| # Draw bounding box | |
| cv2.rectangle(annotated_frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 3) | |
| # Create status label | |
| status = "COMPLIANT" if result['is_compliant'] else "VIOLATION" | |
| confidence_text = f"Person: {person['confidence']:.2f}" | |
| # Draw labels | |
| cv2.putText(annotated_frame, status, (bbox[0], bbox[1] - 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) | |
| cv2.putText(annotated_frame, confidence_text, (bbox[0], bbox[1] - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| # Show missing equipment | |
| if result['missing_equipment']: | |
| missing_text = f"Missing: {', '.join(result['missing_equipment'])}" | |
| cv2.putText(annotated_frame, missing_text, (bbox[0], bbox[3] + 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['violation'], 2) | |
| # Draw summary statistics | |
| summary_text = [ | |
| f"Total People: {analysis['total_people']}", | |
| f"Compliant: {analysis['compliant_people']}", | |
| f"Violations: {analysis['violations']}", | |
| f"Compliance Rate: {(analysis['compliant_people']/max(analysis['total_people'],1)*100):.1f}%" | |
| ] | |
| for i, text in enumerate(summary_text): | |
| cv2.putText(annotated_frame, text, (10, 30 + i * 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| return annotated_frame | |
| def capture_violation(self, frame: np.ndarray, violation_data: Dict) -> str: | |
| """ | |
| Capture and save an image when a safety violation is detected. | |
| Args: | |
| frame: Current frame | |
| violation_data: Information about the violation | |
| Returns: | |
| Path to saved image | |
| """ | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] | |
| filename = f"violation_{timestamp}.jpg" | |
| filepath = os.path.join(self.violation_images_dir, filename) | |
| # Save the frame | |
| cv2.imwrite(filepath, frame) | |
| # Save violation metadata | |
| metadata = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'filename': filename, | |
| 'violation_data': violation_data | |
| } | |
| metadata_file = filepath.replace('.jpg', '_metadata.json') | |
| with open(metadata_file, 'w') as f: | |
| json.dump(metadata, f, indent=2) | |
| self.violations.append(metadata) | |
| return filepath | |
| def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]: | |
| """ | |
| Process a single frame for safety monitoring. | |
| Args: | |
| frame: Input video frame | |
| Returns: | |
| Tuple of (annotated_frame, analysis_results) | |
| """ | |
| # Detect objects and get safety violations | |
| results = self.detect_safety_violations(frame) | |
| # Draw detections on frame using the main drawing method | |
| annotated_frame = self.draw_detections(frame, results) | |
| return annotated_frame, { | |
| 'detections': results['detections'], | |
| 'people_count': results['people_count'], | |
| 'safety_equipment': results['safety_equipment'], | |
| 'violations': results['violations'], | |
| 'violation_summary': self.get_violation_summary(), | |
| 'frame_stats': { | |
| 'processing_time': results['processing_time'], | |
| 'fps': results['fps'], | |
| 'detection_count': len(results['detections']) | |
| } | |
| } | |
| def get_violation_summary(self) -> Dict: | |
| """Get a summary of recent violations.""" | |
| # This would typically connect to a database or log file | |
| # For now, return a placeholder | |
| return { | |
| 'total_violations_today': 0, | |
| 'most_common_violation': 'missing_hardhat', | |
| 'compliance_trend': [] # Could track compliance over time | |
| } | |
| if __name__ == "__main__": | |
| # Test the detector | |
| detector = SafetyDetector() | |
| print("Available classes:", detector.get_model_classes()) | |
| detector.test_detection() |