Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import time | |
| from simple_salesforce import Salesforce | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.units import inch | |
| from io import BytesIO | |
| import base64 | |
| import logging | |
| from retrying import retry | |
| # ========================== | |
| # OPTIMIZED CONFIGURATION | |
| # ========================== | |
| CONFIG = { | |
| "MODEL_PATH": "yolov8_safety.pt", | |
| "FALLBACK_MODEL": "yolov8n.pt", | |
| "OUTPUT_DIR": "static/output", | |
| "VIOLATION_LABELS": { | |
| 0: "no_helmet", | |
| 1: "no_harness", | |
| 2: "unsafe_posture", | |
| 3: "unsafe_zone", | |
| 4: "improper_tool_use" | |
| }, | |
| "CLASS_COLORS": { | |
| "no_helmet": (0, 0, 255), # Red | |
| "no_harness": (0, 165, 255), # Orange | |
| "unsafe_posture": (0, 255, 0), # Green | |
| "unsafe_zone": (255, 0, 0), # Blue | |
| "improper_tool_use": (255, 255, 0) # Yellow | |
| }, | |
| "DISPLAY_NAMES": { | |
| "no_helmet": "No Helmet", | |
| "no_harness": "No Harness", | |
| "unsafe_posture": "Unsafe Posture", | |
| "unsafe_zone": "Unsafe Zone", | |
| "improper_tool_use": "Improper Tool Use" | |
| }, | |
| "SF_CREDENTIALS": { | |
| "username": "prashanth1ai@safety.com", | |
| "password": "SaiPrash461", | |
| "security_token": "AP4AQnPoidIKPvSvNEfAHyoK", | |
| "domain": "login" | |
| }, | |
| "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/", | |
| "FRAME_SKIP": 3, | |
| "MAX_PROCESSING_TIME": 60, | |
| "CONFIDENCE_THRESHOLD": { | |
| "no_helmet": 0.4, | |
| "no_harness": 0.3, | |
| "unsafe_posture": 0.25, | |
| "unsafe_zone": 0.3, | |
| "improper_tool_use": 0.35 | |
| }, | |
| "IOU_THRESHOLD": 0.4, | |
| "MIN_VIOLATION_FRAMES": 3, | |
| "MIN_VIOLATION_DURATION": 1.5 | |
| } | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True) | |
| # Device configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| def load_model(): | |
| try: | |
| if os.path.exists(CONFIG["MODEL_PATH"]): | |
| model = YOLO(CONFIG["MODEL_PATH"]).to(device) | |
| logger.info("Loaded custom safety model") | |
| else: | |
| model = YOLO(CONFIG["FALLBACK_MODEL"]).to(device) | |
| logger.warning("Using fallback model - recommend training yolov8_safety.pt") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {str(e)}") | |
| raise | |
| model = load_model() | |
| def draw_detections(frame, detections): | |
| """Draw bounding boxes with labels and confidence scores""" | |
| for det in detections: | |
| label = det["violation"] | |
| x, y, w, h = [int(v) for v in det["bounding_box"]] | |
| color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255)) | |
| x1, y1 = int(x - w/2), int(y - h/2) | |
| x2, y2 = int(x + w/2), int(y + h/2) | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| label_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {det['confidence']:.2f}" | |
| (text_width, text_height), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(frame, (x1, y1 - text_height - 10), (x1 + text_width, y1), color, -1) | |
| cv2.putText(frame, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| return frame | |
| def calculate_iou(box1, box2): | |
| """Calculate Intersection over Union for two bounding boxes""" | |
| box1 = [box1[0] - box1[2]/2, box1[1] - box1[3]/2, box1[0] + box1[2]/2, box1[1] + box1[3]/2] | |
| box2 = [box2[0] - box2[2]/2, box2[1] - box2[3]/2, box2[0] + box2[2]/2, box2[1] + box2[3]/2] | |
| x_left = max(box1[0], box2[0]) | |
| y_top = max(box1[1], box2[1]) | |
| x_right = min(box1[2], box2[2]) | |
| y_bottom = min(box1[3], box2[3]) | |
| if x_right < x_left or y_bottom < y_top: | |
| return 0.0 | |
| intersection_area = (x_right - x_left) * (y_bottom - y_top) | |
| box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| return intersection_area / float(box1_area + box2_area - intersection_area) | |
| def generate_violation_pdf(violations, score): | |
| try: | |
| pdf_filename = f"violations_{int(time.time())}.pdf" | |
| pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename) | |
| pdf_file = BytesIO() | |
| c = canvas.Canvas(pdf_file, pagesize=letter) | |
| c.setFont("Helvetica-Bold", 14) | |
| c.drawString(1 * inch, 10.5 * inch, "Worksite Safety Violation Report") | |
| c.setFont("Helvetica", 12) | |
| y_position = 10 * inch | |
| report_data = [ | |
| ("Compliance Score", f"{score}%"), | |
| ("Total Violations", len(violations)), | |
| ("Report Date", time.strftime("%Y-%m-%d %H:%M:%S")) | |
| ] | |
| for key, value in report_data: | |
| c.drawString(1 * inch, y_position, f"{key}: {value}") | |
| y_position -= 0.4 * inch | |
| y_position -= 0.2 * inch | |
| c.line(1 * inch, y_position, 7.5 * inch, y_position) | |
| y_position -= 0.3 * inch | |
| c.setFont("Helvetica-Bold", 12) | |
| c.drawString(1 * inch, y_position, "Violation Details:") | |
| y_position -= 0.3 * inch | |
| c.setFont("Helvetica", 10) | |
| if not violations: | |
| c.drawString(1 * inch, y_position, "No violations detected.") | |
| else: | |
| for v in violations: | |
| violation_text = ( | |
| f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} " | |
| f"at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, " | |
| f"Worker: {v['worker_id']})" | |
| ) | |
| c.drawString(1 * inch, y_position, violation_text) | |
| y_position -= 0.25 * inch | |
| if y_position < 1 * inch: | |
| c.showPage() | |
| y_position = 10 * inch | |
| c.setFont("Helvetica", 10) | |
| c.save() | |
| pdf_file.seek(0) | |
| with open(pdf_path, "wb") as f: | |
| f.write(pdf_file.getvalue()) | |
| public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" | |
| logger.info(f"Generated PDF report: {public_url}") | |
| return pdf_path, public_url, pdf_file | |
| except Exception as e: | |
| logger.error(f"PDF generation failed: {str(e)}") | |
| return "", "", None | |
| def connect_to_salesforce(): | |
| try: | |
| sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) | |
| logger.info("Connected to Salesforce") | |
| return sf | |
| except Exception as e: | |
| logger.error(f"Salesforce connection failed: {str(e)}") | |
| raise | |
| def upload_pdf_to_salesforce(sf, pdf_file, report_id): | |
| try: | |
| encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') | |
| content_version = sf.ContentVersion.create({ | |
| "Title": f"Safety_Report_{int(time.time())}", | |
| "PathOnClient": "safety_report.pdf", | |
| "VersionData": encoded_pdf, | |
| "FirstPublishLocationId": report_id | |
| }) | |
| return f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" | |
| except Exception as e: | |
| logger.error(f"PDF upload failed: {str(e)}") | |
| return "" | |
| def push_report_to_salesforce(violations, score, pdf_path, pdf_file): | |
| try: | |
| sf = connect_to_salesforce() | |
| violations_text = "\n".join( | |
| f"- {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} " | |
| f"at {v['timestamp']:.2f}s (Worker {v['worker_id']}, Confidence: {v['confidence']:.2f})" | |
| for v in violations | |
| ) or "No violations detected" | |
| record_data = { | |
| "Compliance_Score__c": score, | |
| "Violations_Found__c": len(violations), | |
| "Violations_Details__c": violations_text, | |
| "Status__c": "New" | |
| } | |
| try: | |
| record = sf.Safety_Video_Report__c.create(record_data) | |
| record_id = record["id"] | |
| logger.info(f"Created Salesforce record: {record_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to create Safety Report: {str(e)}") | |
| record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) | |
| record_id = record["id"] | |
| logger.warning(f"Created fallback Account record: {record_id}") | |
| pdf_url = "" | |
| if pdf_file: | |
| pdf_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) | |
| if pdf_url: | |
| try: | |
| sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": pdf_url}) | |
| except: | |
| sf.Account.update(record_id, {"Description": pdf_url}) | |
| return record_id, pdf_url if pdf_url else f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" | |
| except Exception as e: | |
| logger.error(f"Salesforce integration failed: {str(e)}") | |
| return None, "" | |
| def calculate_safety_score(violations): | |
| penalties = { | |
| "no_helmet": 25, | |
| "no_harness": 30, | |
| "unsafe_posture": 20, | |
| "unsafe_zone": 35, | |
| "improper_tool_use": 25 | |
| } | |
| unique_violations = {(v["worker_id"], v["violation"]) for v in violations} | |
| total_penalty = sum(penalties.get(v[1], 0) for v in unique_violations) | |
| return max(100 - total_penalty, 0) | |
| def process_video(video_data): | |
| try: | |
| temp_video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4") | |
| with open(temp_video_path, "wb") as f: | |
| f.write(video_data) | |
| cap = cv2.VideoCapture(temp_video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| workers = [] | |
| violations = [] | |
| snapshots = [] | |
| violation_history = {k: [] for k in CONFIG["VIOLATION_LABELS"].values()} | |
| snapshot_taken = {k: False for k in CONFIG["VIOLATION_LABELS"].values()} | |
| frame_count = 0 | |
| start_time = time.time() | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % CONFIG["FRAME_SKIP"] != 0: | |
| frame_count += 1 | |
| continue | |
| if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]: | |
| logger.warning("Processing timeout reached") | |
| break | |
| current_time = frame_count / fps | |
| results = model(frame, device=device, verbose=False) | |
| for result in results: | |
| for box in result.boxes: | |
| cls = int(box.cls) | |
| conf = float(box.conf) | |
| label = CONFIG["VIOLATION_LABELS"].get(cls) | |
| if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3): | |
| continue | |
| bbox = box.xywh.cpu().numpy()[0].tolist() | |
| matched_worker = None | |
| max_iou = 0 | |
| for worker in workers: | |
| iou = calculate_iou(bbox, worker["bbox"]) | |
| if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]: | |
| max_iou = iou | |
| matched_worker = worker | |
| if matched_worker: | |
| worker_id = matched_worker["id"] | |
| matched_worker["bbox"] = bbox | |
| matched_worker["last_seen"] = current_time | |
| else: | |
| worker_id = len(workers) + 1 | |
| workers.append({ | |
| "id": worker_id, | |
| "bbox": bbox, | |
| "first_seen": current_time, | |
| "last_seen": current_time | |
| }) | |
| violation_history[label].append({ | |
| "frame": frame_count, | |
| "violation": label, | |
| "confidence": round(conf, 2), | |
| "bounding_box": bbox, | |
| "timestamp": current_time, | |
| "worker_id": worker_id | |
| }) | |
| frame_count += 1 | |
| for violation_type, detections in violation_history.items(): | |
| if not detections: | |
| continue | |
| worker_groups = {} | |
| for det in detections: | |
| if det["worker_id"] not in worker_groups: | |
| worker_groups[det["worker_id"]] = [] | |
| worker_groups[det["worker_id"]].append(det) | |
| for worker_id, worker_dets in worker_groups.items(): | |
| if len(worker_dets) < 2: | |
| continue | |
| duration = worker_dets[-1]["timestamp"] - worker_dets[0]["timestamp"] | |
| if duration >= CONFIG["MIN_VIOLATION_DURATION"]: | |
| best_det = max(worker_dets, key=lambda x: x["confidence"]) | |
| violations.append(best_det) | |
| if not snapshot_taken[violation_type]: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, best_det["frame"]) | |
| ret, snapshot_frame = cap.read() | |
| if ret: | |
| snapshot_frame = draw_detections(snapshot_frame, [best_det]) | |
| filename = f"{violation_type}_{best_det['frame']}.jpg" | |
| path = os.path.join(CONFIG["OUTPUT_DIR"], filename) | |
| cv2.imwrite(path, snapshot_frame) | |
| snapshots.append({ | |
| "violation": violation_type, | |
| "frame": best_det["frame"], | |
| "path": path, | |
| "url": f"{CONFIG['PUBLIC_URL_BASE']}{filename}" | |
| }) | |
| snapshot_taken[violation_type] = True | |
| cap.release() | |
| os.remove(temp_video_path) | |
| score = calculate_safety_score(violations) | |
| pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score) | |
| record_id, sf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) | |
| return { | |
| "violations": violations, | |
| "snapshots": snapshots, | |
| "score": score, | |
| "salesforce_record_id": record_id, | |
| "violation_details_url": sf_url or pdf_url, | |
| "message": "" | |
| } | |
| except Exception as e: | |
| logger.error(f"Video processing failed: {str(e)}") | |
| return { | |
| "violations": [], | |
| "snapshots": [], | |
| "score": 100, | |
| "salesforce_record_id": None, | |
| "violation_details_url": "", | |
| "message": f"Error: {str(e)}" | |
| } | |
| def gradio_interface(video_file): | |
| try: | |
| yield "Analyzing video...", "", "", "", "" | |
| with open(video_file, "rb") as f: | |
| result = process_video(f.read()) | |
| violation_table = ( | |
| "| Violation Type | Timestamp | Confidence | Worker ID |\n" | |
| "|---------------------|-----------|------------|-----------|\n" + | |
| "\n".join( | |
| f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<19} | " | |
| f"{v['timestamp']:.2f} | " | |
| f"{v['confidence']:.2f} | " | |
| f"{v['worker_id']} |" | |
| for v in result["violations"] | |
| ) | |
| ) if result["violations"] else "No violations detected" | |
| snapshots_md = "\n".join( | |
| f"![{s['violation']} at frame {s['frame']}]({s['url']})" | |
| for s in result["snapshots"] | |
| ) if result["snapshots"] else "No snapshots" | |
| yield ( | |
| violation_table, | |
| f"Safety Score: {result['score']}%", | |
| snapshots_md, | |
| f"Salesforce ID: {result['salesforce_record_id'] or 'None'}", | |
| result["violation_details_url"] or "None" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Interface error: {str(e)}") | |
| yield f"Error: {str(e)}", "", "", "", "" | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Video(label="Upload Site Video"), | |
| outputs=[ | |
| gr.Markdown(label="Violations Detected"), | |
| gr.Textbox(label="Compliance Score"), | |
| gr.Markdown(label="Evidence Snapshots"), | |
| gr.Textbox(label="Salesforce Record"), | |
| gr.Textbox(label="Report URL") | |
| ], | |
| title="AI Safety Compliance Monitor", | |
| description="Detects 5 violation types: No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(share=True) |