Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import time | |
| from simple_salesforce import Salesforce | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.units import inch | |
| from io import BytesIO | |
| import base64 | |
| import logging | |
| from retrying import retry | |
| import uuid | |
| # ========================== | |
| # Configuration | |
| # ========================== | |
| CONFIG = { | |
| "MODEL_PATH": "yolov8n.pt", # Lightweight model, must be trained for violations only | |
| "OUTPUT_DIR": "static/output", | |
| "VIOLATION_LABELS": { | |
| 0: "no_helmet", | |
| 1: "no_harness", | |
| 2: "unsafe_posture", | |
| 3: "unsafe_zone" # Ignored in processing | |
| }, | |
| "DISPLAY_NAMES": { # Mapping for user-friendly violation names | |
| "no_helmet": "Missing Helmet", | |
| "no_harness": "Missing Harness", | |
| "unsafe_posture": "Unsafe Posture" | |
| }, | |
| "SF_CREDENTIALS": { | |
| "username": "prashanth1ai@safety.com", | |
| "password": "SaiPrash461", | |
| "security_token": "AP4AQnPoidIKPvSvNEfAHyoK", | |
| "domain": "login" | |
| }, | |
| "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/", | |
| "FRAME_SKIP": 15, # Process every 15th frame | |
| "MAX_PROCESSING_TIME": 25, # Cap video processing at 25s | |
| "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection | |
| } | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Ensure output directory exists | |
| os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True) | |
| # ========================== | |
| # Device Setup | |
| # ========================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # ========================== | |
| # Model Loading | |
| # ========================== | |
| def load_model(): | |
| try: | |
| model = YOLO(CONFIG["MODEL_PATH"]).to(device) | |
| logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}") | |
| logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| model = load_model() | |
| # ========================== | |
| # Salesforce Integration | |
| # ========================== | |
| def connect_to_salesforce(): | |
| try: | |
| sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) | |
| logger.info("Connected to Salesforce") | |
| sf.describe() | |
| return sf | |
| except Exception as e: | |
| logger.error(f"Salesforce connection failed: {e}") | |
| raise | |
| def generate_violation_pdf(violations, score): | |
| try: | |
| pdf_filename = f"violations_{int(time.time())}.pdf" | |
| pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename) | |
| pdf_file = BytesIO() | |
| c = canvas.Canvas(pdf_file, pagesize=letter) | |
| c.setFont("Helvetica", 12) | |
| c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report") | |
| c.setFont("Helvetica", 10) | |
| y_position = 9.5 * inch | |
| report_data = { | |
| "Compliance Score": f"{score}%", | |
| "Violations Found": len(violations), | |
| "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| for key, value in report_data.items(): | |
| c.drawString(1 * inch, y_position, f"{key}: {value}") | |
| y_position -= 0.3 * inch | |
| y_position -= 0.3 * inch | |
| c.drawString(1 * inch, y_position, "Violation Details:") | |
| y_position -= 0.3 * inch | |
| if not violations: | |
| c.drawString(1 * inch, y_position, "No violations detected.") | |
| else: | |
| for v in violations: | |
| display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"]) | |
| text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})" | |
| c.drawString(1 * inch, y_position, text) | |
| y_position -= 0.3 * inch | |
| if y_position < 1 * inch: | |
| c.showPage() | |
| c.setFont("Helvetica", 10) | |
| y_position = 10 * inch | |
| c.showPage() | |
| c.save() | |
| pdf_file.seek(0) | |
| with open(pdf_path, "wb") as f: | |
| f.write(pdf_file.getvalue()) | |
| public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" | |
| logger.info(f"PDF generated: {public_url}") | |
| return pdf_path, public_url, pdf_file | |
| except Exception as e: | |
| logger.error(f"Error generating PDF: {e}") | |
| return "", "", None | |
| def upload_pdf_to_salesforce(sf, pdf_file, report_id): | |
| try: | |
| if not pdf_file: | |
| logger.error("No PDF file provided for upload") | |
| return "" | |
| encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') | |
| content_version_data = { | |
| "Title": f"Safety_Violation_Report_{int(time.time())}", | |
| "PathOnClient": f"safety_violation_{int(time.time())}.pdf", | |
| "VersionData": encoded_pdf, | |
| "FirstPublishLocationId": report_id | |
| } | |
| content_version = sf.ContentVersion.create(content_version_data) | |
| result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'") | |
| if not result['records']: | |
| logger.error("Failed to retrieve ContentVersion") | |
| return "" | |
| file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" | |
| logger.info(f"PDF uploaded to Salesforce: {file_url}") | |
| return file_url | |
| except Exception as e: | |
| logger.error(f"Error uploading PDF to Salesforce: {e}") | |
| return "" | |
| def push_report_to_salesforce(violations, score, pdf_path, pdf_file): | |
| try: | |
| sf = connect_to_salesforce() | |
| violations_text = "\n".join( | |
| f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})" | |
| for v in violations | |
| ) or "No violations detected." | |
| pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else "" | |
| record_data = { | |
| "Compliance_Score__c": score, | |
| "Violations_Found__c": len(violations), | |
| "Violations_Details__c": violations_text, | |
| "Status__c": "Pending", | |
| "PDF_Report_URL__c": pdf_url | |
| } | |
| logger.info(f"Creating Salesforce record with data: {record_data}") | |
| try: | |
| record = sf.Safety_Video_Report__c.create(record_data) | |
| logger.info(f"Created Safety_Video_Report__c record: {record['id']}") | |
| except Exception as e: | |
| logger.error(f"Failed to create Safety_Video_Report__c: {e}") | |
| record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) | |
| logger.warning(f"Fell back to Account record: {record['id']}") | |
| record_id = record["id"] | |
| if pdf_file: | |
| uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) | |
| if uploaded_url: | |
| try: | |
| sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url}) | |
| logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}") | |
| except Exception as e: | |
| logger.error(f"Failed to update Safety_Video_Report__c: {e}") | |
| sf.Account.update(record_id, {"Description": uploaded_url}) | |
| logger.info(f"Updated Account record {record_id} with PDF URL") | |
| pdf_url = uploaded_url | |
| return record_id, pdf_url | |
| except Exception as e: | |
| logger.error(f"Salesforce record creation failed: {e}") | |
| return None, "" | |
| # ========================== | |
| # Safety Score Calculation | |
| # ========================== | |
| def calculate_safety_score(violations): | |
| penalties = { | |
| "no_helmet": 25, | |
| "no_harness": 30, | |
| "unsafe_posture": 20 | |
| } | |
| score = 100 | |
| for v in violations: | |
| if v["violation"] in penalties: | |
| score -= penalties[v["violation"]] | |
| return max(score, 0) | |
| # ========================== | |
| # Video Processing | |
| # ========================== | |
| def process_video(video_data): | |
| try: | |
| video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4") | |
| with open(video_path, "wb") as f: | |
| f.write(video_data) | |
| logger.info(f"Video saved: {video_path}") | |
| video = cv2.VideoCapture(video_path) | |
| if not video.isOpened(): | |
| raise ValueError("Could not open video file") | |
| violations, snapshots = [], [] | |
| frame_count = 0 | |
| start_time = time.time() | |
| fps = video.get(cv2.CAP_PROP_FPS) | |
| max_frames = int(60 * fps) # Process up to 1 minute | |
| # Track one snapshot per violation type | |
| snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False} | |
| while True: | |
| ret, frame = video.read() | |
| if not ret or frame_count >= max_frames: | |
| break | |
| if frame_count % CONFIG["FRAME_SKIP"] != 0: | |
| frame_count += 1 | |
| continue | |
| # Stop if processing time exceeds 25 seconds | |
| if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]: | |
| logger.info("Processing time limit reached") | |
| break | |
| results = model(frame, device=device) | |
| seen_violations = set() | |
| for result in results: | |
| for box in result.boxes: | |
| cls, conf = int(box.cls), float(box.conf) | |
| label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}") | |
| # Only process specified violations | |
| if label not in ["no_helmet", "no_harness", "unsafe_posture"]: | |
| logger.info(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation") | |
| continue | |
| # Apply confidence threshold | |
| if conf < CONFIG["CONFIDENCE_THRESHOLD"]: | |
| logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})") | |
| continue | |
| if label in seen_violations: | |
| continue | |
| seen_violations.add(label) | |
| violation = { | |
| "frame": frame_count, | |
| "violation": label, | |
| "confidence": round(conf, 2), | |
| "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]], | |
| "timestamp": frame_count / fps | |
| } | |
| violations.append(violation) | |
| # Save only one snapshot per violation type | |
| if not snapshot_taken[label]: | |
| snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg") | |
| cv2.imwrite(snapshot_path, frame) | |
| with open(snapshot_path, "rb") as img_file: | |
| img_base64 = base64.b64encode(img_file.read()).decode('utf-8') | |
| snapshots.append({ | |
| "violation": label, | |
| "frame": frame_count, | |
| "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}", | |
| "snapshot_base64": f"data:image/jpeg;base64,{img_base64}" | |
| }) | |
| snapshot_taken[label] = True | |
| frame_count += 1 | |
| video.release() | |
| os.remove(video_path) | |
| if not violations: | |
| logger.info("No violations detected") | |
| return { | |
| "violations": [], | |
| "snapshots": [], | |
| "score": 100, | |
| "salesforce_record_id": None, | |
| "violation_details_url": "" | |
| } | |
| score = calculate_safety_score(violations) | |
| pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score) | |
| report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) | |
| return { | |
| "violations": violations, | |
| "snapshots": snapshots, | |
| "score": score, | |
| "salesforce_record_id": report_id, | |
| "violation_details_url": final_pdf_url | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing video: {e}") | |
| return { | |
| "violations": [], | |
| "snapshots": [], | |
| "score": 100, | |
| "salesforce_record_id": None, | |
| "violation_details_url": "" | |
| } | |
| # ========================== | |
| # Gradio Interface | |
| # ========================== | |
| def gradio_interface(video_file): | |
| if not video_file: | |
| return "No file uploaded.", "", "No file uploaded.", "", "" | |
| try: | |
| with open(video_file, "rb") as f: | |
| video_data = f.read() | |
| result = process_video(video_data) | |
| violation_table = "No violations detected." | |
| if result["violations"]: | |
| header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n" | |
| separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n" | |
| rows = [] | |
| for v in result["violations"]: | |
| display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"]) | |
| row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |" | |
| rows.append(row) | |
| violation_table = header + separator + "\n".join(rows) | |
| snapshots_text = "No snapshots captured." | |
| if result["snapshots"]: | |
| snapshots_text = "\n".join( | |
| f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: " | |
| for s in result["snapshots"] | |
| ) | |
| return ( | |
| violation_table, | |
| f"Safety Score: {result['score']}%", | |
| snapshots_text, | |
| f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}", | |
| result["violation_details_url"] or "N/A" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in Gradio interface: {e}") | |
| return f"Error: {str(e)}", "", "Error in processing.", "", "" | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Video(label="Upload Site Video"), | |
| outputs=[ | |
| gr.Markdown(label="Detected Safety Violations"), | |
| gr.Textbox(label="Compliance Score"), | |
| gr.Markdown(label="Snapshots"), | |
| gr.Textbox(label="Salesforce Record ID"), | |
| gr.Textbox(label="Violation Details URL") | |
| ], | |
| title="Worksite Safety Violation Analyzer", | |
| description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored." | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Launching Safety Analyzer App...") | |
| interface.launch() |