AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
04fdfdf verified
raw
history blame
6.79 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.utils import ImageReader
from io import BytesIO
import base64
from PIL import Image
# ==========================
# Configuration
# ==========================
DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
FALLBACK_MODEL = "yolov8n.pt"
MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
OUTPUT_DIR = "output" # Directory to store snapshots and PDFs
os.makedirs(OUTPUT_DIR, exist_ok=True)
VIOLATION_LABELS = {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
}
# ==========================
# Device Setup
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")
# ==========================
# Load Model
# ==========================
selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
model = YOLO(selected_model)
# ==========================
# Video Processing
# ==========================
def process_video(video_data, frame_skip=5, max_frames=100):
try:
# Save uploaded video data to a temporary file
video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
with open(video_path, "wb") as f:
f.write(video_data)
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file.")
frame_count = 0
violations = []
snapshots = []
processed_frame_count = 0
start_time = time.time()
while True:
ret, frame = video.read()
if not ret:
break
if frame_count % frame_skip != 0:
frame_count += 1
continue
# Model inference
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
label = VIOLATION_LABELS.get(cls, f"class_{cls}")
violation = {
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh],
"timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
}
violations.append(violation)
# Save snapshot
snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg")
cv2.imwrite(snapshot_path, frame)
snapshots.append({
"violation": label,
"frame": frame_count,
"snapshot_url": snapshot_path
})
frame_count += 1
processed_frame_count += 1
if processed_frame_count >= max_frames:
break
if time.time() - start_time > 30:
print("⏰ Exceeded 30 seconds of processing time.")
break
video.release()
os.remove(video_path) # Clean up temporary video file
score = calculate_safety_score(violations)
pdf_report_path = generate_pdf_report(violations, snapshots, score)
return {
"violations": violations,
"snapshots": snapshots,
"score": score,
"pdf_report_url": pdf_report_path
}
except Exception as e:
print(f"❌ Error processing video: {e}")
return {
"violations": [],
"snapshots": [],
"score": 0,
"pdf_report_url": "",
"error": str(e)
}
# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
base_score = 100
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
for v in violations:
base_score -= penalties.get(v["violation"], 0)
return max(base_score, 0)
# ==========================
# PDF Report Generation
# ==========================
def generate_pdf_report(violations, snapshots, score):
pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf")
c = canvas.Canvas(pdf_path, pagesize=letter)
width, height = letter
# Title
c.setFont("Helvetica-Bold", 16)
c.drawString(50, height - 50, "Worksite Safety Compliance Report")
# Compliance Score
c.setFont("Helvetica", 12)
c.drawString(50, height - 80, f"Compliance Score: {score}%")
# Violations Table
y = height - 120
c.setFont("Helvetica-Bold", 12)
c.drawString(50, y, "Detected Violations:")
y -= 20
for v in violations:
c.setFont("Helvetica", 10)
text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
c.drawString(50, y, text)
y -= 20
# Add snapshot if available
snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
if snapshot and os.path.exists(snapshot["snapshot_url"]):
img = ImageReader(snapshot["snapshot_url"])
c.drawImage(img, 50, y - 100, width=200, height=150)
y -= 170
if y < 50:
c.showPage()
y = height - 50
c.save()
return pdf_path
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return {"error": "Please upload a video file."}, "", ""
with open(video_file, "rb") as f:
video_data = f.read()
result = process_video(video_data)
return (
result["violations"],
f"Safety Score: {result['score']}%",
result["pdf_report_url"],
result["snapshots"]
)
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.JSON(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score"),
gr.Textbox(label="PDF Report URL"),
gr.JSON(label="Snapshots")
],
title="Worksite Safety Violation Analyzer",
description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)
if __name__ == "__main__":
print("🚀 Launching Safety Analyzer App...")
interface.launch()