PrashanthB461 commited on
Commit
64f085e
·
verified ·
1 Parent(s): 6c6d7c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -172
app.py CHANGED
@@ -5,14 +5,20 @@ import torch
5
  import numpy as np
6
  from ultralytics import YOLO
7
  import time
8
- from concurrent.futures import ThreadPoolExecutor
 
 
 
 
 
9
  import logging
 
10
 
11
  # ==========================
12
  # Optimized Configuration
13
  # ==========================
14
  CONFIG = {
15
- "MODEL_PATH": "yolov8_safety.pt", # Your trained model
16
  "OUTPUT_DIR": "static/output",
17
  "VIOLATION_LABELS": {
18
  0: "no_helmet",
@@ -22,17 +28,17 @@ CONFIG = {
22
  4: "improper_tool_use"
23
  },
24
  "CLASS_COLORS": {
25
- "no_helmet": (0, 0, 255), # Red
26
- "no_harness": (0, 165, 255), # Orange
27
- "unsafe_posture": (0, 255, 0), # Green
28
- "unsafe_zone": (255, 0, 0), # Blue
29
- "improper_tool_use": (255, 255, 0) # Yellow
30
  },
31
- "FRAME_SKIP": 8, # Process every 8th frame
32
- "MAX_PROCESSING_TIME": 45, # Max processing time (seconds)
33
- "CONFIDENCE_THRESHOLD": 0.35, # Balanced threshold
34
- "MIN_VIOLATION_FRAMES": 2, # Reduced from 3 to 2
35
- "GPU_ACCELERATION": True # Enable GPU if available
36
  }
37
 
38
  # Setup logging
@@ -40,195 +46,177 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
40
  logger = logging.getLogger(__name__)
41
 
42
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
43
- device = torch.device("cuda" if torch.cuda.is_available() and CONFIG["GPU_ACCELERATION"] else "cpu")
44
  logger.info(f"Using device: {device}")
45
 
46
- # Load model
47
- model = YOLO(CONFIG["MODEL_PATH"]).to(device)
48
-
49
- def draw_detections(frame, detections):
50
- """Draw bounding boxes with labels on frame."""
51
- for det in detections:
52
- label = det["violation"]
53
- conf = det["confidence"]
54
- x, y, w, h = det["bounding_box"]
55
- x1, y1 = int(x - w/2), int(y - h/2)
56
- x2, y2 = int(x + w/2), int(y + h/2)
57
-
58
- color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
59
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
60
- cv2.putText(frame,
61
- f"{label}: {conf:.2f}",
62
- (x1, y1-10),
63
- cv2.FONT_HERSHEY_SIMPLEX,
64
- 0.5, color, 2)
65
- return frame
66
-
67
- def process_frame(frame, frame_count, fps):
68
- """Process a single frame for violations (optimized)"""
69
- results = model(frame, device=device, verbose=False)
70
- detections = []
71
-
72
- for result in results:
73
- for box in result.boxes:
74
- cls = int(box.cls)
75
- conf = float(box.conf)
76
- if conf > CONFIG["CONFIDENCE_THRESHOLD"] and cls in CONFIG["VIOLATION_LABELS"]:
77
- detections.append({
78
- "frame": frame_count,
79
- "violation": CONFIG["VIOLATION_LABELS"][cls],
80
- "confidence": round(conf, 2),
81
- "bounding_box": box.xywh.cpu().numpy()[0],
82
- "timestamp": frame_count / fps
83
- })
84
- return detections
85
 
 
 
 
86
  def process_video(video_path):
87
- """Optimized video processing with parallel execution"""
88
- start_time = time.time()
89
- cap = cv2.VideoCapture(video_path)
90
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
91
- violations = []
92
- snapshots = {}
93
- frame_count = 0
94
-
95
- with ThreadPoolExecutor(max_workers=4) as executor:
96
- futures = []
97
  while cap.isOpened():
98
  ret, frame = cap.read()
99
  if not ret:
100
  break
101
-
102
- if frame_count % CONFIG["FRAME_SKIP"] == 0:
103
- futures.append(executor.submit(
104
- process_frame,
105
- frame.copy(),
106
- frame_count,
107
- fps
108
- ))
109
 
110
- frame_count += 1
 
 
 
 
 
111
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
112
  logger.info("Processing time limit reached")
113
  break
114
-
115
- # Process results as they complete
116
- for future in futures:
117
- frame_detections = future.result()
118
- violations.extend(frame_detections)
119
 
120
- # Capture first occurrence of each violation type
121
- for det in frame_detections:
122
- if det["violation"] not in snapshots:
123
- snapshots[det["violation"]] = {
124
- "frame": det["frame"],
125
- "timestamp": det["timestamp"],
126
- "image": draw_detections(
127
- cv2.cvtColor(cap.read()[1], cv2.COLOR_BGR2RGB),
128
- [det]
129
- )
 
 
 
 
 
 
 
 
 
130
  }
131
-
132
- cap.release()
133
-
134
- # Filter violations by frequency
135
- violation_counts = {}
136
- for v in violations:
137
- key = (v["violation"], int(v["timestamp"]))
138
- violation_counts[key] = violation_counts.get(key, 0) + 1
139
-
140
- filtered_violations = [
141
- v for v in violations
142
- if violation_counts.get((v["violation"], int(v["timestamp"])), 0) >= CONFIG["MIN_VIOLATION_FRAMES"]
143
- ]
144
-
145
- # Prepare snapshot outputs
146
- snapshot_outputs = []
147
- for violation_type, data in snapshots.items():
148
- snapshot_path = os.path.join(
149
- CONFIG["OUTPUT_DIR"],
150
- f"{violation_type}_{data['frame']}.jpg"
151
- )
152
- cv2.imwrite(snapshot_path, data["image"])
153
- snapshot_outputs.append({
154
- "violation": violation_type,
155
- "frame": data["frame"],
156
- "timestamp": data["timestamp"],
157
- "path": snapshot_path
158
- })
159
-
160
- return {
161
- "violations": filtered_violations,
162
- "snapshots": snapshot_outputs,
163
- "processing_time": time.time() - start_time
164
- }
165
-
166
- def calculate_safety_score(violations):
167
- """Calculate safety score (0-100)"""
168
- penalty_weights = {
169
- "no_helmet": 25,
170
- "no_harness": 30,
171
- "unsafe_posture": 20,
172
- "unsafe_zone": 35,
173
- "improper_tool_use": 25
174
- }
175
- unique_violations = set((v["violation"]) for v in violations)
176
- total_penalty = sum(penalty_weights.get(v, 0) for v in unique_violations)
177
- return max(100 - total_penalty, 0)
178
-
179
- def format_output(result):
180
- """Format results for Gradio output"""
181
- # Violation table
182
- violation_table = (
183
- "| Violation Type | Timestamp (s) | Confidence |\n"
184
- "|----------------|---------------|------------|\n" +
185
- "\n".join(
186
- f"| {v['violation']:<14} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
187
- for v in result["violations"]
188
- ) if result["violations"] else "No violations detected."
189
- )
190
-
191
- # Snapshots
192
- snapshots_md = "\n".join(
193
- f"**{s['violation']}** at {s['timestamp']:.1f}s: "
194
- f"![](file/{s['path']})"
195
- for s in result["snapshots"]
196
- ) if result["snapshots"] else "No snapshots available."
197
-
198
- # Safety score
199
- safety_score = calculate_safety_score(result["violations"])
200
-
201
- return (
202
- violation_table,
203
- f"Safety Score: {safety_score}%",
204
- snapshots_md,
205
- f"Processed in {result['processing_time']:.1f}s"
206
- )
207
 
 
 
 
208
  def analyze_video(video_file):
209
- """Gradio interface function"""
210
  if not video_file:
211
- return "No video uploaded", "", "", ""
 
 
 
212
 
213
  try:
 
214
  result = process_video(video_file)
215
- return format_output(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  except Exception as e:
217
- logger.error(f"Error: {str(e)}")
218
- return f"Error: {str(e)}", "", "", ""
219
 
220
- # Gradio Interface
221
  interface = gr.Interface(
222
  fn=analyze_video,
223
  inputs=gr.Video(label="Upload Site Video"),
224
  outputs=[
225
  gr.Markdown("## Detected Violations"),
226
- gr.Textbox(label="Safety Score"),
227
  gr.Markdown("## Violation Snapshots"),
228
- gr.Textbox(label="Processing Info")
 
229
  ],
230
- title="AI Safety Compliance Analyzer",
231
- description="Optimized for fast detection of safety violations",
232
  allow_flagging="never"
233
  )
234
 
 
5
  import numpy as np
6
  from ultralytics import YOLO
7
  import time
8
+ from simple_salesforce import Salesforce
9
+ from reportlab.lib.pagesizes import letter
10
+ from reportlab.pdfgen import canvas
11
+ from reportlab.lib.units import inch
12
+ from io import BytesIO
13
+ import base64
14
  import logging
15
+ from retrying import retry
16
 
17
  # ==========================
18
  # Optimized Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt",
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
 
28
  4: "improper_tool_use"
29
  },
30
  "CLASS_COLORS": {
31
+ "no_helmet": (0, 0, 255),
32
+ "no_harness": (0, 165, 255),
33
+ "unsafe_posture": (0, 255, 0),
34
+ "unsafe_zone": (255, 0, 0),
35
+ "improper_tool_use": (255, 255, 0)
36
  },
37
+ "FRAME_SKIP": 10, # Increased from 5 to 10 for faster processing
38
+ "MAX_PROCESSING_TIME": 30, # Reduced from 60 to 30 seconds
39
+ "CONFIDENCE_THRESHOLD": 0.35, # Slightly increased for faster filtering
40
+ "IOU_THRESHOLD": 0.5,
41
+ "MIN_VIOLATION_FRAMES": 2 # Reduced from 3 to 2 frames
42
  }
