AI_Safety / app.py
PrashanthB461's picture
Update app.py
dd7321a verified
import os
import cv2
import gradio as gr
import torch
import numpy as np
from salesforce import get_salesforce_connection
from simple_salesforce import Salesforce
sf = get_salesforce_connection()
try:
from ultralytics import YOLO
except ImportError as e:
print("❌ Ultralytics not installed. Run: pip install ultralytics")
raise
# ==========================
# Configuration
# ==========================
DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
FALLBACK_MODEL = "yolov8n.pt" # Use nano model if custom one is missing
MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
VIOLATION_LABELS = {
0: "no_helmet",
1: "no_harness",
2: "unsafe_posture",
3: "unsafe_zone"
}
# ==========================
# Device Setup
# ==========================
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"βœ… Using device: {device}")
except Exception as e:
print(f"⚠️ Error setting device: {e}")
device = torch.device("cpu")
# ==========================
# Load Model
# ==========================
if os.path.isfile(MODEL_PATH):
selected_model = MODEL_PATH
print(f"βœ… Found model at: {selected_model}")
else:
print(f"⚠️ Model file '{MODEL_PATH}' not found. Falling back to: {FALLBACK_MODEL}")
selected_model = FALLBACK_MODEL
try:
model = YOLO(selected_model)
print(f"βœ… Model loaded: {selected_model}")
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
# ==========================
# Video Processing
# ==========================
def process_video(video_path):
try:
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file.")
frame_count = 0
violations = []
while True:
ret, frame = video.read()
if not ret:
break
results = model(frame, device=device)
for result in results:
for box in result.boxes:
cls = int(box.cls)
conf = float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
label = VIOLATION_LABELS.get(cls, f"class_{cls}")
violations.append({
"frame": frame_count,
"violation": label,
"confidence": round(conf, 2),
"bounding_box": [round(x, 2) for x in xywh]
})
frame_count += 1
video.release()
score = calculate_safety_score(violations)
return violations, score
except Exception as e:
print(f"❌ Error processing video: {e}")
return [], f"Error: {e}"
# ==========================
# Safety Score
# ==========================
def calculate_safety_score(violations):
base_score = 100
penalties = {
"no_helmet": 25,
"no_harness": 30,
"unsafe_posture": 20,
"unsafe_zone": 25
}
for v in violations:
base_score -= penalties.get(v["violation"], 0)
return max(base_score, 0)
# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
if not video_file:
return "Please upload a video file.", ""
violations, score = process_video(video_file)
return violations, f"Safety Score: {score}%"
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Site Video"),
outputs=[
gr.JSON(label="Detected Safety Violations"),
gr.Textbox(label="Compliance Score")
],
title="Worksite Safety Violation Analyzer",
description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
)
if __name__ == "__main__":
print("πŸš€ Launching Safety Analyzer App...")
interface.launch()