AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
4e7fbac verified
raw
history blame
3.34 kB
import os
import cv2
import gradio as gr
import torch
import numpy as np
try:
from ultralytics import YOLO
except ImportError as e:
print(f"Error importing ultralytics: {e}")
raise
# ========== Configuration ==========
MODEL_PATH = "models/yolov8_safety.pt" # Your custom safety model
VIOLATION_LABELS = {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
}
# ========== Device Setup ==========
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
except Exception as e:
print(f"Device error: {e}")
device = torch.device("cpu")
# ========== Load Model ==========
if not os.path.isfile(MODEL_PATH):
raise FileNotFoundError(
f"ERROR: Model file '{MODEL_PATH}' not found. Please upload it to the 'models/' folder.")
try:
model = YOLO(MODEL_PATH)
except Exception as e:
print(f"Error loading YOLO model: {e}")
raise
# ========== Core Logic ==========
def process_video(video_path):
try:
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file.")
violations = []
frame_count = 0
while True:
ret, frame = video.read()
if not ret:
break
# YOLOv8 inference
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
if cls in VIOLATION_LABELS:
violations.append({
"frame": frame_count,
"violation": VIOLATION_LABELS[cls],
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh]
})
frame_count += 1
video.release()
safety_score = calculate_safety_score(violations)
return violations, safety_score
except Exception as e:
print(f"Error processing video: {e}")
return [], f"Error: {e}"
# ========== Score Calculation ==========
def calculate_safety_score(violations):
base_score = 100
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
for v in violations:
base_score -= penalties.get(v["violation"], 0)
return max(base_score, 0)
# ========== Gradio Interface ==========
def gradio_interface(video_file):
if not video_file:
return "Please upload a video file.", ""
violations, score = process_video(video_file)
return violations, f"Safety Score: {score}%"
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.JSON(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score")
],
title="Worksite Safety Violation Analyzer",
description="Upload a short site video to detect safety compliance violations like missing helmets, no harness, and unsafe behavior."
)
if __name__ == "__main__":
print("Launching Safety Analyzer App...")
interface.launch()