43
 
44
  # Setup logging
 
46
  logger = logging.getLogger(__name__)
47
 
48
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  logger.info(f"Using device: {device}")
51
 
52
+ # Load model with caching
53
+ model = None
54
+ def load_model():
55
+ global model
56
+ if model is None:
57
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
58
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # ==========================
61
+ # Optimized Processing Functions
62
+ # ==========================
63
  def process_video(video_path):
64
+ """Optimized video processing with early termination"""
65
+ try:
66
+ cap = cv2.VideoCapture(video_path)
67
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
68
+ frame_count = 0
69
+ violations = []
70
+ workers = []
71
+ snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
72
+ start_time = time.time()
73
+
74
  while cap.isOpened():
75
  ret, frame = cap.read()
76
  if not ret:
77
  break
 
 
 
 
 
 
 
 
78
 
79
+ # Skip frames for faster processing
80
+ if frame_count % CONFIG["FRAME_SKIP"] != 0:
81
+ frame_count += 1
82
+ continue
83
+
84
+ # Early termination if processing takes too long
85
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
86
  logger.info("Processing time limit reached")
87
  break
 
 
 
 
 
88
 
89
+ # Run detection on current frame
90
+ results = load_model()(frame, device=device, verbose=False) # Disable verbose logging
91
+
92
+ for result in results:
93
+ for box in result.boxes:
94
+ cls = int(box.cls)
95
+ conf = float(box.conf)
96
+ label = CONFIG["VIOLATION_LABELS"].get(cls)
97
+
98
+ if not label or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
99
+ continue
100
+
101
+ bbox = box.xywh.cpu().numpy()[0]
102
+ detection = {
103
+ "frame": frame_count,
104
+ "violation": label,
105
+ "confidence": conf,
106
+ "bounding_box": bbox,
107
+ "timestamp": frame_count / fps
108
  }
