import os import sys import subprocess import logging import warnings import cv2 import gradio as gr import torch import numpy as np from ultralytics import YOLO import time from simple_salesforce import Salesforce from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas from reportlab.lib.units import inch from io import BytesIO import base64 from retrying import retry import uuid from multiprocessing import Pool, cpu_count from functools import partial import tempfile import shutil import tenacity # ========================== # Configuration and Setup # ========================== logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) logger.info("FFmpeg is available.") return True except (subprocess.CalledProcessError, FileNotFoundError): logger.error("FFmpeg is not installed or not found in PATH. Video processing may fail.") return False FFMPEG_AVAILABLE = check_ffmpeg() # ========================== # ByteTrack Implementation # ========================== class BYTETracker: def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.3, frame_rate=30, max_distance=100): self.track_thresh = track_thresh self.track_buffer = track_buffer self.match_thresh = match_thresh self.frame_rate = frame_rate self.next_id = 1 self.tracks = {} self.worker_history = {} self.last_positions = {} self.recently_removed = {} # Store recently removed tracks for re-identification self.track_attributes = {} # Store additional attributes like appearance features self.active_workers = set() # Track currently active workers self.worker_violation_history = {} # Track violations per worker self.max_worker_distance = max_distance def update(self, dets, scores, cls): tracks = [] current_time = time.time() # Prune stale tracks stale_ids = [] for track_id, track_info in self.tracks.items(): if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate: stale_ids.append(track_id) for track_id in stale_ids: # Store recently removed tracks for re-identification (for 2 seconds) self.recently_removed[track_id] = { 'bbox': self.tracks[track_id]['bbox'], 'last_seen': current_time, 'last_position': self.last_positions.get(track_id, [0, 0]), 'appearance': self.track_attributes.get(track_id, {}).get('appearance', None) } del self.tracks[track_id] if track_id in self.worker_history: del self.worker_history[track_id] if track_id in self.last_positions: del self.last_positions[track_id] if track_id in self.active_workers: self.active_workers.remove(track_id) # Clean up recently_removed tracks older than 2 seconds to_remove = [] for track_id, info in self.recently_removed.items(): if current_time - info['last_seen'] > 2.0: to_remove.append(track_id) for track_id in to_remove: del self.recently_removed[track_id] # Process new detections active_tracks = {} for i, (det, score, cl) in enumerate(zip(dets, scores, cls)): if score < self.track_thresh: continue x, y, w, h = det matched = False best_iou = 0 best_track_id = None # Try to match with active tracks for track_id, track_info in self.tracks.items(): tx, ty, tw, th = track_info['bbox'] iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th]) if iou > self.match_thresh and iou > best_iou: best_iou = iou best_track_id = track_id matched = True if matched: # Update existing track self.tracks[best_track_id].update({ 'bbox': [x, y, w, h], 'score': score, 'cls': cl, 'last_seen': current_time }) if 'appearance' not in self.track_attributes.get(best_track_id, {}): self.track_attributes[best_track_id] = {'appearance': self._extract_appearance_features([x, y, w, h])} if best_track_id not in self.worker_history: self.worker_history[best_track_id] = [] self.worker_history[best_track_id].append({'pos': [x, y], 'time': current_time}) if len(self.worker_history[best_track_id]) > 30: self.worker_history[best_track_id] = self.worker_history[best_track_id][-30:] self.last_positions[best_track_id] = [x, y] self.active_workers.add(best_track_id) if cl is not None: if best_track_id not in self.worker_violation_history: self.worker_violation_history[best_track_id] = set() self.worker_violation_history[best_track_id].add(int(cl)) active_tracks[best_track_id] = { 'id': best_track_id, 'bbox': [x, y, w, h], 'score': score, 'cls': cl } else: # Try to re-identify with recently removed tracks reidentified = False for track_id, info in self.recently_removed.items(): if self._is_same_worker([x, y], info['last_position']): self.tracks[track_id] = { 'bbox': [x, y, w, h], 'score': score, 'cls': cl, 'last_seen': current_time } if track_id not in self.worker_history: self.worker_history[track_id] = [] self.worker_history[track_id].append({'pos': [x, y], 'time': current_time}) self.last_positions[track_id] = [x, y] self.active_workers.add(track_id) if cl is not None: if track_id not in self.worker_violation_history: self.worker_violation_history[track_id] = set() self.worker_violation_history[track_id].add(int(cl)) active_tracks[track_id] = { 'id': track_id, 'bbox': [x, y, w, h], 'score': score, 'cls': cl } reidentified = True del self.recently_removed[track_id] break if not reidentified: # Try to match with last positions of existing tracks via distance same_worker = False for worker_id, last_pos in self.last_positions.items(): if self._is_same_worker([x, y], last_pos): self.tracks[worker_id] = { 'bbox': [x, y, w, h], 'score': score, 'cls': cl, 'last_seen': current_time } if worker_id not in self.worker_history: self.worker_history[worker_id] = [] self.worker_history[worker_id].append({'pos': [x, y], 'time': current_time}) self.last_positions[worker_id] = [x, y] self.active_workers.add(worker_id) if cl is not None: if worker_id not in self.worker_violation_history: self.worker_violation_history[worker_id] = set() self.worker_violation_history[worker_id].add(int(cl)) active_tracks[worker_id] = { 'id': worker_id, 'bbox': [x, y, w, h], 'score': score, 'cls': cl } same_worker = True break if not same_worker: # Register a new track new_id = self.next_id self.tracks[new_id] = { 'bbox': [x, y, w, h], 'score': score, 'cls': cl, 'last_seen': current_time } self.track_attributes[new_id] = {'appearance': self._extract_appearance_features([x, y, w, h])} self.worker_history[new_id] = [{'pos': [x, y], 'time': current_time}] self.last_positions[new_id] = [x, y] self.active_workers.add(new_id) if cl is not None: if new_id not in self.worker_violation_history: self.worker_violation_history[new_id] = set() self.worker_violation_history[new_id].add(int(cl)) active_tracks[new_id] = { 'id': new_id, 'bbox': [x, y, w, h], 'score': score, 'cls': cl } self.next_id += 1 return list(active_tracks.values()) def _calculate_iou(self, box1, box2): x1, y1, w1, h1 = box1 x2, y2, w2, h2 = box2 x_left = max(x1 - w1/2, x2 - w2/2) y_top = max(y1 - h1/2, y2 - h2/2) x_right = min(x1 + w1/2, x2 + w2/2) y_bottom = min(y1 + h1/2, y2 + h2/2) if x_right < x_left or y_bottom < y_top: return 0.0 intersection_area = (x_right - x_left) * (y_bottom - y_top) box1_area = w1 * h1 box2_area = w2 * h2 iou = intersection_area / (box1_area + box2_area - intersection_area) return iou def _is_same_worker(self, pos1, pos2): x1, y1 = pos1 x2, y2 = pos2 distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2) return distance < self.max_worker_distance def _extract_appearance_features(self, bbox): """Simple appearance feature extraction (placeholder)""" _, _, w, h = bbox return [w, h, w/h] def get_active_worker_count(self): return len(self.active_workers) def get_worker_violation_types(self, worker_id): return self.worker_violation_history.get(worker_id, set()) def get_all_workers(self): return set(list(self.tracks.keys()) + list(self.recently_removed.keys())) # ========================== # Optimized Configuration # ========================== CONFIG = { "MODEL_PATH": "yolov8_safety.pt", "FALLBACK_MODEL": "yolov8n.pt", "VIOLATION_LABELS": { 0: "no_helmet", 1: "no_harness", 2: "unsafe_posture", 3: "unsafe_zone", 4: "improper_tool_use" }, "CLASS_COLORS": { "no_helmet": (0, 0, 255), "no_harness": (0, 165, 255), "unsafe_posture": (0, 255, 0), "unsafe_zone": (255, 0, 0), "improper_tool_use": (255, 255, 0) }, "DISPLAY_NAMES": { "no_helmet": "No Helmet Violation", "no_harness": "No Harness Violation", "unsafe_posture": "Unsafe Posture", "unsafe_zone": "Unsafe Zone Entry", "improper_tool_use": "Improper Tool Use" }, "SF_CREDENTIALS": { "username": os.getenv("SF_USERNAME", "prashanth1ai@safety.com"), "password": os.getenv("SF_PASSWORD", "SaiPrash461"), "security_token": os.getenv("SF_SECURITY_TOKEN", "AP4AQnPoidIKPvSvNEfAHyoK"), "domain": "login" }, "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/", "CONFIDENCE_THRESHOLDS": { "no_helmet": 0.4, "no_harness": 0.25, "unsafe_posture": 0.25, "unsafe_zone": 0.25, "improper_tool_use": 0.25 }, "MIN_VIOLATION_FRAMES": 1, "VIOLATION_COOLDOWN": 30.0, "WORKER_TRACKING_DURATION": 10.0, "MAX_PROCESSING_TIME": 60, "FRAME_SKIP": 1, "BATCH_SIZE": 15, "PARALLEL_WORKERS": max(1, cpu_count() - 1), "TRACK_BUFFER": 150, # 5.0 seconds at 30 fps "TRACK_THRESH": 0.3, "MATCH_THRESH": 0.3, "SNAPSHOT_QUALITY": 95, "MAX_WORKER_DISTANCE": 100, "TARGET_RESOLUTION": (384, 384) } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") def load_model(): try: if os.path.isfile(CONFIG["MODEL_PATH"]): model_path = CONFIG["MODEL_PATH"] logger.info(f"Model loaded: {model_path}") else: model_path = CONFIG["FALLBACK_MODEL"] logger.warning("Using fallback model. Train yolov8_safety.pt for best results.") if not os.path.isfile(model_path): logger.info(f"Downloading fallback model: {model_path}") torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path) model = YOLO(model_path).to(device) if device.type == "cuda": model.model.half() logger.info(f"Model classes: {model.names}") return model except Exception as e: logger.error(f"Failed to load model: {e}") raise model = load_model() # ========================== # Helper Functions # ========================== def preprocess_frame(frame): target_res = CONFIG["TARGET_RESOLUTION"] frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR) frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) return frame def draw_detections(frame, detections): result_frame = frame.copy() for det in detections: label = det.get("violation", "Unknown") confidence = det.get("confidence", 0.0) x, y, w, h = det.get("bounding_box", [0, 0, 0, 0]) worker_id = det.get("worker_id", "Unknown") x1 = int(x - w/2) y1 = int(y - h/2) x2 = int(x + w/2) y2 = int(y + h/2) color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255)) cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3) display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})" text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1) cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) conf_text = f"Conf: {confidence:.2f}" cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) return result_frame def calculate_safety_score(violations): penalties = { "no_helmet": 25, "no_harness": 30, "unsafe_posture": 20, "unsafe_zone": 35, "improper_tool_use": 25 } worker_violations = {} for v in violations: worker_id = v.get("worker_id", "Unknown") violation_type = v.get("violation", "Unknown") if worker_id not in worker_violations: worker_violations[worker_id] = set() worker_violations[worker_id].add(violation_type) total_penalty = 0 for worker_violations_set in worker_violations.values(): worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set) total_penalty += worker_penalty score = max(0, 100 - total_penalty) return score def generate_violation_pdf(violations, score, output_dir): try: pdf_filename = f"violations_{int(time.time())}.pdf" pdf_path = os.path.join(output_dir, pdf_filename) pdf_file = BytesIO() c = canvas.Canvas(pdf_file, pagesize=letter) c.setFont("Helvetica-Bold", 16) c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report") c.setFont("Helvetica", 12) c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}") c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}") c.setFont("Helvetica-Bold", 14) c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%") y_position = 8.2 * inch c.setFont("Helvetica-Bold", 12) c.drawString(1 * inch, y_position, "Summary:") y_position -= 0.3 * inch worker_violations = {} for v in violations: worker_id = v.get("worker_id", "Unknown") if worker_id not in worker_violations: worker_violations[worker_id] = [] worker_violations[worker_id].append(v) c.setFont("Helvetica", 10) summary_data = { "Total Workers with Violations": len(worker_violations), "Total Violations Found": len(violations), "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S") } for key, value in summary_data.items(): c.drawString(1 * inch, y_position, f"{key}: {value}") y_position -= 0.25 * inch y_position -= 0.5 * inch c.setFont("Helvetica-Bold", 12) c.drawString(1 * inch, y_position, "Violations by Worker:") y_position -= 0.3 * inch c.setFont("Helvetica", 10) for worker_id, worker_vios in worker_violations.items(): c.drawString(1 * inch, y_position, f"Worker {worker_id}:") y_position -= 0.2 * inch for v in worker_vios: display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") time_str = f"{v.get('timestamp', 0.0):.2f}s" conf_str = f"{v.get('confidence', 0.0):.2f}" violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})" c.drawString(1.2 * inch, y_position, violation_text) y_position -= 0.2 * inch if y_position < 1 * inch: c.showPage() c.setFont("Helvetica", 10) y_position = 10 * inch c.save() pdf_file.seek(0) with open(pdf_path, "wb") as f: f.write(pdf_file.getvalue()) public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" logger.info(f"PDF generated: {public_url}") return pdf_path, public_url, pdf_file except Exception as e: logger.error(f"Error generating PDF: {e}") return "", "", None @retry(stop_max_attempt_number=3, wait_fixed=2000) def connect_to_salesforce(): try: sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) logger.info("Connected to Salesforce") sf.describe() return sf except Exception as e: logger.error(f"Salesforce connection failed: {e}") raise def upload_pdf_to_salesforce(sf, pdf_file, report_id): try: if not pdf_file: logger.error("No PDF file provided for upload") return "" encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') content_version_data = { "Title": f"Safety_Violation_Report_{int(time.time())}", "PathOnClient": f"safety_violation_{int(time.time())}.pdf", "VersionData": encoded_pdf, "FirstPublishLocationId": report_id } content_version = sf.ContentVersion.create(content_version_data) result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'") if not result['records']: logger.error("Failed to retrieve ContentVersion") return "" file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" logger.info(f"PDF uploaded to Salesforce: {file_url}") return file_url except Exception as e: logger.error(f"Error uploading PDF to Salesforce: {e}") return "" def push_report_to_salesforce(violations, score, pdf_path, pdf_file): try: sf = connect_to_salesforce() violations_text = "" for v in violations: display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown') worker_id = v.get('worker_id', 'Unknown') timestamp = v.get('timestamp', 0.0) confidence = v.get('confidence', 0.0) violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n" if not violations_text: violations_text = "No violations detected." pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else "" record_data = { "Compliance_Score__c": score, "Violations_Found__c": len(violations), "Violations_Details__c": violations_text, "Status__c": "Pending", "PDF_Report_URL__c": pdf_url } logger.info(f"Creating Salesforce record with data: {record_data}") try: record = sf.Safety_Video_Report__c.create(record_data) logger.info(f"Created Safety_Video_Report__c record: {record['id']}") except Exception as e: logger.error(f"Failed to create Safety_Video_Report__c: {e}") record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) logger.warning(f"Fell back to Account record: {record['id']}") record_id = record["id"] if pdf_file: uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) if uploaded_url: try: sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url}) logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}") except Exception as e: logger.error(f"Failed to update Safety_Video_Report__c: {e}") sf.Account.update(record_id, {"Description": uploaded_url}) logger.info(f"Updated Account record {record_id} with PDF URL") pdf_url = uploaded_url return record_id, pdf_url except Exception as e: logger.error(f"Salesforce record creation failed: {e}") return "N/A", "Salesforce integration failed." @tenacity.retry( stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), retry=tenacity.retry_if_exception_type((IOError, OSError)), before_sleep=lambda retry_state: logger.info(f"Retrying file access (attempt {retry_state.attempt_number}/3)...") ) def verify_and_open_video(video_path): if not os.path.exists(video_path): raise FileNotFoundError(f"Temporary video file not found: {video_path}") file_size = os.path.getsize(video_path) if file_size == 0: raise ValueError(f"Temporary video file is empty: {video_path}") with open(video_path, "rb") as f: f.read(1) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.") return cap def process_video(video_data, temp_dir): video_path = None output_dir = os.path.join(temp_dir, "output") os.makedirs(output_dir, exist_ok=True) os.environ['YOLO_CONFIG_DIR'] = temp_dir try: if not video_data: raise ValueError("Empty video data provided.") logger.info(f"Received video data size: {len(video_data)} bytes") if len(video_data) == 0: raise ValueError("Video data is empty.") with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file: temp_file.write(video_data) temp_file.flush() video_path = temp_file.name logger.info(f"Video saved to temporary file: {video_path}") if not os.path.exists(video_path): raise FileNotFoundError(f"Temporary video file not found: {video_path}") file_size = os.path.getsize(video_path) if file_size == 0: raise ValueError(f"Temporary video file is empty: {video_path}") logger.info(f"Temporary video file size: {file_size} bytes") cap = verify_and_open_video(video_path) logger.info(f"Successfully opened video file: {video_path}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or 30 duration = total_frames / fps width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}") if total_frames <= 0: raise ValueError("Video has no frames.") tracker = BYTETracker( track_thresh=CONFIG["TRACK_THRESH"], track_buffer=CONFIG["TRACK_BUFFER"], match_thresh=CONFIG["MATCH_THRESH"], frame_rate=fps, max_distance=CONFIG["MAX_WORKER_DISTANCE"] ) unique_violations = {} violation_frames = {} violation_confidences = {} start_time = time.time() frame_skip = CONFIG["FRAME_SKIP"] processed_frames = 0 last_yield_time = start_time logger.info("First pass: Worker detection and tracking") all_workers = set() worker_first_seen = {} worker_last_seen = {} while processed_frames < total_frames: batch_frames = [] batch_indices = [] batch_timestamps = [] for _ in range(CONFIG["BATCH_SIZE"]): # Skip frames BEFORE reading to speed up for _ in range(frame_skip - 1): if not cap.grab(): break frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if frame_idx >= total_frames: break ret, frame = cap.read() if not ret: logger.warning(f"Failed to read frame {frame_idx}. Skipping.") break frame = preprocess_frame(frame) timestamp = frame_idx / fps batch_frames.append(frame) batch_indices.append(frame_idx) batch_timestamps.append(timestamp) processed_frames += 1 if not batch_frames: logger.info("No more frames to process.") break try: batch_frames_np = np.array(batch_frames) batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0 batch_frames_tensor = batch_frames_tensor.to(device) if device.type == "cuda": batch_frames_tensor = batch_frames_tensor.half() results = model(batch_frames_tensor, device=device, conf=0.1, verbose=False) except Exception as e: logger.error(f"Model inference failed: {e}") raise ValueError(f"Failed to process video frames with YOLO model: {str(e)}") finally: if device.type == "cuda": torch.cuda.empty_cache() current_time = time.time() if current_time - last_yield_time > 0.1: progress = (processed_frames / total_frames) * 100 elapsed_time = current_time - start_time fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0 yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "" last_yield_time = current_time for i, (result, frame_idx, timestamp) in enumerate(zip(results, batch_indices, batch_timestamps)): boxes = result.boxes track_inputs = [] for box in boxes: cls = int(box.cls) conf = float(box.conf) label = CONFIG["VIOLATION_LABELS"].get(cls, None) if label is None: continue if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25): continue bbox = box.xywh.cpu().numpy()[0] track_inputs.append({ "bbox": bbox, "conf": conf, "cls": cls }) if not track_inputs: continue tracked_objects = tracker.update( np.array([t["bbox"] for t in track_inputs]), np.array([t["conf"] for t in track_inputs]), np.array([t["cls"] for t in track_inputs]) ) for obj in tracked_objects: tracker_id = obj['id'] all_workers.add(tracker_id) if tracker_id not in worker_first_seen: worker_first_seen[tracker_id] = timestamp worker_last_seen[tracker_id] = timestamp label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None) conf = obj['score'] if label is None: continue violation_key = (tracker_id, label) if violation_key not in unique_violations or conf > violation_confidences.get(violation_key, 0.0): unique_violations[violation_key] = timestamp violation_frames[violation_key] = frame_idx violation_confidences[violation_key] = conf cap.release() processing_time = time.time() - start_time logger.info(f"Processing complete in {processing_time:.2f}s") total_workers = len(all_workers) logger.info(f"Total unique workers detected: {total_workers}") violations = [] for (worker_id, label), detection_time in unique_violations.items(): violations.append({ "worker_id": worker_id, "violation": label, "timestamp": detection_time, "confidence": violation_confidences.get((worker_id, label), 0.0), "frame_idx": violation_frames[(worker_id, label)] }) if not violations: logger.info("No violations detected after processing") yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A" return snapshots = [] cap = cv2.VideoCapture(video_path) for violation in violations: frame_idx = violation["frame_idx"] cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret: logger.warning(f"Failed to read frame {frame_idx} for snapshot.") continue frame = preprocess_frame(frame) frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0 frame_tensor = frame_tensor.unsqueeze(0).to(device) if device.type == "cuda": frame_tensor = frame_tensor.half() result = model(frame_tensor, device=device, conf=0.1, verbose=False)[0] boxes = result.boxes for box in boxes: cls = int(box.cls) conf = float(box.conf) label = CONFIG["VIOLATION_LABELS"].get(cls, None) if label == violation["violation"]: violation["confidence"] = round(conf, 2) bbox = box.xywh.cpu().numpy()[0] detection = { "worker_id": violation["worker_id"], "violation": label, "confidence": violation["confidence"], "bounding_box": bbox, "timestamp": violation["timestamp"] } snapshot_frame = frame.copy() snapshot_frame = draw_detections(snapshot_frame, [detection]) cv2.putText( snapshot_frame, f"Time: {violation['timestamp']:.2f}s", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2 ) snapshot_filename = f"violation_{label}worker{violation['worker_id']}{int(violation['timestamp']*100)}.jpg" snapshot_path = os.path.join(output_dir, snapshot_filename) cv2.imwrite( snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]] ) snapshots.append({ "violation": label, "worker_id": violation["worker_id"], "timestamp": violation["timestamp"], "snapshot_path": snapshot_path, "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}", "confidence": violation["confidence"] }) logger.info(f"Captured snapshot for {label} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s") break cap.release() score = calculate_safety_score(violations) pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir) record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) worker_violations = {} for v in violations: worker_id = v.get("worker_id", "Unknown") if worker_id not in worker_violations: worker_violations[worker_id] = [] worker_violations[worker_id].append(v) violation_table = f"## Total Workers Detected: {total_workers}\n\n" violation_table += "| Worker ID | Violation | Time (s) | Confidence |\n" violation_table += "|-----------|-----------|----------|------------|\n" for worker_id, vios in sorted(worker_violations.items()): vios.sort(key=lambda x: x.get("violation", "")) for v in vios: display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") timestamp = v.get("timestamp", 0.0) confidence = v.get("confidence", 0.0) violation_table += f"| {worker_id} | {display_name} | {timestamp:.2f} | {confidence:.2f} |\n" snapshots_text = "" for s in snapshots: display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown") worker_id = s.get("worker_id", "Unknown") timestamp = s.get("timestamp", 0.0) snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n" snapshots_text += f"![Violation]({s['snapshot_url']})\n\n" if not snapshots_text: snapshots_text = "No snapshots captured." yield ( violation_table, f"Safety Score: {score}% (Based on {total_workers} workers)", snapshots_text, final_pdf_url ) except Exception as e: logger.error(f"Error processing video: {str(e)}", exc_info=True) yield f"Error processing video: {str(e)}", "", "", "" finally: if video_path and os.path.exists(video_path): try: os.remove(video_path) logger.info(f"Cleaned up temporary video file: {video_path}") except Exception as e: logger.error(f"Failed to clean up temporary video file {video_path}: {e}") if device.type == "cuda": torch.cuda.empty_cache() def gradio_interface(video_file): temp_dir = None local_video_path = None try: if not video_file: return "No file uploaded.", "", "No file uploaded.", "" temp_dir = tempfile.mkdtemp(prefix="Ultralytics_") logger.info(f"Created temporary directory for video processing: {temp_dir}") with open(video_file, "rb") as f: video_data = f.read() logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes") if len(video_data) == 0: return "Uploaded video file is empty.", "", "", "" with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file: temp_file.write(video_data) temp_file.flush() local_video_path = temp_file.name logger.info(f"Copied Gradio video to local temporary file: {local_video_path}") if not FFMPEG_AVAILABLE: return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "" for status, score, snapshots_text, details_url in process_video(video_data, temp_dir): yield status, score, snapshots_text, details_url except Exception as e: logger.error(f"Error in Gradio interface: {e}", exc_info=True) yield f"Error: {str(e)}", "", "Error in processing.", "" finally: if local_video_path and os.path.exists(local_video_path): try: os.remove(local_video_path) logger.info(f"Cleaned up local temporary video file: {local_video_path}") except Exception as e: logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}") if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) logger.info(f"Cleaned up temporary directory: {temp_dir}") if device.type == "cuda": torch.cuda.empty_cache() # ========================== # Gradio Interface # ========================== interface = gr.Interface( fn=gradio_interface, inputs=gr.Video(label="Upload Site Video"), outputs=[ gr.Markdown(label="Detected Safety Violations"), gr.Textbox(label="Compliance Score"), gr.Markdown(label="Snapshots"), gr.Textbox(label="Violation Details URL") ], title="Worksite Safety Violation Analyzer", description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). The system tracks individual workers and their specific violations.", allow_flagging="never" ) if __name__ == "__main__": logger.info("Launching Enhanced Safety Analyzer App...") interface.launch()