import os import sys import subprocess import logging import warnings import cv2 import gradio as gr import torch import numpy as np from ultralytics import YOLO import time from simple_salesforce import Salesforce from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas from reportlab.lib.units import inch from io import BytesIO import base64 from retrying import retry import uuid from multiprocessing import Pool, cpu_count from functools import partial # ========================== # Configuration and Setup # ========================== # Handle Ultralytics config directory os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics' os.makedirs('/tmp/Ultralytics', exist_ok=True) # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # ========================== # ByteTrack Implementation # ========================== class BYTETracker: """Custom implementation of ByteTrack to avoid installation issues""" def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30): self.track_thresh = track_thresh self.track_buffer = track_buffer self.match_thresh = match_thresh self.frame_rate = frame_rate self.next_id = 1 def update(self, dets, scores, cls): tracks = [] for i, (det, score, cl) in enumerate(zip(dets, scores, cls)): if score < self.track_thresh: continue x, y, w, h = det tracks.append({ 'id': self.next_id, 'bbox': [x, y, w, h], 'score': score, 'cls': cl }) self.next_id += 1 return tracks # ========================== # Optimized Configuration # ========================== CONFIG = { "MODEL_PATH": "yolov8_safety.pt", "FALLBACK_MODEL": "yolov8n.pt", "OUTPUT_DIR": "static/output", "VIOLATION_LABELS": { 0: "no_helmet", 1: "no_harness", 2: "unsafe_posture", 3: "unsafe_zone", 4: "improper_tool_use" }, "CLASS_COLORS": { "no_helmet": (0, 0, 255), # Red "no_harness": (0, 165, 255), # Orange "unsafe_posture": (0, 255, 0), # Green "unsafe_zone": (255, 0, 0), # Blue "improper_tool_use": (255, 255, 0) # Yellow }, "DISPLAY_NAMES": { "no_helmet": "No Helmet Violation", "no_harness": "No Harness Violation", "unsafe_posture": "Unsafe Posture", "unsafe_zone": "Unsafe Zone Entry", "improper_tool_use": "Improper Tool Use" }, "SF_CREDENTIALS": { "username": "prashanth1ai@safety.com", "password": "SaiPrash461", "security_token": "AP4AQnPoidIKPvSvNEfAHyoK", "domain": "login" }, "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/", "CONFIDENCE_THRESHOLDS": { "no_helmet": 0.75, "no_harness": 0.4, "unsafe_posture": 0.4, "unsafe_zone": 0.4, "improper_tool_use": 0.4 }, "MIN_VIOLATION_FRAMES": 3, "WORKER_TRACKING_DURATION": 5.0, "MAX_PROCESSING_TIME": 60, "FRAME_SKIP": 1, "BATCH_SIZE": 32, "PARALLEL_WORKERS": max(1, cpu_count() - 1), "TRACK_BUFFER": 30, "TRACK_THRESH": 0.4, "MATCH_THRESH": 0.8 } # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") def load_model(): try: if os.path.isfile(CONFIG["MODEL_PATH"]): model_path = CONFIG["MODEL_PATH"] logger.info(f"Model loaded: {model_path}") else: model_path = CONFIG["FALLBACK_MODEL"] logger.warning("Using fallback model. Train yolov8_safety.pt for best results.") if not os.path.isfile(model_path): logger.info(f"Downloading fallback model: {model_path}") torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path) model = YOLO(model_path).to(device) return model except Exception as e: logger.error(f"Failed to load model: {e}") raise model = load_model() # [Rest of your existing functions remain exactly the same...] # draw_detections(), calculate_safety_score(), generate_violation_pdf(), # connect_to_salesforce(), upload_pdf_to_salesforce(), push_report_to_salesforce(), # process_video(), and gradio_interface() functions should be kept exactly as they were # ========================== # Optimized Helper Functions # ========================== def draw_detections(frame, detections): for det in detections: label = det.get("violation", "Unknown") confidence = det.get("confidence", 0.0) x, y, w, h = det.get("bounding_box", [0, 0, 0, 0]) x1 = int(x - w/2) y1 = int(y - h/2) x2 = int(x + w/2) y2 = int(y + h/2) color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255)) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}" cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return frame def calculate_safety_score(violations): penalties = { "no_helmet": 25, "no_harness": 30, "unsafe_posture": 20, "unsafe_zone": 35, "improper_tool_use": 25 } total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations) score = 100 - total_penalty return max(score, 0) def generate_violation_pdf(violations, score): try: pdf_filename = f"violations_{int(time.time())}.pdf" pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename) pdf_file = BytesIO() c = canvas.Canvas(pdf_file, pagesize=letter) c.setFont("Helvetica", 12) c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report") c.setFont("Helvetica", 10) y_position = 9.5 * inch report_data = { "Compliance Score": f"{score}%", "Violations Found": len(violations), "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S") } for key, value in report_data.items(): c.drawString(1 * inch, y_position, f"{key}: {value}") y_position -= 0.3 * inch y_position -= 0.3 * inch c.drawString(1 * inch, y_position, "Violation Details:") y_position -= 0.3 * inch if not violations: c.drawString(1 * inch, y_position, "No violations detected.") else: for v in violations: display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") text = f"{display_name} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})" c.drawString(1 * inch, y_position, text) y_position -= 0.3 * inch if y_position < 1 * inch: c.showPage() c.setFont("Helvetica", 10) y_position = 10 * inch c.showPage() c.save() pdf_file.seek(0) with open(pdf_path, "wb") as f: f.write(pdf_file.getvalue()) public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}" logger.info(f"PDF generated: {public_url}") return pdf_path, public_url, pdf_file except Exception as e: logger.error(f"Error generating PDF: {e}") return "", "", None # ========================== # Salesforce Integration # ========================== @retry(stop_max_attempt_number=3, wait_fixed=2000) def connect_to_salesforce(): try: sf = Salesforce(**CONFIG["SF_CREDENTIALS"]) logger.info("Connected to Salesforce") sf.describe() return sf except Exception as e: logger.error(f"Salesforce connection failed: {e}") raise def upload_pdf_to_salesforce(sf, pdf_file, report_id): try: if not pdf_file: logger.error("No PDF file provided for upload") return "" encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8') content_version_data = { "Title": f"Safety_Violation_Report_{int(time.time())}", "PathOnClient": f"safety_violation_{int(time.time())}.pdf", "VersionData": encoded_pdf, "FirstPublishLocationId": report_id } content_version = sf.ContentVersion.create(content_version_data) result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'") if not result['records']: logger.error("Failed to retrieve ContentVersion") return "" file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}" logger.info(f"PDF uploaded to Salesforce: {file_url}") return file_url except Exception as e: logger.error(f"Error uploading PDF to Salesforce: {e}") return "" def push_report_to_salesforce(violations, score, pdf_path, pdf_file): try: sf = connect_to_salesforce() violations_text = "\n".join( f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})" for v in violations ) or "No violations detected." pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else "" record_data = { "Compliance_Score__c": score, "Violations_Found__c": len(violations), "Violations_Details__c": violations_text, "Status__c": "Pending", "PDF_Report_URL__c": pdf_url } logger.info(f"Creating Salesforce record with data: {record_data}") try: record = sf.Safety_Video_Report__c.create(record_data) logger.info(f"Created Safety_Video_Report__c record: {record['id']}") except Exception as e: logger.error(f"Failed to create Safety_Video_Report__c: {e}") record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"}) logger.warning(f"Fell back to Account record: {record['id']}") record_id = record["id"] if pdf_file: uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id) if uploaded_url: try: sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url}) logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}") except Exception as e: logger.error(f"Failed to update Safety_Video_Report__c: {e}") sf.Account.update(record_id, {"Description": uploaded_url}) logger.info(f"Updated Account record {record_id} with PDF URL") pdf_url = uploaded_url return record_id, pdf_url except Exception as e: logger.error(f"Salesforce record creation failed: {e}", exc_info=True) return None, "" # ========================== # Fast Video Processing # ========================== def process_video(video_data): try: # Create temp video file video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4") with open(video_path, "wb") as f: f.write(video_data) logger.info(f"Video saved: {video_path}") # Open video file cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError("Could not open video file") # Get video properties total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or 30 duration = total_frames / fps width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}") # Initialize ByteTrack tracker = BYTETracker( track_thresh=CONFIG["TRACK_THRESH"], track_buffer=CONFIG["TRACK_BUFFER"], match_thresh=CONFIG["MATCH_THRESH"], frame_rate=fps ) # Track violations by worker ID and type violation_tracker = {} # {worker_id: {violation_type: [detections]}} snapshots = [] start_time = time.time() frame_skip = CONFIG["FRAME_SKIP"] # Process frames in batches while True: batch_frames = [] batch_indices = [] # Collect frames for this batch for _ in range(CONFIG["BATCH_SIZE"]): frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if frame_idx >= total_frames: break ret, frame = cap.read() if not ret: break # Skip frames if needed for _ in range(frame_skip - 1): if not cap.grab(): break batch_frames.append(frame) batch_indices.append(frame_idx) # Break if no more frames if not batch_frames: break # Run batch detection results = model(batch_frames, device=device, conf=0.1, verbose=False) # Process results for each frame in batch for i, (result, frame_idx) in enumerate(zip(results, batch_indices)): current_time = frame_idx / fps # Update progress if time.time() - start_time > 1.0: progress = (frame_idx / total_frames) * 100 yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", "" start_time = time.time() # Prepare detections for ByteTrack boxes = result.boxes track_inputs = [] for box in boxes: cls = int(box.cls) conf = float(box.conf) label = CONFIG["VIOLATION_LABELS"].get(cls, None) if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25): continue bbox = box.xywh.cpu().numpy()[0] track_inputs.append({ "bbox": bbox, # [x, y, w, h] "conf": conf, "cls": cls }) # Update tracker tracked_objects = tracker.update( np.array([t["bbox"] for t in track_inputs]), np.array([t["conf"] for t in track_inputs]), np.array([t["cls"] for t in track_inputs]) ) # Process tracked objects for obj, track_input in zip(tracked_objects, track_inputs): worker_id = obj.id label = CONFIG["VIOLATION_LABELS"].get(int(obj.cls), None) bbox = track_input["bbox"] conf = track_input["conf"] detection = { "frame": frame_idx, "violation": label, "confidence": round(conf, 2), "bounding_box": [round(x, 2) for x in bbox], "timestamp": current_time, "worker_id": worker_id } # Track violations by worker_id and type if worker_id not in violation_tracker: violation_tracker[worker_id] = {} if label not in violation_tracker[worker_id]: violation_tracker[worker_id][label] = [] violation_tracker[worker_id][label].append(detection) cap.release() os.remove(video_path) processing_time = time.time() - start_time logger.info(f"Processing complete in {processing_time:.2f}s") # Consolidate violations violations = [] for worker_id, worker_violations in violation_tracker.items(): for label, detections in worker_violations.items(): if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]: # Select highest-confidence detection best_detection = max(detections, key=lambda x: x["confidence"]) best_detection["start_timestamp"] = min(d["timestamp"] for d in detections) best_detection["end_timestamp"] = max(d["timestamp"] for d in detections) violations.append(best_detection) # Capture snapshot for confirmed violation cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"]) ret, snapshot_frame = cap.read() if ret: snapshot_frame = draw_detections(snapshot_frame, [best_detection]) snapshot_filename = f"{label}_{best_detection['frame']}.jpg" snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename) cv2.imwrite(snapshot_path, snapshot_frame) snapshots.append({ "violation": label, "frame": best_detection["frame"], "snapshot_path": snapshot_path, "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}" }) cap.release() # Generate results if not violations: yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A" return score = calculate_safety_score(violations) pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score) report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file) violation_table = "| Violation | Time Range (s) | Confidence | Worker ID |\n" violation_table += "|------------------------|----------------|------------|-----------|\n" for v in sorted(violations, key=lambda x: x["start_timestamp"]): display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown") row = f"| {display_name:<22} | {v.get('start_timestamp', 0.0):.2f}-{v.get('end_timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n" violation_table += row snapshots_text = "\n".join( f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})" for s in snapshots ) if snapshots else "No snapshots captured." yield ( violation_table, f"Safety Score: {score}%", snapshots_text, f"Salesforce Record ID: {report_id or 'N/A'}", final_pdf_url or "N/A" ) except Exception as e: logger.error(f"Error processing video: {e}", exc_info=True) yield f"Error processing video: {e}", "", "", "", "" # Initialize device and model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") def load_model(): try: if os.path.isfile(CONFIG["MODEL_PATH"]): model_path = CONFIG["MODEL_PATH"] logger.info(f"Model loaded: {model_path}") else: model_path = CONFIG["FALLBACK_MODEL"] logger.warning("Using fallback model. Train yolov8_safety.pt for best results.") if not os.path.isfile(model_path): logger.info(f"Downloading fallback model: {model_path}") torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path) model = YOLO(model_path).to(device) return model except Exception as e: logger.error(f"Failed to load model: {e}") raise model = load_model() # ========================== # Helper Functions # ========================== def draw_detections(frame, detections): # ... [your existing implementation] ... def calculate_safety_score(violations): # ... [your existing implementation] ... def generate_violation_pdf(violations, score): # ... [your existing implementation] ... @retry(stop_max_attempt_number=3, wait_fixed=2000) def connect_to_salesforce(): # ... [your existing implementation] ... def upload_pdf_to_salesforce(sf, pdf_file, report_id): # ... [your existing implementation] ... def push_report_to_salesforce(violations, score, pdf_path, pdf_file): # ... [your existing implementation] ... def process_video(video_data): # ... [your existing implementation] ... def gradio_interface(video_file): if not video_file: return "No file uploaded.", "", "No file uploaded.", "", "" try: with open(video_file, "rb") as f: video_data = f.read() for status, score, snapshots_text, record_id, details_url in process_video(video_data): yield status, score, snapshots_text, record_id, details_url except Exception as e: logger.error(f"Error in Gradio interface: {e}", exc_info=True) yield f"Error: {str(e)}", "", "Error in processing.", "", "" # ========================== # Gradio Interface # ========================== interface = gr.Interface( fn=gradio_interface, inputs=gr.Video(label="Upload Site Video"), outputs=[ gr.Markdown(label="Detected Safety Violations"), gr.Textbox(label="Compliance Score"), gr.Markdown(label="Snapshots"), gr.Textbox(label="Salesforce Record ID"), gr.Textbox(label="Violation Details URL") ], title="Worksite Safety Violation Analyzer", description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.", allow_flagging="never" ) if __name__ == "__main__": logger.info("Launching Enhanced Safety Analyzer App...") interface.launch()