Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import time | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.utils import ImageReader | |
| from io import BytesIO | |
| import base64 | |
| from PIL import Image | |
| # ========================== | |
| # Configuration | |
| # ========================== | |
| DEFAULT_MODEL_PATH = "models/yolov8_safety.pt" | |
| FALLBACK_MODEL = "yolov8n.pt" | |
| MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH) | |
| OUTPUT_DIR = "output" # Directory to store snapshots and PDFs | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| VIOLATION_LABELS = { | |
| 0: "no_helmet", | |
| 1: "no_harness", | |
| 2: "unsafe_posture", | |
| 3: "unsafe_zone" | |
| } | |
| # ========================== | |
| # Device Setup | |
| # ========================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"✅ Using device: {device}") | |
| # ========================== | |
| # Load Model | |
| # ========================== | |
| selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL | |
| model = YOLO(selected_model) | |
| # ========================== | |
| # Video Processing | |
| # ========================== | |
| def process_video(video_data, frame_skip=5, max_frames=100): | |
| try: | |
| # Save uploaded video data to a temporary file | |
| video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4") | |
| with open(video_path, "wb") as f: | |
| f.write(video_data) | |
| video = cv2.VideoCapture(video_path) | |
| if not video.isOpened(): | |
| raise ValueError("Could not open video file.") | |
| frame_count = 0 | |
| violations = [] | |
| snapshots = [] | |
| processed_frame_count = 0 | |
| start_time = time.time() | |
| while True: | |
| ret, frame = video.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_skip != 0: | |
| frame_count += 1 | |
| continue | |
| # Model inference | |
| results = model(frame, device=device) | |
| for result in results: | |
| for box in result.boxes: | |
| cls = int(box.cls) | |
| conf = float(box.conf) | |
| xywh = box.xywh.cpu().numpy()[0] | |
| label = VIOLATION_LABELS.get(cls, f"class_{cls}") | |
| violation = { | |
| "frame": frame_count, | |
| "violation": label, | |
| "confidence": round(conf, 2), | |
| "bounding_box": [round(x, 2) for x in xywh], | |
| "timestamp": frame_count / video.get(cv2.CAP_PROP_FPS) | |
| } | |
| violations.append(violation) | |
| # Save snapshot | |
| snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg") | |
| cv2.imwrite(snapshot_path, frame) | |
| snapshots.append({ | |
| "violation": label, | |
| "frame": frame_count, | |
| "snapshot_url": snapshot_path | |
| }) | |
| frame_count += 1 | |
| processed_frame_count += 1 | |
| if processed_frame_count >= max_frames: | |
| break | |
| if time.time() - start_time > 30: | |
| print("⏰ Exceeded 30 seconds of processing time.") | |
| break | |
| video.release() | |
| os.remove(video_path) # Clean up temporary video file | |
| score = calculate_safety_score(violations) | |
| pdf_report_path = generate_pdf_report(violations, snapshots, score) | |
| return { | |
| "violations": violations, | |
| "snapshots": snapshots, | |
| "score": score, | |
| "pdf_report_url": pdf_report_path | |
| } | |
| except Exception as e: | |
| print(f"❌ Error processing video: {e}") | |
| return { | |
| "violations": [], | |
| "snapshots": [], | |
| "score": 0, | |
| "pdf_report_url": "", | |
| "error": str(e) | |
| } | |
| # ========================== | |
| # Safety Score Calculation | |
| # ========================== | |
| def calculate_safety_score(violations): | |
| base_score = 100 | |
| penalties = { | |
| "no_helmet": 25, | |
| "no_harness": 30, | |
| "unsafe_posture": 20, | |
| "unsafe_zone": 25 | |
| } | |
| for v in violations: | |
| base_score -= penalties.get(v["violation"], 0) | |
| return max(base_score, 0) | |
| # ========================== | |
| # PDF Report Generation | |
| # ========================== | |
| def generate_pdf_report(violations, snapshots, score): | |
| pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf") | |
| c = canvas.Canvas(pdf_path, pagesize=letter) | |
| width, height = letter | |
| # Title | |
| c.setFont("Helvetica-Bold", 16) | |
| c.drawString(50, height - 50, "Worksite Safety Compliance Report") | |
| # Compliance Score | |
| c.setFont("Helvetica", 12) | |
| c.drawString(50, height - 80, f"Compliance Score: {score}%") | |
| # Violations Table | |
| y = height - 120 | |
| c.setFont("Helvetica-Bold", 12) | |
| c.drawString(50, y, "Detected Violations:") | |
| y -= 20 | |
| for v in violations: | |
| c.setFont("Helvetica", 10) | |
| text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}" | |
| c.drawString(50, y, text) | |
| y -= 20 | |
| # Add snapshot if available | |
| snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None) | |
| if snapshot and os.path.exists(snapshot["snapshot_url"]): | |
| img = ImageReader(snapshot["snapshot_url"]) | |
| c.drawImage(img, 50, y - 100, width=200, height=150) | |
| y -= 170 | |
| if y < 50: | |
| c.showPage() | |
| y = height - 50 | |
| c.save() | |
| return pdf_path | |
| # ========================== | |
| # Gradio Interface | |
| # ========================== | |
| def gradio_interface(video_file): | |
| if not video_file: | |
| return {"error": "Please upload a video file."}, "", "" | |
| with open(video_file, "rb") as f: | |
| video_data = f.read() | |
| result = process_video(video_data) | |
| return ( | |
| result["violations"], | |
| f"Safety Score: {result['score']}%", | |
| result["pdf_report_url"], | |
| result["snapshots"] | |
| ) | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Video(label="Upload Site Video"), | |
| outputs=[ | |
| gr.JSON(label="Detected Safety Violations"), | |
| gr.Textbox(label="Compliance Score"), | |
| gr.Textbox(label="PDF Report URL"), | |
| gr.JSON(label="Snapshots") | |
| ], | |
| title="Worksite Safety Violation Analyzer", | |
| description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)." | |
| ) | |
| if __name__ == "__main__": | |
| print("🚀 Launching Safety Analyzer App...") | |
| interface.launch() |