PrashanthB461 commited on
Commit
6c6d7c5
·
verified ·
1 Parent(s): dc9b74e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -41
app.py CHANGED
@@ -2,16 +2,17 @@ import os
2
  import cv2
3
  import gradio as gr
4
  import torch
 
5
  from ultralytics import YOLO
6
  import time
7
- import logging
8
  from concurrent.futures import ThreadPoolExecutor
 
9
 
10
  # ==========================
11
  # Optimized Configuration
12
  # ==========================
13
  CONFIG = {
14
- "MODEL_PATH": "yolov8_safety.pt",
15
  "OUTPUT_DIR": "static/output",
16
  "VIOLATION_LABELS": {
17
  0: "no_helmet",
@@ -20,11 +21,18 @@ CONFIG = {
20
  3: "unsafe_zone",
21
  4: "improper_tool_use"
22
  },
23
- "FRAME_SKIP": 10, # Increased for faster processing
24
- "MAX_PROCESSING_TIME": 60,
25
- "CONFIDENCE_THRESHOLD": 0.35, # Balanced threshold
26
- "MIN_VIOLATION_FRAMES": 2, # Reduced from 3 to 2
27
- "GPU_ACCELERATION": True # Enable GPU if available
 
 
 
 
 
 
 
28
  }
29
 
30
  # Setup logging
@@ -38,9 +46,27 @@ logger.info(f"Using device: {device}")
38
  # Load model
39
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def process_frame(frame, frame_count, fps):
42
- """Process a single frame for violations"""
43
- results = model(frame, device=device, verbose=False) # Disable verbose logging
44
  detections = []
45
 
46
  for result in results:
@@ -51,21 +77,22 @@ def process_frame(frame, frame_count, fps):
51
  detections.append({
52
  "frame": frame_count,
53
  "violation": CONFIG["VIOLATION_LABELS"][cls],
54
- "confidence": conf,
55
  "bounding_box": box.xywh.cpu().numpy()[0],
56
  "timestamp": frame_count / fps
57
  })
58
  return detections
59
 
60
  def process_video(video_path):
61
- """Optimized video processing with parallel frame processing"""
62
  start_time = time.time()
63
  cap = cv2.VideoCapture(video_path)
64
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
65
  violations = []
 
66
  frame_count = 0
67
 
68
- with ThreadPoolExecutor() as executor:
69
  futures = []
70
  while cap.isOpened():
71
  ret, frame = cap.read()
@@ -73,15 +100,34 @@ def process_video(video_path):
73
  break
74
 
75
  if frame_count % CONFIG["FRAME_SKIP"] == 0:
76
- futures.append(executor.submit(process_frame, frame.copy(), frame_count, fps))
 
 
 
 
 
77
 
78
  frame_count += 1
79
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
80
  logger.info("Processing time limit reached")
81
  break
82
 
 
83
  for future in futures:
84
- violations.extend(future.result())
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  cap.release()
87
 
@@ -96,48 +142,94 @@ def process_video(video_path):
96
  if violation_counts.get((v["violation"], int(v["timestamp"])), 0) >= CONFIG["MIN_VIOLATION_FRAMES"]
97
  ]
98
 
99
- return filtered_violations, time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def analyze_video(video_file):
102
- """Optimized Gradio interface function"""
103
  if not video_file:
104
  return "No video uploaded", "", "", ""
105
 
106
  try:
107
- start_time = time.time()
108
- violations, processing_time = process_video(video_file)
109
-
110
- # Generate simple output (removed PDF generation for speed)
111
- violation_table = (
112
- "| Violation Type | Timestamp (s) | Confidence |\n"
113
- "|----------------|---------------|------------|\n" +
114
- "\n".join(
115
- f"| {v['violation']:<14} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
116
- for v in violations
117
- ) if violations else "No violations detected."
118
- )
119
-
120
- return (
121
- violation_table,
122
- f"Processing Time: {processing_time:.1f}s",
123
- f"Violations Found: {len(violations)}",
124
- f"Analysis Completed in {time.time()-start_time:.1f}s"
125
- )
126
  except Exception as e:
 
127
  return f"Error: {str(e)}", "", "", ""
128
 