109
+
110
+ # Simplified worker tracking
111
+ matched = False
112
+ for worker in workers:
113
+ if calculate_iou(worker["bbox"], bbox) > CONFIG["IOU_THRESHOLD"]:
114
+ worker["bbox"] = bbox
115
+ detection["worker_id"] = worker["id"]
116
+ matched = True
117
+ break
118
+
119
+ if not matched:
120
+ worker_id = len(workers) + 1
121
+ workers.append({"id": worker_id, "bbox": bbox})
122
+ detection["worker_id"] = worker_id
123
+
124
+ violations.append(detection)
125
+
126
+ # Capture snapshot if first detection of this type
127
+ if not snapshot_taken[label]:
128
+ snapshot_path = os.path.join(
129
+ CONFIG["OUTPUT_DIR"],
130
+ f"{label}_{frame_count}.jpg"
131
+ )
132
+ cv2.imwrite(snapshot_path, frame)
133
+ snapshot_taken[label] = True
134
+
135
+ frame_count += 1
136
+
137
+ cap.release()
138
+
139
+ # Filter violations (require min frames)
140
+ filtered_violations = []
141
+ violation_counts = {}
142
+ for v in violations:
143
+ key = (v["worker_id"], v["violation"])
144
+ violation_counts[key] = violation_counts.get(key, 0) + 1
145
+
146
+ for v in violations:
147
+ if violation_counts[(v["worker_id"], v["violation"])] >= CONFIG["MIN_VIOLATION_FRAMES"]:
148
+ filtered_violations.append(v)
149
+
150
+ return {
151
+ "violations": filtered_violations,
152
+ "snapshots": [k for k, v in snapshot_taken.items() if v],
153
+ "message": ""
154
+ }
155
+ except Exception as e:
156
+ logger.error(f"Video processing failed: {e}")
157
+ return {
158
+ "violations": [],
159
+ "snapshots": [],
160
+ "message": f"Error: {str(e)}"
161
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ # ==========================
164
+ # Fast Gradio Interface
165
+ # ==========================
166
  def analyze_video(video_file):
167
+ """Optimized interface with progress updates"""
168
  if not video_file:
169
+ return "No video uploaded", "", "", "", ""
170
+
171
+ # Immediate feedback
172
+ yield "Processing started...", "", "", "", ""
173
 
174
  try:
175
+ # Process video in background
176
  result = process_video(video_file)
177
+
178
+ if result["message"]:
179
+ yield result["message"], "", "", "", ""
180
+ return
181
+
182
+ # Generate simplified output
183
+ violation_table = (
184
+ "| Violation Type | Timestamp | Confidence | Worker ID |\n"
185
+ "|----------------|-----------|------------|-----------|\n" +
186
+ "\n".join(
187
+ f"| {v['violation']} | {v['timestamp']:.1f}s | {v['confidence']:.2f} | {v['worker_id']} |"
188
+ for v in result["violations"]
189
+ ) if result["violations"] else "No violations detected."
190
+ )
191
+
192
+ snapshots_md = "\n".join(
193
+ f"**{violation}** detected"
194
+ for violation in result["snapshots"]
195
+ ) if result["snapshots"] else "No snapshots available."
196
+
197
+ yield (
198
+ violation_table,
199
+ f"Analysis complete in {time.time() - start_time:.1f}s",
200
+ snapshots_md,
201
+ "Salesforce integration placeholder",
202
+ "Report URL placeholder"
203
+ )
204
  except Exception as e:
205
+ yield f"Error: {str(e)}", "", "", "", ""
 
206
 
207
+ # Launch optimized interface
208
  interface = gr.Interface(
209
  fn=analyze_video,
210
  inputs=gr.Video(label="Upload Site Video"),
211
  outputs=[
212
  gr.Markdown("## Detected Violations"),
213
+ gr.Textbox(label="Processing Time"),
214
  gr.Markdown("## Violation Snapshots"),
215
+ gr.Textbox(label="Salesforce Record"),
216
+ gr.Textbox(label="Report URL")
217
  ],
218
+ title="Fast Safety Compliance Analyzer",
219
+ description="Optimized for quick safety violation detection in worksite videos.",
220
  allow_flagging="never"
221
  )
222