AI_Safety_Demo7 / app.py
PrashanthB461's picture
Update app.py
7a2fc1f verified
raw
history blame
3.5 kB
#import cv2
import gradio as gr
import torch # Moved torch import to the top
try:
from ultralytics import YOLO
except ImportError as e:
print(f"Error importing ultralytics: {e}")
print("Ensure 'ultralytics' is listed in requirements.txt and installed.")
raise
import numpy as np
# Set device for model inference
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
except Exception as e:
print(f"Error setting device: {e}")
device = torch.device("cpu") # Fallback to CPU
print("Falling back to CPU")
# Load the YOLOv8 model
try:
model = YOLO("yolov8n.pt") # Use YOLOv8 nano model
except Exception as e:
print(f"Error loading YOLO model: {e}")
raise
# Function to process the video file
def process_video(video_path):
try:
# Load the video
video = cv2.VideoCapture(video_path)
if not video.isOpened():
raise ValueError("Could not open video file.")
frame_count = 0
violations = []
while True:
ret, frame = video.read()
if not ret:
break # End of video
# Run YOLOv8 inference on the frame
results = model(frame, device=device)
# Process detected objects
for result in results:
boxes = result.boxes
for box in boxes:
cls = int(box.cls)
conf = float(box.conf)
xywh = box.xywh.cpu().numpy()[0]
# Map class IDs to violation types (adjust as needed)
violation_labels = {0: "person", 1: "bicycle", 2: "car"}
if cls in violation_labels:
violations.append({
"frame": frame_count,
"violation": violation_labels.get(cls, "unknown"),
"confidence": conf,
"bounding_box": xywh.tolist()
})
frame_count += 1
video.release()
safety_score = calculate_safety_score(violations)
return violations, safety_score
except Exception as e:
print(f"Error processing video: {e}")
return [], f"Error: {e}"
# Function to calculate safety score
def calculate_safety_score(violations):
total_score = 100
violation_penalties = {
"person": 20,
"bicycle": 15,
"car": 30,
"unknown": 10
}
for violation in violations:
total_score -= violation_penalties.get(violation["violation"], 0)
return max(total_score, 0)
# Gradio Interface
def gradio_interface(video_file):
if video_file is None:
return "Please upload a video file.", ""
try:
violations, safety_score = process_video(video_file)
return violations, f"Safety Score: {safety_score}%"
except Exception as e:
print(f"Gradio interface error: {e}")
return [], f"Error: {e}"
# Define Gradio interface
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Video(label="Upload Video"),
outputs=[
gr.JSON(label="Detected Violations"),
gr.Textbox(label="Safety Score")
],
title="Safety Violation Detection",
description="Upload a video to detect safety violations and calculate a safety score."
)
if __name__ == "__main__":
print("Launching Gradio interface...")
interface.launch()