129
- # Simplified Gradio Interface
130
  interface = gr.Interface(
131
  fn=analyze_video,
132
  inputs=gr.Video(label="Upload Site Video"),
133
  outputs=[
134
  gr.Markdown("## Detected Violations"),
135
- gr.Textbox(label="Processing Info"),
136
- gr.Textbox(label="Summary"),
137
- gr.Textbox(label="Status")
138
  ],
139
- title="Fast Safety Compliance Analyzer",
140
- description="Optimized version for quick safety violation detection"
 
141
  )
142
 
143
  if __name__ == "__main__":
 
2
  import cv2
3
  import gradio as gr
4
  import torch
5
+ import numpy as np
6
  from ultralytics import YOLO
7
  import time
 
8
  from concurrent.futures import ThreadPoolExecutor
9
+ import logging
10
 
11
  # ==========================
12
  # Optimized Configuration
13
  # ==========================
14
  CONFIG = {
15
+ "MODEL_PATH": "yolov8_safety.pt", # Your trained model
16
  "OUTPUT_DIR": "static/output",
17
  "VIOLATION_LABELS": {
18
  0: "no_helmet",
 
21
  3: "unsafe_zone",
22
  4: "improper_tool_use"
23
  },
24
+ "CLASS_COLORS": {
25
+ "no_helmet": (0, 0, 255), # Red
26
+ "no_harness": (0, 165, 255), # Orange
27
+ "unsafe_posture": (0, 255, 0), # Green
28
+ "unsafe_zone": (255, 0, 0), # Blue
29
+ "improper_tool_use": (255, 255, 0) # Yellow
30
+ },
31
+ "FRAME_SKIP": 8, # Process every 8th frame
32
+ "MAX_PROCESSING_TIME": 45, # Max processing time (seconds)
33
+ "CONFIDENCE_THRESHOLD": 0.35, # Balanced threshold
34
+ "MIN_VIOLATION_FRAMES": 2, # Reduced from 3 to 2
35
+ "GPU_ACCELERATION": True # Enable GPU if available
36
  }
37
 
38
  # Setup logging
 
46
  # Load model
47
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
48
 
49
+ def draw_detections(frame, detections):
50
+ """Draw bounding boxes with labels on frame."""
51
+ for det in detections:
52
+ label = det["violation"]
53
+ conf = det["confidence"]
54
+ x, y, w, h = det["bounding_box"]
55
+ x1, y1 = int(x - w/2), int(y - h/2)
56
+ x2, y2 = int(x + w/2), int(y + h/2)
57
+
58
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
59
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
60
+ cv2.putText(frame,
61
+ f"{label}: {conf:.2f}",
62
+ (x1, y1-10),
63
+ cv2.FONT_HERSHEY_SIMPLEX,
64
+ 0.5, color, 2)
65
+ return frame
66
+
67
  def process_frame(frame, frame_count, fps):
68
+ """Process a single frame for violations (optimized)"""
69
+ results = model(frame, device=device, verbose=False)
70
  detections = []
71
 
72
  for result in results:
 
77
  detections.append({
78
  "frame": frame_count,
79
  "violation": CONFIG["VIOLATION_LABELS"][cls],
80
+ "confidence": round(conf, 2),
81
  "bounding_box": box.xywh.cpu().numpy()[0],
82
  "timestamp": frame_count / fps
83
  })
84
  return detections
85
 
86
  def process_video(video_path):
87
+ """Optimized video processing with parallel execution"""
88
  start_time = time.time()
89
  cap = cv2.VideoCapture(video_path)
90
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
91
  violations = []
92
+ snapshots = {}
93
  frame_count = 0
94
 
95
+ with ThreadPoolExecutor(max_workers=4) as executor:
96
  futures = []
97
  while cap.isOpened():
98
  ret, frame = cap.read()
 
100
  break
101
 
102
  if frame_count % CONFIG["FRAME_SKIP"] == 0:
103
+ futures.append(executor.submit(
104
+ process_frame,
105
+ frame.copy(),
106
+ frame_count,
107
+ fps
108
+ ))
109
 
110
  frame_count += 1
111
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
112
  logger.info("Processing time limit reached")
113
  break
114
 
115
+ # Process results as they complete
116
  for future in futures:
