Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import subprocess | |
| import logging | |
| import warnings | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import time | |
| from simple_salesforce import Salesforce | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.units import inch | |
| from io import BytesIO | |
| import base64 | |
| from retrying import retry | |
| import uuid | |
| from multiprocessing import Pool, cpu_count | |
| from functools import partial | |
| import tempfile | |
| import shutil | |
| # ========================== # Configuration and Setup # ========================== | |
| # Use a temporary directory for storage to avoid file system issues on Hugging Face Spaces | |
| TEMP_DIR = tempfile.mkdtemp(prefix="Ultralytics_") | |
| os.environ['YOLO_CONFIG_DIR'] = TEMP_DIR | |
| # Ensure output directory exists within temp directory | |
| OUTPUT_DIR = os.path.join(TEMP_DIR, "output") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Configure logging for better debugging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Check for FFmpeg availability to ensure video processing works | |
| def check_ffmpeg(): | |
| try: | |
| subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
| logger.info("FFmpeg is available.") | |
| return True | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| logger.error("FFmpeg is not installed or not found in PATH. Video processing may fail.") | |
| return False | |
| FFMPEG_AVAILABLE = check_ffmpeg() | |
| # ========================== # ByteTrack Implementation # ========================== | |
| class BYTETracker: | |
| def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30): | |
| self.track_thresh = track_thresh | |
| self.track_buffer = track_buffer | |
| self.match_thresh = match_thresh | |
| self.frame_rate = frame_rate | |
| self.next_id = 1 | |
| self.tracks = {} # Store active tracks | |
| self.worker_history = {} # Track worker positions over time | |
| self.last_positions = {} # Last known positions of workers | |
| def update(self, dets, scores, cls): | |
| tracks = [] | |
| current_time = time.time() | |
| # Update existing tracks with new detections | |
| for i, (det, score, cl) in enumerate(zip(dets, scores, cls)): | |
| if score < self.track_thresh: | |
| continue | |
| x, y, w, h = det | |
| matched = False | |
| best_iou = 0 | |
| best_track_id = None | |
| # Try to match with existing tracks | |
| for track_id, track_info in self.tracks.items(): | |
| if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate: | |
| continue | |
| tx, ty, tw, th = track_info['bbox'] | |
| iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th]) | |
| if iou > self.match_thresh and iou > best_iou: | |
| best_iou = iou | |
| best_track_id = track_id | |
| matched = True | |
| if matched: | |
| # Update existing track | |
| self.tracks[best_track_id].update({ | |
| 'bbox': [x, y, w, h], | |
| 'score': score, | |
| 'cls': cl, | |
| 'last_seen': current_time | |
| }) | |
| # Update position history | |
| if best_track_id not in self.worker_history: | |
| self.worker_history[best_track_id] = [] | |
| self.worker_history[best_track_id].append([x, y]) | |
| self.last_positions[best_track_id] = [x, y] | |
| tracks.append({ | |
| 'id': best_track_id, | |
| 'bbox': [x, y, w, h], | |
| 'score': score, | |
| 'cls': cl | |
| }) | |
| else: | |
| # Create new track | |
| same_worker = False | |
| for worker_id, last_pos in self.last_positions.items(): | |
| if self._is_same_worker([x, y], last_pos): | |
| self.tracks[worker_id] = { | |
| 'bbox': [x, y, w, h], | |
| 'score': score, | |
| 'cls': cl, | |
| 'last_seen': current_time | |
| } | |
| tracks.append({ | |
| 'id': worker_id, | |
| 'bbox': [x, y, w, h], | |
| 'score': score, | |
| 'cls': cl | |
| }) | |
| same_worker = True | |
| break | |
| if not same_worker: | |
| self.tracks[self.next_id] = { | |
| 'bbox': [x, y, w, h], | |
| 'score': score, | |
| 'cls': cl, | |
| 'last_seen': current_time | |
| } | |
| self.worker_history[self.next_id] = [[x, y]] | |
| self.last_positions[self.next_id] = [x, y] | |
| tracks.append({ | |
| 'id': self.next_id, | |
| 'bbox': [x, y, w, h], | |
| 'score': score, | |
| 'cls': cl | |
| }) | |
| self.next_id += 1 | |
| # Clean up old tracks | |
| current_time = time.time() | |
| stale_ids = [] | |
| for track_id, track_info in self.tracks.items(): | |
| if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate: | |
| stale_ids.append(track_id) | |
| for track_id in stale_ids: | |
| del self.tracks[track_id] | |
| if track_id in self.worker_history: | |
| del self.worker_history[track_id] | |
| if track_id in self.last_positions: | |
| del self.last_positions[track_id] | |
| return tracks | |
| def _calculate_iou(self, box1, box2): | |
| """Calculate IOU between two boxes""" | |
| x1, y1, w1, h1 = box1 | |
| x2, y2, w2, h2 = box2 | |
| # Calculate intersection coordinates | |
| x_left = max(x1 - w1/2, x2 - w2/2) | |
| y_top = max(y1 - h1/2, y2 - h2/2) | |
| x_right = min(x1 + w1/2, x2 + w2/2) | |
| y_bottom = min(y1 + h1/2, y2 + h2/2) | |
| if x_right < x_left or y_bottom < y_top: | |
| return 0.0 | |
| intersection_area = (x_right - x_left) * (y_bottom - y_top) | |
| box1_area = w1 * h1 | |
| box2_area = w2 * h2 | |
| iou = intersection_area / (box1_area + box2_area - intersection_area) | |
| return iou | |
| def _is_same_worker(self, pos1, pos2, threshold=100): | |
| """Check if two positions likely belong to the same worker""" | |
| x1, y1 = pos1 | |
| x2, y2 = pos2 | |
| distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2) | |
| return distance < threshold | |
| # ========================== # Optimized Configuration # ========================== | |
| CONFIG = { | |
| "MODEL_PATH": "yolov8_safety.pt", | |
| "FALLBACK_MODEL": "yolov8n.pt", | |
| "OUTPUT_DIR": OUTPUT_DIR, | |
| "VIOLATION_LABELS": { | |
| 0: "no_helmet", | |
| 1: "no_harness", | |
| 2: "unsafe_posture", | |
| 3: "unsafe_zone", | |
| 4: "improper_tool_use" | |
| }, | |
| "CLASS_COLORS": { | |
| "no_helmet": (0, 0, 255), # Red | |
| "no_harness": (0, 165, 255), # Orange | |
| "unsafe_posture": (0, 255, 0), # Green | |
| "unsafe_zone": (255, 0, 0), # Blue | |
| "improper_tool_use": (255, 255, 0) # Cyan | |
| }, | |
| "DISPLAY_NAMES": { | |
| "no_helmet": "No Helmet Violation", | |
| "no_harness": "No Harness Violation", | |
| "unsafe_posture": "Unsafe Posture", | |
| "unsafe_zone": "Unsafe Zone Entry", | |
| "improper_tool_use": "Improper Tool Use" | |
| }, | |
| "SF_CREDENTIALS": { | |
| "username": os.getenv("SF_USERNAME", "prashanth1ai@safety.com"), | |
| "password": os.getenv("SF_PASSWORD", "SaiPrash461"), | |
| "security_token": os.getenv("SF_SECURITY_TOKEN", "AP4AQnPoidIKPvSvNEfAHyoK"), | |
| "domain": "login" | |
| }, | |
| "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/", | |
| "CONFIDENCE_THRESHOLDS": { | |
| "no_helmet": 0.5, | |
| "no_harness": 0.3, | |
| "unsafe_posture": 0.3, | |
| "unsafe_zone": 0.3, | |
| "improper_tool_use": 0.3 | |
| }, | |
| "MIN_VIOLATION_FRAMES": 1, | |
| "VIOLATION_COOLDOWN": 30.0, | |
| "WORKER_TRACKING_DURATION": 5.0, | |
| "MAX_PROCESSING_TIME": 60, | |
| "FRAME_SKIP": 2, | |
| "BATCH_SIZE": 8, # Reduced batch size to lower memory usage | |
| "PARALLEL_WORKERS": max(1, cpu_count() - 1), | |
| "TRACK_BUFFER": 30, | |
| "TRACK_THRESH": 0.3, | |
| "MATCH_THRESH": 0.7, | |
| "SNAPSHOT_QUALITY": 95, | |
| "MAX_WORKER_DISTANCE": 100 | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| def load_model(): | |
| try: | |
| if os.path.isfile(CONFIG["MODEL_PATH"]): | |
| model_path = CONFIG["MODEL_PATH"] | |
| logger.info(f"Model loaded: {model_path}") | |
| else: | |
| model_path = CONFIG["FALLBACK_MODEL"] | |
| logger.warning("Using fallback model. Train yolov8_safety.pt for best results.") | |
| if not os.path.isfile(model_path): | |
| logger.info(f"Downloading fallback model: {model_path}") | |
| torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path) | |
| model = YOLO(model_path).to(device) | |
| logger.info(f"Model classes: {model.names}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| model = load_model() | |
| # ========================== # Helper Functions # ========================== | |
| def preprocess_frame(frame): | |
| """Apply basic preprocessing to enhance detection""" | |
| frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) | |
| return frame | |
| def draw_detections(frame, detections): | |
| """Draw bounding boxes and labels on detection frame""" | |
| result_frame = frame.copy() | |
| for det in detections: | |
| label = det.get("violation", "Unknown") | |
| confidence = det.get("confidence", 0.0) | |
| x, y, w, h = det.get("bounding_box", [0, 0, 0, 0]) | |
| worker_id = det.get("worker_id", "Unknown") | |
| x1 = int(x - w/2) | |
| y1 = int(y - h/2) | |
| x2 = int(x + w/2) | |
| y2 = int(y + h/2) | |
| color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255)) | |
| cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3) | |
| display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})" | |
| text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] | |
| cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1) | |
| cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| conf_text = f"Conf: {confidence:.2f}" | |
| cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) | |
| return result_frame | |
| def calculate_safety_score(violations): | |
| """Calculate safety score based on detected violations""" | |
| penalties = { | |
| "no_helmet": 25, | |
| "no_harness": 30, | |
| "unsafe_posture": 20, | |
| "unsafe_zone": 35, | |
| "improper_tool_use": 25 | |
| } | |
| worker_violations = {} | |
| for v in violations: | |
| worker_id = v.get("worker_id", "Unknown") | |
| violation_type = v.get("violation", "Unknown") | |
| if worker_id not in worker_violations: | |
| worker_violations[worker_id] = set() | |
| worker_violations[worker_id].add(violation_type) | |
| total_penalty = 0 | |
| for worker_violations_set in worker_violations.values(): | |
| worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set) | |
| total_penalty += worker_penalty | |
| score = max(0, 100 - total_penalty) | |
| return score | |
| def generate_violation_pdf(violations, score): | |
| """Generate a PDF report for the detected violations""" | |
| try: | |
| pdf_filename = f"violations_{int(time.time())}.pdf" | |
| pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename) | |
| pdf_file = BytesIO() | |
| c = canvas.Canvas(pdf_file, pagesize=letter) | |
| c.setFont("Helvetica-Bold", 16) | |
| c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report") | |
| c.setFont("Helvetica", 12) | |
| c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}") | |
| c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}") | |
| c.setFont("Helvetica-Bold", 14) | |
| c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%") | |
| y_position = 8.2 * inch | |
| c.setFont("Helvetica-Bold", 12) | |
| c.drawString(1 * inch, y_position, "Summary:") | |
| y_position -= 0.3 * inch | |
| worker_violations = {} | |
| for v in violations: | |
| worker_id = v.get("worker_id", "Unknown") | |
| if worker_id not in worker_violations: | |
| worker_violations[worker_id] = [] | |
| worker_violations[worker_id].append(v) | |
| c.setFont("Helvetica", 10) | |
| summary_data = { | |
| "Total Workers with Violations": len(worker_violations), | |
| "Total Violations Found": len(violations), | |
| "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| for key, value in summary_data.items(): | |
| c.drawString(1 * inch, y_position, f"{key}: {value}") | |
| y_position -= 0.25 * inch | |
| y_position -= 0.5 * inch | |
| c.setFont("Helvetica-Bold", 12) | |
| c.drawString(1 * inch, y_position, "Violations by Worker:") | |
| y_position -= 0.3 * inch | |
| c.setFont("Helvetica", 10) | |
| for worker_id, worker_vios in worker_violations.items(): | |
| c.drawString(1 * inch, y_position, f"Worker {worker_id}:") | |
| y_position -= 0.2 * inch | |
| for v in worker_vios: | |
| display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") | |
| time_str = f"{v.get('timestamp', 0.0):.2f}s" | |
| conf_str = f"{v.get('confidence', 0.0):.2f}" | |
| violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})" | |
| c.drawString(1.2 * inch, y_position, violation_text) | |
| y_position -= 0.2 * inch | |
| if y_position < 1 * inch: | |
| c.showPage() | |
| c.setFont("Helvetica", 10) | |
| y_position = 10 * inch | |
| c.save() | |
| pdf_file.seek(0) | |
| with open(pdf_path, "wb") as f: | |
| f.write(pdf_file.getvalue()) | |
| public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" | |
| logger.info(f"PDF generated: {public_url}") | |
| return pdf_path, public_url, pdf_file | |
| except Exception as e: | |
| logger.error(f"Error generating PDF: {e}") | |
| return "", "", None | |
| def connect_to_salesforce(): | |
| """Connect to Salesforce with retry logic""" | |
| try: | |
| sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) | |
| logger.info("Connected to Salesforce") | |
| sf.describe() | |
| return sf | |
| except Exception as e: | |
| logger.error(f"Salesforce connection failed: {e}") | |
| raise | |
| def upload_pdf_to_salesforce(sf, pdf_file, report_id): | |
| """Upload PDF report to Salesforce""" | |
| try: | |
| if not pdf_file: | |
| logger.error("No PDF file provided for upload") | |
| return "" | |
| encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') | |
| content_version_data = { | |
| "Title": f"Safety_Violation_Report_{int(time.time())}", | |
| "PathOnClient": f"safety_violation_{int(time.time())}.pdf", | |
| "VersionData": encoded_pdf, | |
| "FirstPublishLocationId": report_id | |
| } | |
| content_version = sf.ContentVersion.create(content_version_data) | |
| result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'") | |
| if not result['records']: | |
| logger.error("Failed to retrieve ContentVersion") | |
| return "" | |
| file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" | |
| logger.info(f"PDF uploaded to Salesforce: {file_url}") | |
| return file_url | |
| except Exception as e: | |
| logger.error(f"Error uploading PDF to Salesforce: {e}") | |
| return "" | |
| def push_report_to_salesforce(violations, score, pdf_path, pdf_file): | |
| """Push violation report to Salesforce""" | |
| try: | |
| sf = connect_to_salesforce() | |
| violations_text = "" | |
| for v in violations: | |
| display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown') | |
| worker_id = v.get('worker_id', 'Unknown') | |
| timestamp = v.get('timestamp', 0.0) | |
| confidence = v.get('confidence', 0.0) | |
| violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n" | |
| if not violations_text: | |
| violations_text = "No violations detected." | |
| pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else "" | |
| record_data = { | |
| "Compliance_Score__c": score, | |
| "Violations_Found__c": len(violations), | |
| "Violations_Details__c": violations_text, | |
| "Status__c": "Pending", | |
| "PDF_Report_URL__c": pdf_url | |
| } | |
| logger.info(f"Creating Salesforce record with data: {record_data}") | |
| try: | |
| record = sf.Safety_Video_Report__c.create(record_data) | |
| logger.info(f"Created Safety_Video_Report__c record: {record['id']}") | |
| except Exception as e: | |
| logger.error(f"Failed to create Safety_Video_Report__c: {e}") | |
| record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) | |
| logger.warning(f"Fell back to Account record: {record['id']}") | |
| record_id = record["id"] | |
| if pdf_file: | |
| uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) | |
| if uploaded_url: | |
| try: | |
| sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url}) | |
| logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}") | |
| except Exception as e: | |
| logger.error(f"Failed to update Safety_Video_Report__c: {e}") | |
| sf.Account.update(record_id, {"Description": uploaded_url}) | |
| logger.info(f"Updated Account record {record_id} with PDF URL") | |
| pdf_url = uploaded_url | |
| return record_id, pdf_url | |
| except Exception as e: | |
| logger.error(f"Salesforce record creation failed: {e}") | |
| return "N/A", "Salesforce integration failed." | |
| def process_video(video_data): | |
| """Process video to detect safety violations""" | |
| try: | |
| # Validate video data | |
| if not video_data: | |
| raise ValueError("Empty video data provided.") | |
| # Save video to a temporary file | |
| video_fd, video_path = tempfile.mkstemp(suffix=".mp4", dir=TEMP_DIR) | |
| with os.fdopen(video_fd, "wb") as f: | |
| f.write(video_data) | |
| logger.info(f"Video saved: {video_path}") | |
| # Open video with OpenCV | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.") | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| duration = total_frames / fps | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}") | |
| # Check if video is empty | |
| if total_frames <= 0: | |
| raise ValueError("Video has no frames.") | |
| tracker = BYTETracker( | |
| track_thresh=CONFIG["TRACK_THRESH"], | |
| track_buffer=CONFIG["TRACK_BUFFER"], | |
| match_thresh=CONFIG["MATCH_THRESH"], | |
| frame_rate=fps | |
| ) | |
| unique_violations = {} | |
| snapshots = [] | |
| start_time = time.time() | |
| frame_skip = CONFIG["FRAME_SKIP"] | |
| processed_frames = 0 | |
| while processed_frames < total_frames: | |
| batch_frames = [] | |
| batch_indices = [] | |
| for _ in range(CONFIG["BATCH_SIZE"]): | |
| frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) | |
| if frame_idx >= total_frames: | |
| break | |
| ret, frame = cap.read() | |
| if not ret: | |
| logger.warning(f"Failed to read frame {frame_idx}. Skipping.") | |
| break | |
| frame = preprocess_frame(frame) | |
| for _ in range(frame_skip - 1): | |
| if not cap.grab(): | |
| break | |
| batch_frames.append(frame) | |
| batch_indices.append(frame_idx) | |
| processed_frames += 1 | |
| if not batch_frames: | |
| logger.info("No more frames to process.") | |
| break | |
| # Process batch with YOLO model | |
| try: | |
| results = model(batch_frames, device=device, conf=0.1, verbose=False) | |
| except Exception as e: | |
| logger.error(f"Model inference failed: {e}") | |
| raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}") | |
| for i, (result, frame_idx) in enumerate(zip(results, batch_indices)): | |
| current_time = frame_idx / fps | |
| if time.time() - start_time > 1.0: | |
| progress = (processed_frames / total_frames) * 100 | |
| yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", "" | |
| start_time = time.time() | |
| boxes = result.boxes | |
| track_inputs = [] | |
| for box in boxes: | |
| cls = int(box.cls) | |
| conf = float(box.conf) | |
| label = CONFIG["VIOLATION_LABELS"].get(cls, None) | |
| if label is None: | |
| continue | |
| if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25): | |
| continue | |
| bbox = box.xywh.cpu().numpy()[0] | |
| track_inputs.append({ | |
| "bbox": bbox, | |
| "conf": conf, | |
| "cls": cls | |
| }) | |
| if not track_inputs: | |
| continue | |
| tracked_objects = tracker.update( | |
| np.array([t["bbox"] for t in track_inputs]), | |
| np.array([t["conf"] for t in track_inputs]), | |
| np.array([t["cls"] for t in track_inputs]) | |
| ) | |
| for obj in tracked_objects: | |
| worker_id = obj['id'] | |
| label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None) | |
| conf = obj['score'] | |
| bbox = obj['bbox'] | |
| if label is None: | |
| continue | |
| if worker_id not in unique_violations: | |
| unique_violations[worker_id] = {} | |
| if label not in unique_violations[worker_id]: | |
| unique_violations[worker_id][label] = current_time | |
| detection = { | |
| "worker_id": worker_id, | |
| "violation": label, | |
| "confidence": round(float(conf), 2), # Ensure confidence is a float | |
| "bounding_box": bbox, | |
| "timestamp": current_time | |
| } | |
| snapshot_frame = batch_frames[i].copy() | |
| snapshot_frame = draw_detections(snapshot_frame, [detection]) | |
| cv2.putText( | |
| snapshot_frame, | |
| f"Time: {current_time:.2f}s", | |
| (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.7, | |
| (255, 255, 255), | |
| 2 | |
| ) | |
| snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg" | |
| snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename) | |
| cv2.imwrite( | |
| snapshot_path, | |
| snapshot_frame, | |
| [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]] | |
| ) | |
| snapshots.append({ | |
| "violation": label, | |
| "worker_id": worker_id, | |
| "timestamp": current_time, | |
| "snapshot_path": snapshot_path, | |
| "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}", | |
| "confidence": round(float(conf), 2) # Ensure confidence is stored as float | |
| }) | |
| logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s") | |
| # Ensure resources are released | |
| cap.release() | |
| if os.path.exists(video_path): | |
| os.remove(video_path) | |
| processing_time = time.time() - start_time | |
| logger.info(f"Processing complete in {processing_time:.2f}s") | |
| # Log the snapshots for debugging | |
| logger.info(f"Snapshots: {snapshots}") | |
| violations = [] | |
| for worker_id, worker_violations in unique_violations.items(): | |
| for label, detection_time in worker_violations.items(): | |
| # Find the confidence from snapshots, ensuring it's a float | |
| confidence = next( | |
| (float(s["confidence"]) for s in snapshots if s["worker_id"] == worker_id and s["violation"] == label), | |
| 0.0 | |
| ) | |
| violation = { | |
| "worker_id": worker_id, | |
| "violation": label, | |
| "timestamp": detection_time, | |
| "confidence": confidence | |
| } | |
| violations.append(violation) | |
| # Log the violations for debugging | |
| logger.info(f"Violations: {violations}") | |
| if not violations: | |
| logger.info("No violations detected after processing") | |
| yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A" | |
| return | |
| score = calculate_safety_score(violations) | |
| pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score) | |
| # Push to Salesforce with fallback | |
| record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) | |
| # Generate violation table with robust error handling | |
| violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n" | |
| violation_table += "|-----------|-----------|----------|------------|\n" | |
| for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))): | |
| display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") | |
| worker_id = v.get("worker_id", "Unknown") | |
| timestamp = v.get("timestamp", 0.0) | |
| # Ensure confidence is a valid float | |
| try: | |
| confidence = float(v.get("confidence", 0.0)) | |
| except (ValueError, TypeError) as e: | |
| logger.error(f"Invalid confidence value in violation {v}: {e}") | |
| confidence = 0.0 | |
| violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n" | |
| snapshots_text = "" | |
| for s in snapshots: | |
| display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown") | |
| worker_id = s.get("worker_id", "Unknown") | |
| timestamp = s.get("timestamp", 0.0) | |
| snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n" | |
| snapshots_text += f"\n\n" | |
| if not snapshots_text: | |
| snapshots_text = "No snapshots captured." | |
| yield ( | |
| violation_table, | |
| f"Safety Score: {score}%", | |
| snapshots_text, | |
| f"Salesforce Record ID: {record_id}", | |
| final_pdf_url | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing video: {str(e)}", exc_info=True) | |
| if 'video_path' in locals() and os.path.exists(video_path): | |
| os.remove(video_path) | |
| yield f"Error processing video: {str(e)}", "", "", "", "" | |
| finally: | |
| # Clean up temporary directory | |
| if os.path.exists(TEMP_DIR): | |
| shutil.rmtree(TEMP_DIR, ignore_errors=True) | |
| def gradio_interface(video_file): | |
| """Gradio interface for the video processing""" | |
| if not video_file: | |
| return "No file uploaded.", "", "No file uploaded.", "", "" | |
| try: | |
| with open(video_file, "rb") as f: | |
| video_data = f.read() | |
| # Validate FFmpeg availability | |
| if not FFMPEG_AVAILABLE: | |
| return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", "" | |
| for status, score, snapshots_text, record_id, details_url in process_video(video_data): | |
| yield status, score, snapshots_text, record_id, details_url | |
| except Exception as e: | |
| logger.error(f"Error in Gradio interface: {e}", exc_info=True) | |
| yield f"Error: {str(e)}", "", "Error in processing.", "", "" | |
| # ========================== # Gradio Interface # ========================== | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Video(label="Upload Site Video"), | |
| outputs=[ | |
| gr.Markdown(label="Detected Safety Violations"), | |
| gr.Textbox(label="Compliance Score"), | |
| gr.Markdown(label="Snapshots"), | |
| gr.Textbox(label="Salesforce Record ID"), | |
| gr.Textbox(label="Violation Details URL") | |
| ], | |
| title="Worksite Safety Violation Analyzer", | |
| description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Launching Enhanced Safety Analyzer App...") | |
| interface.launch() |