import os import cv2 import gradio as gr import torch import numpy as np from ultralytics import YOLO import time from simple_salesforce import Salesforce from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas from reportlab.lib.units import inch from io import BytesIO import base64 import logging from retrying import retry import uuid # ========================== # Enhanced Configuration # ========================== CONFIG = { "MODEL_PATH": "yolov8_safety.pt", "FALLBACK_MODEL": "yolov8n.pt", "OUTPUT_DIR": "static/output", "VIOLATION_LABELS": { 0: "no_helmet", 1: "no_harness", 2: "unsafe_posture", 3: "unsafe_zone", 4: "improper_tool_use" }, "CLASS_COLORS": { "no_helmet": (0, 0, 255), "no_harness": (0, 165, 255), "unsafe_posture": (0, 255, 0), "unsafe_zone": (255, 0, 0), "improper_tool_use": (255, 255, 0) }, "DISPLAY_NAMES": { "no_helmet": "No Helmet Violation", "no_harness": "No Harness Violation", "unsafe_posture": "Unsafe Posture Violation", "unsafe_zone": "Unsafe Zone Entry", "improper_tool_use": "Improper Tool Use" }, "SF_CREDENTIALS": { "username": "prashanth1ai@safety.com", "password": "SaiPrash461", "security_token": "AP4AQnPoidIKPvSvNEfAHyoK", "domain": "login" }, "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/", "FRAME_SKIP": 50, # Increased to process fewer frames (1 frame every 5 frames) "CONFIDENCE_THRESHOLDS": { "no_helmet": 0.5, "no_harness": 0.15, "unsafe_posture": 0.15, "unsafe_zone": 0.15, "improper_tool_use": 0.15 }, "IOU_THRESHOLD": 0.4, "MIN_VIOLATION_FRAMES":2, # Reduced to ensure violations are detected with fewer frames "HELMET_CONFIDENCE_THRESHOLD": 0.7, "MAX_PROCESSING_TIME": 120 # Enforce 30-second processing limit } # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") def load_model(): try: if os.path.isfile(CONFIG["MODEL_PATH"]): model_path = CONFIG["MODEL_PATH"] logger.info(f"Model loaded: {model_path}") else: model_path = CONFIG["FALLBACK_MODEL"] logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.") if not os.path.isfile(model_path): logger.info(f"Downloading fallback model: {model_path}") torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path) model = YOLO(model_path).to(device) return model except Exception as e: logger.error(f"Failed to load model: {e}") raise model = load_model() # ========================== # Enhanced Helper Functions # ========================== def draw_detections(frame, detections): for det in detections: label = det.get("violation", "Unknown") confidence = det.get("confidence", 0.0) x, y, w, h = det.get("bounding_box", [0, 0, 0, 0]) x1 = int(x - w/2) y1 = int(y - h/2) x2 = int(x + w/2) y2 = int(y + h/2) color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255)) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}" cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return frame def calculate_iou(box1, box2): x1, y1, w1, h1 = box1 x2, y2, w2, h2 = box2 x1_min, y1_min = x1 - w1/2, y1 - h1/2 x1_max, y1_max = x1 + w1/2, y1 + h1/2 x2_min, y2_min = x2 - w2/2, y2 - h2/2 x2_max, y2_max = x2 + w2/2, y2 + h2/2 intersection = max(0, x1_max - x1_min) * max(0, y1_max - y1_min) area1 = w1 * h1 area2 = w2 * h2 union = area1 + area2 - intersection return intersection / union if union > 0 else 0 # ========================== # Salesforce Integration # ========================== @retry(stop_max_attempt_number=3, wait_fixed=2000) def connect_to_salesforce(): try: sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) logger.info("Connected to Salesforce") sf.describe() return sf except Exception as e: logger.error(f"Salesforce connection failed: {e}") raise def generate_violation_pdf(violations, score): try: pdf_filename = f"violations_{int(time.time())}.pdf" pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename) pdf_file = BytesIO() c = canvas.Canvas(pdf_file, pagesize=letter) c.setFont("Helvetica", 12) c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report") c.setFont("Helvetica", 10) y_position = 9.5 * inch report_data = { "Compliance Score": f"{score}%", "Violations Found": len(violations), "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S") } for key, value in report_data.items(): c.drawString(1 * inch, y_position, f"{key}: {value}") y_position -= 0.3 * inch y_position -= 0.3 * inch c.drawString(1 * inch, y_position, "Violation Details:") y_position -= 0.3 * inch if not violations: c.drawString(1 * inch, y_position, "No violations detected.") else: for v in violations: display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})" c.drawString(1 * inch, y_position, text) y_position -= 0.3 * inch if y_position < 1 * inch: c.showPage() c.setFont("Helvetica", 10) y_position = 10 * inch c.showPage() c.save() pdf_file.seek(0) with open(pdf_path, "wb") as f: f.write(pdf_file.getvalue()) public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" logger.info(f"PDF generated: {public_url}") return pdf_path, public_url, pdf_file except Exception as e: logger.error(f"Error generating PDF: {e}") return "", "", None def upload_pdf_to_salesforce(sf, pdf_file, report_id): try: if not pdf_file: logger.error("No PDF file provided for upload") return "" encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') content_version_data = { "Title": f"Safety_Violation_Report_{int(time.time())}", "PathOnClient": f"safety_violation_{int(time.time())}.pdf", "VersionData": encoded_pdf, "FirstPublishLocationId": report_id } content_version = sf.ContentVersion.create(content_version_data) result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'") if not result['records']: logger.error("Failed to retrieve ContentVersion") return "" file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" logger.info(f"PDF uploaded to Salesforce: {file_url}") return file_url except Exception as e: logger.error(f"Error uploading PDF to Salesforce: {e}") return "" def push_report_to_salesforce(violations, score, pdf_path, pdf_file): try: sf = connect_to_salesforce() violations_text = "\n".join( f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})" for v in violations ) or "No violations detected." pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else "" record_data = { "Compliance_Score__c": score, "Violations_Found__c": len(violations), "Violations_Details__c": violations_text, "Status__c": "Pending", "PDF_Report_URL__c": pdf_url } logger.info(f"Creating Salesforce record with data: {record_data}") try: record = sf.Safety_Video_Report__c.create(record_data) logger.info(f"Created Safety_Video_Report__c record: {record['id']}") except Exception as e: logger.error(f"Failed to create Safety_Video_Report__c: {e}") record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) logger.warning(f"Fell back to Account record: {record['id']}") record_id = record["id"] if pdf_file: uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) if uploaded_url: try: sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url}) logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}") except Exception as e: logger.error(f"Failed to update Safety_Video_Report__c: {e}") sf.Account.update(record_id, {"Description": uploaded_url}) logger.info(f"Updated Account record {record_id} with PDF URL") pdf_url = uploaded_url return record_id, pdf_url except Exception as e: logger.error(f"Salesforce record creation failed: {e}", exc_info=True) return None, "" def calculate_safety_score(violations): penalties = { "no_helmet": 25, "no_harness": 30, "unsafe_posture": 20, "unsafe_zone": 35, "improper_tool_use": 25 } total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations) score = 100 - total_penalty return max(score, 0) # ========================== # Enhanced Video Processing # ========================== def process_video(video_data): try: video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4") with open(video_path, "wb") as f: f.write(video_data) logger.info(f"Video saved: {video_path}") video = cv2.VideoCapture(video_path) if not video.isOpened(): raise ValueError("Could not open video file") violations = [] snapshots = [] frame_count = 0 total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = video.get(cv2.CAP_PROP_FPS) if fps <= 0: fps = 30 video_duration = total_frames / fps logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}") workers = [] violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()} confirmed_violations = {} snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()} helmet_compliance = {} detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()} start_time = time.time() # Calculate frames to process within 30 seconds target_frames = int(total_frames / CONFIG["FRAME_SKIP"]) frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) processed_frames = 0 for idx in frame_indices: elapsed_time = time.time() - start_time if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]: logger.info(f"Processing time limit of {CONFIG['MAX_PROCESSING_TIME']} seconds reached. Processed {processed_frames}/{target_frames} frames.") break video.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = video.read() if not ret: continue processed_frames += 1 current_time = idx / fps progress = (processed_frames / target_frames) * 100 yield f"Processing video... {progress:.1f}% complete (Frame {idx}/{total_frames})", "", "", "", "" # Run detection on this frame results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"]) current_detections = [] for result in results: boxes = result.boxes for box in boxes: cls = int(box.cls) conf = float(box.conf) label = CONFIG["VIOLATION_LABELS"].get(cls, None) if label is None: logger.warning(f"Unknown class ID {cls} detected, skipping") continue if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25): logger.debug(f"Detection {label} with confidence {conf:.2f} below threshold, skipping") continue bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]] current_detections.append({ "frame": idx, "violation": label, "confidence": round(conf, 2), "bounding_box": bbox, "timestamp": current_time }) detection_counts[label] += 1 logger.debug(f"Frame {idx}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}") for detection in current_detections: violation_type = detection.get("violation", None) if violation_type is None: logger.error(f"Invalid detection, missing 'violation' key: {detection}") continue if violation_type == "no_helmet": matched_worker = None max_iou = 0 for worker in workers: iou = calculate_iou(detection["bounding_box"], worker["bbox"]) if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]: max_iou = iou matched_worker = worker if matched_worker: worker_id = matched_worker["id"] if worker_id not in helmet_compliance: helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False} helmet_compliance[worker_id]["no_helmet_frames"] += 1 if detection["confidence"] < CONFIG["HELMET_CONFIDENCE_THRESHOLD"]: helmet_compliance[worker_id]["compliant"] = True logger.debug(f"Worker {worker_id} marked as helmet compliant due to low no_helmet confidence") if helmet_compliance[worker_id]["compliant"]: logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation") continue matched_worker = None max_iou = 0 for worker in workers: iou = calculate_iou(detection["bounding_box"], worker["bbox"]) if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]: max_iou = iou matched_worker = worker if matched_worker: matched_worker["bbox"] = detection["bounding_box"] matched_worker["last_seen"] = current_time worker_id = matched_worker["id"] else: worker_id = len(workers) + 1 workers.append({ "id": worker_id, "bbox": detection["bounding_box"], "first_seen": current_time, "last_seen": current_time }) if worker_id not in helmet_compliance: helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False} if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]: logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping") continue detection["worker_id"] = worker_id violation_history[violation_type].append(detection) workers = [w for w in workers if current_time - w["last_seen"] < 5.0] logger.info(f"Detection counts: {detection_counts}") for violation_type, detections in violation_history.items(): if not detections: logger.info(f"No detections for {violation_type}") continue worker_violations = {} for det in detections: if det["worker_id"] not in worker_violations: worker_violations[det["worker_id"]] = [] worker_violations[det["worker_id"]].append(det) for worker_id, worker_dets in worker_violations.items(): if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]: if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]: continue if violation_type == "no_helmet": if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]: logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance") continue if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2: logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections") continue best_detection = max(worker_dets, key=lambda x: x["confidence"]) violations.append(best_detection) if worker_id not in confirmed_violations: confirmed_violations[worker_id] = set() confirmed_violations[worker_id].add(violation_type) if not snapshot_taken[violation_type]: cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"]) ret, snapshot_frame = cap.read() if not ret: logger.error(f"Failed to capture snapshot for {violation_type} at frame {best_detection['frame']}") cap.release() continue snapshot_frame = draw_detections(snapshot_frame, [best_detection]) snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg" snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename) cv2.imwrite(snapshot_path, snapshot_frame) snapshots.append({ "violation": violation_type, "frame": best_detection["frame"], "snapshot_path": snapshot_path, "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}" }) snapshot_taken[violation_type] = True logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}") cap.release() video.release() os.remove(video_path) logger.info(f"Video file {video_path} removed") if not violations: logger.info("No persistent violations detected") yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A" return score = calculate_safety_score(violations) pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score) report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n" violation_table += "|------------------------|---------------|------------|-----------|\n" for v in violations: display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n" violation_table += row snapshots_text = "No snapshots captured." if snapshots: violation_name_map = CONFIG["DISPLAY_NAMES"] snapshots_text = "\n".join( f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: ![]({s.get('snapshot_base64', '')})" for s in snapshots ) logger.info(f"Processing complete: {len(violations)} violations detected, score: {score}%") yield ( violation_table, f"Safety Score: {score}%", snapshots_text, f"Salesforce Record ID: {report_id or 'N/A'}", final_pdf_url or "N/A" ) except Exception as e: logger.error(f"Error processing video: {e}", exc_info=True) yield f"Error processing video: {e}", "", "", "", "" # ========================== # Gradio Interface # ========================== def gradio_interface(video_file): if not video_file: return "No file uploaded.", "", "No file uploaded.", "", "" try: with open(video_file, "rb") as f: video_data = f.read() for status, score, snapshots_text, record_id, details_url in process_video(video_data): yield status, score, snapshots_text, record_id, details_url except Exception as e: logger.error(f"Error in Gradio interface: {e}", exc_info=True) yield f"Error: {str(e)}", "", "Error in processing.", "", "" interface = gr.Interface( fn=gradio_interface, inputs=gr.Video(label="Upload Site Video"), outputs=[ gr.Markdown(label="Detected Safety Violations"), gr.Textbox(label="Compliance Score"), gr.Markdown(label="Snapshots"), gr.Textbox(label="Salesforce Record ID"), gr.Textbox(label="Violation Details URL") ], title="Worksite Safety Violation Analyzer", description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.", allow_flagging="never" ) if __name__ == "__main__": logger.info("Launching Enhanced Safety Analyzer App...") interface.launch()