117
+ frame_detections = future.result()
118
+ violations.extend(frame_detections)
119
+
120
+ # Capture first occurrence of each violation type
121
+ for det in frame_detections:
122
+ if det["violation"] not in snapshots:
123
+ snapshots[det["violation"]] = {
124
+ "frame": det["frame"],
125
+ "timestamp": det["timestamp"],
126
+ "image": draw_detections(
127
+ cv2.cvtColor(cap.read()[1], cv2.COLOR_BGR2RGB),
128
+ [det]
129
+ )
130
+ }
131
 
132
  cap.release()
133
 
 
142
  if violation_counts.get((v["violation"], int(v["timestamp"])), 0) >= CONFIG["MIN_VIOLATION_FRAMES"]
143
  ]
144
 
145
+ # Prepare snapshot outputs
146
+ snapshot_outputs = []
147
+ for violation_type, data in snapshots.items():
148
+ snapshot_path = os.path.join(
149
+ CONFIG["OUTPUT_DIR"],
150
+ f"{violation_type}_{data['frame']}.jpg"
151
+ )
152
+ cv2.imwrite(snapshot_path, data["image"])
153
+ snapshot_outputs.append({
154
+ "violation": violation_type,
155
+ "frame": data["frame"],
156
+ "timestamp": data["timestamp"],
157
+ "path": snapshot_path
158
+ })
159
+
160
+ return {
161
+ "violations": filtered_violations,
162
+ "snapshots": snapshot_outputs,
163
+ "processing_time": time.time() - start_time
164
+ }
165
+
166
+ def calculate_safety_score(violations):
167
+ """Calculate safety score (0-100)"""
168
+ penalty_weights = {
169
+ "no_helmet": 25,
170
+ "no_harness": 30,
171
+ "unsafe_posture": 20,
172
+ "unsafe_zone": 35,
173
+ "improper_tool_use": 25
174
+ }
175
+ unique_violations = set((v["violation"]) for v in violations)
176
+ total_penalty = sum(penalty_weights.get(v, 0) for v in unique_violations)
177
+ return max(100 - total_penalty, 0)
178
+
179
+ def format_output(result):
180
+ """Format results for Gradio output"""
181
+ # Violation table
182
+ violation_table = (
183
+ "| Violation Type | Timestamp (s) | Confidence |\n"
184
+ "|----------------|---------------|------------|\n" +
185
+ "\n".join(
186
+ f"| {v['violation']:<14} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
187
+ for v in result["violations"]
188
+ ) if result["violations"] else "No violations detected."
189
+ )
190
+
191
+ # Snapshots
192
+ snapshots_md = "\n".join(
193
+ f"**{s['violation']}** at {s['timestamp']:.1f}s: "
194
+ f"![](file/{s['path']})"
195
+ for s in result["snapshots"]
196
+ ) if result["snapshots"] else "No snapshots available."
197
+
198
+ # Safety score
199
+ safety_score = calculate_safety_score(result["violations"])
200
+
201
+ return (
202
+ violation_table,
203
+ f"Safety Score: {safety_score}%",
204
+ snapshots_md,
205
+ f"Processed in {result['processing_time']:.1f}s"
206
+ )
207
 
208
  def analyze_video(video_file):
209
+ """Gradio interface function"""
210
  if not video_file:
211
  return "No video uploaded", "", "", ""
212
 
213
  try:
214
+ result = process_video(video_file)
215
+ return format_output(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  except Exception as e:
217
+ logger.error(f"Error: {str(e)}")
218
  return f"Error: {str(e)}", "", "", ""
219
 
220
+ # Gradio Interface
221
  interface = gr.Interface(
222
  fn=analyze_video,
223
  inputs=gr.Video(label="Upload Site Video"),
224
  outputs=[
225
  gr.Markdown("## Detected Violations"),
226
+ gr.Textbox(label="Safety Score"),
227
+ gr.Markdown("## Violation Snapshots"),
228
+ gr.Textbox(label="Processing Info")
229
  ],
230
+ title="AI Safety Compliance Analyzer",
231
+ description="Optimized for fast detection of safety violations",
232
+ allow_flagging="never"
233
  )
234
 
235
  if __name__ == "__main__":