PrashanthB461 commited on
Commit
40bf7bd
·
verified ·
1 Parent(s): 7a2fc1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -52
app.py CHANGED
@@ -1,65 +1,73 @@
1
- #import cv2
 
2
  import gradio as gr
3
- import torch # Moved torch import to the top
 
 
4
  try:
5
  from ultralytics import YOLO
6
  except ImportError as e:
7
  print(f"Error importing ultralytics: {e}")
8
- print("Ensure 'ultralytics' is listed in requirements.txt and installed.")
9
  raise
10
- import numpy as np
11
 
12
- # Set device for model inference
 
 
 
 
 
 
 
 
 
13
  try:
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  print(f"Using device: {device}")
16
  except Exception as e:
17
- print(f"Error setting device: {e}")
18
- device = torch.device("cpu") # Fallback to CPU
19
- print("Falling back to CPU")
 
 
 
 
20
 
21
- # Load the YOLOv8 model
22
  try:
23
- model = YOLO("yolov8n.pt") # Use YOLOv8 nano model
24
  except Exception as e:
25
  print(f"Error loading YOLO model: {e}")
26
  raise
27
 
28
- # Function to process the video file
29
  def process_video(video_path):
30
  try:
31
- # Load the video
32
  video = cv2.VideoCapture(video_path)
33
  if not video.isOpened():
34
  raise ValueError("Could not open video file.")
35
-
36
- frame_count = 0
37
  violations = []
 
38
 
39
  while True:
40
  ret, frame = video.read()
41
  if not ret:
42
- break # End of video
43
 
44
- # Run YOLOv8 inference on the frame
45
  results = model(frame, device=device)
46
 
47
- # Process detected objects
48
  for result in results:
49
- boxes = result.boxes
50
- for box in boxes:
51
  cls = int(box.cls)
52
  conf = float(box.conf)
53
  xywh = box.xywh.cpu().numpy()[0]
54
 
55
- # Map class IDs to violation types (adjust as needed)
56
- violation_labels = {0: "person", 1: "bicycle", 2: "car"}
57
- if cls in violation_labels:
58
  violations.append({
59
  "frame": frame_count,
60
- "violation": violation_labels.get(cls, "unknown"),
61
- "confidence": conf,
62
- "bounding_box": xywh.tolist()
63
  })
64
 
65
  frame_count += 1
@@ -67,47 +75,43 @@ def process_video(video_path):
67
  video.release()
68
  safety_score = calculate_safety_score(violations)
69
  return violations, safety_score
 
70
  except Exception as e:
71
  print(f"Error processing video: {e}")
72
  return [], f"Error: {e}"
73
 
74
- # Function to calculate safety score
75
  def calculate_safety_score(violations):
76
- total_score = 100
77
- violation_penalties = {
78
- "person": 20,
79
- "bicycle": 15,
80
- "car": 30,
81
- "unknown": 10
82
  }
83
- for violation in violations:
84
- total_score -= violation_penalties.get(violation["violation"], 0)
85
- return max(total_score, 0)
86
 
87
- # Gradio Interface
88
  def gradio_interface(video_file):
89
- if video_file is None:
90
  return "Please upload a video file.", ""
91
-
92
- try:
93
- violations, safety_score = process_video(video_file)
94
- return violations, f"Safety Score: {safety_score}%"
95
- except Exception as e:
96
- print(f"Gradio interface error: {e}")
97
- return [], f"Error: {e}"
98
 
99
- # Define Gradio interface
 
 
100
  interface = gr.Interface(
101
  fn=gradio_interface,
102
- inputs=gr.Video(label="Upload Video"),
103
  outputs=[
104
- gr.JSON(label="Detected Violations"),
105
- gr.Textbox(label="Safety Score")
106
  ],
107
- title="Safety Violation Detection",
108
- description="Upload a video to detect safety violations and calculate a safety score."
109
  )
110
 
111
  if __name__ == "__main__":
