import os import cv2 import gradio as gr import torch import numpy as np from ultralytics import YOLO import time from simple_salesforce import Salesforce from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas from reportlab.lib.units import inch from io import BytesIO import base64 import logging from retrying import retry import uuid # ========================== # Configuration # ========================== CONFIG = { "MODEL_PATH": "yolov8n.pt", # Lightweight model, must be trained for violations only "OUTPUT_DIR": "static/output", "VIOLATION_LABELS": { 0: "no_helmet", 1: "no_harness", 2: "unsafe_posture", 3: "unsafe_zone" # Ignored in processing }, "DISPLAY_NAMES": { # Mapping for user-friendly violation names "no_helmet": "Missing Helmet", "no_harness": "Missing Harness", "unsafe_posture": "Unsafe Posture" }, "SF_CREDENTIALS": { "username": "prashanth1ai@safety.com", "password": "SaiPrash461", "security_token": "AP4AQnPoidIKPvSvNEfAHyoK", "domain": "login" }, "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/", "FRAME_SKIP": 15, # Process every 15th frame "MAX_PROCESSING_TIME": 25, # Cap video processing at 25s "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection } # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Ensure output directory exists os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True) # ========================== # Device Setup # ========================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # ========================== # Model Loading # ========================== def load_model(): try: model = YOLO(CONFIG["MODEL_PATH"]).to(device) logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}") logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.") return model except Exception as e: logger.error(f"Failed to load model: {e}") raise model = load_model() # ========================== # Salesforce Integration # ========================== @retry(stop_max_attempt_number=2, wait_fixed=1000) def connect_to_salesforce(): try: sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) logger.info("Connected to Salesforce") sf.describe() return sf except Exception as e: logger.error(f"Salesforce connection failed: {e}") raise def generate_violation_pdf(violations, score): try: pdf_filename = f"violations_{int(time.time())}.pdf" pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename) pdf_file = BytesIO() c = canvas.Canvas(pdf_file, pagesize=letter) c.setFont("Helvetica", 12) c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report") c.setFont("Helvetica", 10) y_position = 9.5 * inch report_data = { "Compliance Score": f"{score}%", "Violations Found": len(violations), "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S") } for key, value in report_data.items(): c.drawString(1 * inch, y_position, f"{key}: {value}") y_position -= 0.3 * inch y_position -= 0.3 * inch c.drawString(1 * inch, y_position, "Violation Details:") y_position -= 0.3 * inch if not violations: c.drawString(1 * inch, y_position, "No violations detected.") else: for v in violations: display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"]) text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})" c.drawString(1 * inch, y_position, text) y_position -= 0.3 * inch if y_position < 1 * inch: c.showPage() c.setFont("Helvetica", 10) y_position = 10 * inch c.showPage() c.save() pdf_file.seek(0) with open(pdf_path, "wb") as f: f.write(pdf_file.getvalue()) public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" logger.info(f"PDF generated: {public_url}") return pdf_path, public_url, pdf_file except Exception as e: logger.error(f"Error generating PDF: {e}") return "", "", None @retry(stop_max_attempt_number=2, wait_fixed=1000) def upload_pdf_to_salesforce(sf, pdf_file, report_id): try: if not pdf_file: logger.error("No PDF file provided for upload") return "" encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') content_version_data = { "Title": f"Safety_Violation_Report_{int(time.time())}", "PathOnClient": f"safety_violation_{int(time.time())}.pdf", "VersionData": encoded_pdf, "FirstPublishLocationId": report_id } content_version = sf.ContentVersion.create(content_version_data) result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'") if not result['records']: logger.error("Failed to retrieve ContentVersion") return "" file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" logger.info(f"PDF uploaded to Salesforce: {file_url}") return file_url except Exception as e: logger.error(f"Error uploading PDF to Salesforce: {e}") return "" @retry(stop_max_attempt_number=2, wait_fixed=1000) def push_report_to_salesforce(violations, score, pdf_path, pdf_file): try: sf = connect_to_salesforce() violations_text = "\n".join( f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})" for v in violations ) or "No violations detected." pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else "" record_data = { "Compliance_Score__c": score, "Violations_Found__c": len(violations), "Violations_Details__c": violations_text, "Status__c": "Pending", "PDF_Report_URL__c": pdf_url } logger.info(f"Creating Salesforce record with data: {record_data}") try: record = sf.Safety_Video_Report__c.create(record_data) logger.info(f"Created Safety_Video_Report__c record: {record['id']}") except Exception as e: logger.error(f"Failed to create Safety_Video_Report__c: {e}") record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) logger.warning(f"Fell back to Account record: {record['id']}") record_id = record["id"] if pdf_file: uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) if uploaded_url: try: sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url}) logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}") except Exception as e: logger.error(f"Failed to update Safety_Video_Report__c: {e}") sf.Account.update(record_id, {"Description": uploaded_url}) logger.info(f"Updated Account record {record_id} with PDF URL") pdf_url = uploaded_url return record_id, pdf_url except Exception as e: logger.error(f"Salesforce record creation failed: {e}") return None, "" # ========================== # Safety Score Calculation # ========================== def calculate_safety_score(violations): penalties = { "no_helmet": 25, "no_harness": 30, "unsafe_posture": 20 } score = 100 for v in violations: if v["violation"] in penalties: score -= penalties[v["violation"]] return max(score, 0) # ========================== # Video Processing # ========================== def process_video(video_data): try: video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4") with open(video_path, "wb") as f: f.write(video_data) logger.info(f"Video saved: {video_path}") video = cv2.VideoCapture(video_path) if not video.isOpened(): raise ValueError("Could not open video file") violations, snapshots = [], [] frame_count = 0 start_time = time.time() fps = video.get(cv2.CAP_PROP_FPS) max_frames = int(60 * fps) # Process up to 1 minute # Track one snapshot per violation type snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False} while True: ret, frame = video.read() if not ret or frame_count >= max_frames: break if frame_count % CONFIG["FRAME_SKIP"] != 0: frame_count += 1 continue # Stop if processing time exceeds 25 seconds if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]: logger.info("Processing time limit reached") break results = model(frame, device=device) seen_violations = set() for result in results: for box in result.boxes: cls, conf = int(box.cls), float(box.conf) label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}") # Only process specified violations if label not in ["no_helmet", "no_harness", "unsafe_posture"]: logger.info(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation") continue # Apply confidence threshold if conf < CONFIG["CONFIDENCE_THRESHOLD"]: logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})") continue if label in seen_violations: continue seen_violations.add(label) violation = { "frame": frame_count, "violation": label, "confidence": round(conf, 2), "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]], "timestamp": frame_count / fps } violations.append(violation) # Save only one snapshot per violation type if not snapshot_taken[label]: snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg") cv2.imwrite(snapshot_path, frame) with open(snapshot_path, "rb") as img_file: img_base64 = base64.b64encode(img_file.read()).decode('utf-8') snapshots.append({ "violation": label, "frame": frame_count, "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}", "snapshot_base64": f"data:image/jpeg;base64,{img_base64}" }) snapshot_taken[label] = True frame_count += 1 video.release() os.remove(video_path) if not violations: logger.info("No violations detected") return { "violations": [], "snapshots": [], "score": 100, "salesforce_record_id": None, "violation_details_url": "" } score = calculate_safety_score(violations) pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score) report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) return { "violations": violations, "snapshots": snapshots, "score": score, "salesforce_record_id": report_id, "violation_details_url": final_pdf_url } except Exception as e: logger.error(f"Error processing video: {e}") return { "violations": [], "snapshots": [], "score": 100, "salesforce_record_id": None, "violation_details_url": "" } # ========================== # Gradio Interface # ========================== def gradio_interface(video_file): if not video_file: return "No file uploaded.", "", "No file uploaded.", "", "" try: with open(video_file, "rb") as f: video_data = f.read() result = process_video(video_data) violation_table = "No violations detected." if result["violations"]: header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n" separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n" rows = [] for v in result["violations"]: display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"]) row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |" rows.append(row) violation_table = header + separator + "\n".join(rows) snapshots_text = "No snapshots captured." if result["snapshots"]: snapshots_text = "\n".join( f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})" for s in result["snapshots"] ) return ( violation_table, f"Safety Score: {result['score']}%", snapshots_text, f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}", result["violation_details_url"] or "N/A" ) except Exception as e: logger.error(f"Error in Gradio interface: {e}") return f"Error: {str(e)}", "", "Error in processing.", "", "" interface = gr.Interface( fn=gradio_interface, inputs=gr.Video(label="Upload Site Video"), outputs=[ gr.Markdown(label="Detected Safety Violations"), gr.Textbox(label="Compliance Score"), gr.Markdown(label="Snapshots"), gr.Textbox(label="Salesforce Record ID"), gr.Textbox(label="Violation Details URL") ], title="Worksite Safety Violation Analyzer", description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored." ) if __name__ == "__main__": logger.info("Launching Safety Analyzer App...") interface.launch()