PrashanthB461 commited on
Commit
e04e491
·
verified ·
1 Parent(s): caad7aa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import torch # Moved torch import to the top
4
+ try:
5
+ from ultralytics import YOLO
6
+ except ImportError as e:
7
+ print(f"Error importing ultralytics: {e}")
8
+ print("Ensure 'ultralytics' is listed in requirements.txt and installed.")
9
+ raise
10
+ import numpy as np
11
+
12
+ # Set device for model inference
13
+ try:
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Using device: {device}")
16
+ except Exception as e:
17
+ print(f"Error setting device: {e}")
18
+ device = torch.device("cpu") # Fallback to CPU
19
+ print("Falling back to CPU")
20
+
21
+ # Load the YOLOv8 model
22
+ try:
23
+ model = YOLO("yolov8n.pt") # Use YOLOv8 nano model
24
+ except Exception as e:
25
+ print(f"Error loading YOLO model: {e}")
26
+ raise
27
+
28
+ # Function to process the video file
29
+ def process_video(video_path):
30
+ try:
31
+ # Load the video
32
+ video = cv2.VideoCapture(video_path)
33
+ if not video.isOpened():
34
+ raise ValueError("Could not open video file.")
35
+
36
+ frame_count = 0
37
+ violations = []
38
+
39
+ while True:
40
+ ret, frame = video.read()
41
+ if not ret:
42
+ break # End of video
43
+
44
+ # Run YOLOv8 inference on the frame
45
+ results = model(frame, device=device)
46
+
47
+ # Process detected objects
48
+ for result in results:
49
+ boxes = result.boxes
50
+ for box in boxes:
51
+ cls = int(box.cls)
52
+ conf = float(box.conf)
53
+ xywh = box.xywh.cpu().numpy()[0]
54
+
55
+ # Map class IDs to violation types (adjust as needed)
56
+ violation_labels = {0: "person", 1: "bicycle", 2: "car"}
57
+ if cls in violation_labels:
58
+ violations.append({
59
+ "frame": frame_count,
60
+ "violation": violation_labels.get(cls, "unknown"),
61
+ "confidence": conf,
62
+ "bounding_box": xywh.tolist()
63
+ })
64
+
65
+ frame_count += 1
66
+
67
+ video.release()
68
+ safety_score = calculate_safety_score(violations)
69
+ return violations, safety_score
70
+ except Exception as e:
71
+ print(f"Error processing video: {e}")
72
+ return [], f"Error: {e}"
73
+
74
+ # Function to calculate safety score
75
+ def calculate_safety_score(violations):
76
+ total_score = 100
77
+ violation_penalties = {
78
+ "person": 20,
79
+ "bicycle": 15,
80
+ "car": 30,
81
+ "unknown": 10
82
+ }
83
+ for violation in violations:
84
+ total_score -= violation_penalties.get(violation["violation"], 0)
85
+ return max(total_score, 0)
86
+
87
+ # Gradio Interface
88
+ def gradio_interface(video_file):
89
+ if video_file is None:
90
+ return "Please upload a video file.", ""
91
+
92
+ try:
93
+ violations, safety_score = process_video(video_file)
94
+ return violations, f"Safety Score: {safety_score}%"
95
+ except Exception as e:
96
+ print(f"Gradio interface error: {e}")
97
+ return [], f"Error: {e}"
98
+
99
+ # Define Gradio interface
100
+ interface = gr.Interface(
101
+ fn=gradio_interface,
102
+ inputs=gr.Video(label="Upload Video"),
103
+ outputs=[
104
+ gr.JSON(label="Detected Violations"),
105
+ gr.Textbox(label="Safety Score")
106
+ ],
107
+ title="Safety Violation Detection",
108
+ description="Upload a video to detect safety violations and calculate a safety score."
109
+ )
110
+
111
+ if __name__ == "__main__":
112
+ print("Launching Gradio interface...")
113
+ interface.launch()