112
- print("Launching Gradio interface...")
113
- interface.launch()
 
1
+ import os
2
+ import cv2
3
  import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+
7
  try:
8
  from ultralytics import YOLO
9
  except ImportError as e:
10
  print(f"Error importing ultralytics: {e}")
 
11
  raise
 
12
 
13
+ # ========== Configuration ==========
14
+ MODEL_PATH = "models/yolov8_safety.pt" # Your custom safety model
15
+ VIOLATION_LABELS = {
16
+ 0: "no_helmet",
17
+ 1: "no_harness",
18
+ 2: "unsafe_posture",
19
+ 3: "unsafe_zone"
20
+ }
21
+
22
+ # ========== Device Setup ==========
23
  try:
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  print(f"Using device: {device}")
26
  except Exception as e:
27
+ print(f"Device error: {e}")
28
+ device = torch.device("cpu")
29
+
30
+ # ========== Load Model ==========
31
+ if not os.path.isfile(MODEL_PATH):
32
+ raise FileNotFoundError(
33
+ f"🚨 ERROR: Model file '{MODEL_PATH}' not found. Please upload it to the 'models/' folder.")
34
 
 
35
  try:
36
+ model = YOLO(MODEL_PATH)
37
  except Exception as e:
38
  print(f"Error loading YOLO model: {e}")
39
  raise
40
 
41
+ # ========== Core Logic ==========
42
  def process_video(video_path):
43
  try:
 
44
  video = cv2.VideoCapture(video_path)
45
  if not video.isOpened():
46
  raise ValueError("Could not open video file.")
47
+
 
48
  violations = []
49
+ frame_count = 0
50
 
51
  while True:
52
  ret, frame = video.read()
53
  if not ret:
54
+ break
55
 
56
+ # YOLOv8 inference
57
  results = model(frame, device=device)
58
 
 
59
  for result in results:
60
+ for box in result.boxes:
 
61
  cls = int(box.cls)
62
  conf = float(box.conf)
63
  xywh = box.xywh.cpu().numpy()[0]
64
 
65
+ if cls in VIOLATION_LABELS:
 
 
66
  violations.append({
67
  "frame": frame_count,
68
+ "violation": VIOLATION_LABELS[cls],
69
+ "confidence": round(conf, 2),
70
+ "bounding_box": [round(x, 2) for x in xywh]
71
  })
72
 
73
  frame_count += 1
 
75
  video.release()
76
  safety_score = calculate_safety_score(violations)
77
  return violations, safety_score
78
+
79
  except Exception as e:
80
  print(f"Error processing video: {e}")
81
  return [], f"Error: {e}"
82
 
83
+ # ========== Score Calculation ==========
84
  def calculate_safety_score(violations):
85
+ base_score = 100
86
+ penalties = {
87
+ "no_helmet": 25,
88
+ "no_harness": 30,
89
+ "unsafe_posture": 20,
90
+ "unsafe_zone": 25
91
  }
92
+ for v in violations:
93
+ base_score -= penalties.get(v["violation"], 0)
94
+ return max(base_score, 0)
95
 
96
+ # ========== Gradio Interface ==========
97
  def gradio_interface(video_file):
98
+ if not video_file:
99
  return "Please upload a video file.", ""
 
 
 
 
 
 
 
100
 
101
+ violations, score = process_video(video_file)
102
+ return violations, f"Safety Score: {score}%"
103
+
104
  interface = gr.Interface(
105
  fn=gradio_interface,
106
+ inputs=gr.Video(label="Upload Site Video"),
107
  outputs=[
108
+ gr.JSON(label="Detected Safety Violations"),
109
+ gr.Textbox(label="Compliance Score")
110
  ],
111
+ title="Worksite Safety Violation Analyzer",
112
+ description="Upload a short site video to detect safety compliance violations like missing helmets, no harness, and unsafe behavior."
113
  )
114
 
115
  if __name__ == "__main__":
116
+ print("Launching Safety Analyzer App...")
117
+ interface.launch()