PrashanthB461 commited on
Commit
457903c
·
verified ·
1 Parent(s): 4e7fbac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -19
app.py CHANGED
@@ -7,11 +7,15 @@ import numpy as np
7
  try:
8
  from ultralytics import YOLO
9
  except ImportError as e:
10
- print(f"Error importing ultralytics: {e}")
11
  raise
12
 
13
- # ========== Configuration ==========
14
- MODEL_PATH = "models/yolov8_safety.pt" # Your custom safety model
 
 
 
 
15
  VIOLATION_LABELS = {
16
  0: "no_helmet",
17
  1: "no_harness",
@@ -19,41 +23,49 @@ VIOLATION_LABELS = {
19
  3: "unsafe_zone"
20
  }
21
 
22
- # ========== Device Setup ==========
 
 
23
  try:
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- print(f"Using device: {device}")
26
  except Exception as e:
27
- print(f"Device error: {e}")
28
  device = torch.device("cpu")
29
 
30
- # ========== Load Model ==========
 
 
31
  if not os.path.isfile(MODEL_PATH):
32
  raise FileNotFoundError(
33
- f"ERROR: Model file '{MODEL_PATH}' not found. Please upload it to the 'models/' folder.")
 
 
34
 
35
  try:
36
  model = YOLO(MODEL_PATH)
 
37
  except Exception as e:
38
- print(f"Error loading YOLO model: {e}")
39
  raise
40
 
41
- # ========== Core Logic ==========
 
 
42
  def process_video(video_path):
43
  try:
44
  video = cv2.VideoCapture(video_path)
45
  if not video.isOpened():
46
  raise ValueError("Could not open video file.")
47
 
48
- violations = []
49
  frame_count = 0
 
50
 
51
  while True:
52
  ret, frame = video.read()
53
  if not ret:
54
  break
55
 
56
- # YOLOv8 inference
57
  results = model(frame, device=device)
58
 
59
  for result in results:
@@ -73,14 +85,16 @@ def process_video(video_path):
73
  frame_count += 1
74
 
75
  video.release()
76
- safety_score = calculate_safety_score(violations)
77
- return violations, safety_score
78
 
79
  except Exception as e:
80
- print(f"Error processing video: {e}")
81
  return [], f"Error: {e}"
82
 
83
- # ========== Score Calculation ==========
 
 
84
  def calculate_safety_score(violations):
85
  base_score = 100
86
  penalties = {
@@ -93,7 +107,9 @@ def calculate_safety_score(violations):
93
  base_score -= penalties.get(v["violation"], 0)
94
  return max(base_score, 0)
95
 
96
- # ========== Gradio Interface ==========
 
 
97
  def gradio_interface(video_file):
98
  if not video_file:
99
  return "Please upload a video file.", ""
@@ -109,9 +125,9 @@ interface = gr.Interface(
109
  gr.Textbox(label="Compliance Score")
110
  ],
111
  title="Worksite Safety Violation Analyzer",
112
- description="Upload a short site video to detect safety compliance violations like missing helmets, no harness, and unsafe behavior."
113
  )
114
 
115
  if __name__ == "__main__":
116
- print("Launching Safety Analyzer App...")
117
  interface.launch()
 
7
  try:
8
  from ultralytics import YOLO
9
  except ImportError as e:
10
+ print(" Ultralytics not installed.")
11
  raise
12
 
13
+ # ==========================
14
+ # Configuration
15
+ # ==========================
16
+ DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
17
+ MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
18
+
19
  VIOLATION_LABELS = {
20
  0: "no_helmet",
21
  1: "no_harness",
 
23
  3: "unsafe_zone"
24
  }
25
 
26
+ # ==========================
27
+ # Device Setup
28
+ # ==========================
29
  try:
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ print(f"Using device: {device}")
32
  except Exception as e:
33
+ print(f"⚠️ Error setting device: {e}")
34
  device = torch.device("cpu")
35
 
36
+ # ==========================
37
+ # Load Model
38
+ # ==========================
39
  if not os.path.isfile(MODEL_PATH):
40
  raise FileNotFoundError(
41
+ f"ERROR: Model file '{MODEL_PATH}' not found.\n"
42
+ f"👉 Please upload it to the 'models/' folder or set SAFETY_MODEL_PATH env variable."
43
+ )
44
 
45
  try:
46
  model = YOLO(MODEL_PATH)
47
+ print(f"✅ Model loaded from {MODEL_PATH}")
48
  except Exception as e:
49
+ print(f" Failed to load model: {e}")
50
  raise
51
 
52
+ # ==========================
53
+ # Video Processing
54
+ # ==========================
55
  def process_video(video_path):
56
  try:
57
  video = cv2.VideoCapture(video_path)
58
  if not video.isOpened():
59
  raise ValueError("Could not open video file.")
60
 
 
61
  frame_count = 0
62
+ violations = []
63
 
64
  while True:
65
  ret, frame = video.read()
66
  if not ret:
67
  break
68
 
 
69
  results = model(frame, device=device)
70
 
71
  for result in results:
 
85
  frame_count += 1
86
 
87
  video.release()
88
+ score = calculate_safety_score(violations)
89
+ return violations, score
90
 
91
  except Exception as e:
92
+ print(f"Error processing video: {e}")
93
  return [], f"Error: {e}"
94
 
95
+ # ==========================
96
+ # Safety Score
97
+ # ==========================
98
  def calculate_safety_score(violations):
99
  base_score = 100
100
  penalties = {
 
107
  base_score -= penalties.get(v["violation"], 0)
108
  return max(base_score, 0)
109
 
110
+ # ==========================
111
+ # Gradio Interface
112
+ # ==========================
113
  def gradio_interface(video_file):
114
  if not video_file:
115
  return "Please upload a video file.", ""
 
125
  gr.Textbox(label="Compliance Score")
126
  ],
127
  title="Worksite Safety Violation Analyzer",
128
+ description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
129
  )
130
 
131
  if __name__ == "__main__":
132
+ print("🚀 Launching Safety Analyzer App...")
133
  interface.launch()