PrashanthB461 commited on
Commit
d195ce0
·
verified ·
1 Parent(s): 4b68be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -112
app.py CHANGED
@@ -14,9 +14,11 @@ import base64
14
  import logging
15
  from retrying import retry
16
  import uuid
 
 
17
 
18
  # ==========================
19
- # Enhanced Configuration
20
  # ==========================
21
  CONFIG = {
22
  "MODEL_PATH": "yolov8_safety.pt",
@@ -50,7 +52,6 @@ CONFIG = {
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
- "FRAME_SKIP": 2, # Process every 2nd frame for balance of speed/accuracy
54
  "CONFIDENCE_THRESHOLDS": {
55
  "no_helmet": 0.6,
56
  "no_harness": 0.15,
@@ -59,12 +60,12 @@ CONFIG = {
59
  "improper_tool_use": 0.15
60
  },
61
  "IOU_THRESHOLD": 0.4,
62
- "MIN_VIOLATION_FRAMES": 3, # Require more consistent detections
63
  "HELMET_CONFIDENCE_THRESHOLD": 0.65,
64
- "WORKER_TRACKING_DURATION": 3.0, # Seconds to track a worker
65
- "MIN_FRAME_RATE": 5, # Minimum frames per second to process
66
- "MAX_FRAME_RATE": 15, # Maximum frames per second to process
67
- "BATCH_SIZE": 8 # Number of frames to process at once
68
  }
69
 
70
  # Setup logging
@@ -83,7 +84,7 @@ def load_model():
83
  logger.info(f"Model loaded: {model_path}")
84
  else:
85
  model_path = CONFIG["FALLBACK_MODEL"]
86
- logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
87
  if not os.path.isfile(model_path):
88
  logger.info(f"Downloading fallback model: {model_path}")
89
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
@@ -121,7 +122,7 @@ def calculate_iou(box1, box2):
121
  x1, y1, w1, h1 = box1
122
  x2, y2, w2, h2 = box2
123
 
124
- # Calculate coordinates of the intersection rectangle
125
  x_left = max(x1 - w1/2, x2 - w2/2)
126
  y_top = max(y1 - h1/2, y2 - h2/2)
127
  x_right = min(x1 + w1/2, x2 + w2/2)
@@ -137,6 +138,36 @@ def calculate_iou(box1, box2):
137
 
138
  return intersection_area / union_area
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def generate_violation_pdf(violations, score):
141
  try:
142
  pdf_filename = f"violations_{int(time.time())}.pdf"
@@ -199,7 +230,7 @@ def calculate_safety_score(violations):
199
  return max(score, 0)
200
 
201
  # ==========================
202
- # Optimized Video Processing
203
  # ==========================
204
  def process_video(video_data):
205
  try:
@@ -218,120 +249,104 @@ def process_video(video_data):
218
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
219
  fps = cap.get(cv2.CAP_PROP_FPS)
220
  if fps <= 0:
221
- fps = 30 # Default assumption if FPS not available
222
  duration = total_frames / fps
223
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
224
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
225
 
226
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
227
 
228
- # Calculate optimal frame skipping
229
- original_frame_skip = CONFIG["FRAME_SKIP"]
230
- target_fps = min(max(fps / original_frame_skip, CONFIG["MIN_FRAME_RATE"]), CONFIG["MAX_FRAME_RATE"])
231
- actual_frame_skip = max(1, int(fps / target_fps))
232
- frames_to_process = total_frames // actual_frame_skip
233
- logger.info(f"Processing strategy: Frame skip={actual_frame_skip}, Target FPS={target_fps:.1f}, Frames to process={frames_to_process}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
 
 
 
235
  workers = []
236
  violations = []
237
  helmet_violations = {}
238
  snapshots = []
239
  start_time = time.time()
240
- processed_frames = 0
241
- last_progress_update = 0
 
 
 
 
 
 
 
 
242
 
243
- # Process frames in batches
244
- while True:
245
- batch_frames = []
246
- batch_indices = []
247
 
248
- # Collect frames for this batch
249
- for _ in range(CONFIG["BATCH_SIZE"]):
250
- frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
251
- if frame_idx >= total_frames:
252
- break
253
-
254
- ret, frame = cap.read()
255
- if not ret:
256
- break
257
-
258
- batch_frames.append(frame)
259
- batch_indices.append(frame_idx)
260
- processed_frames += 1
261
-
262
- # Skip frames according to our strategy
263
- for _ in range(actual_frame_skip - 1):
264
- if not cap.grab():
265
- break
 
 
 
 
 
 
 
 
266
 
267
- # Break if no more frames
268
- if not batch_frames:
269
- break
270
 
271
- # Run batch detection
272
- results = model(batch_frames, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
273
-
274
- # Process results for each frame in batch
275
- for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
276
- current_time = frame_idx / fps
277
-
278
- # Update progress periodically
279
- if time.time() - last_progress_update > 1.0: # Update every second
280
- progress = (frame_idx / total_frames) * 100
281
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
282
- last_progress_update = time.time()
283
-
284
- # Process detections in this frame
285
- boxes = result.boxes
286
- for box in boxes:
287
- cls = int(box.cls)
288
- conf = float(box.conf)
289
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
290
-
291
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
292
- continue
293
-
294
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
295
- detection = {
296
- "frame": frame_idx,
297
- "violation": label,
298
- "confidence": round(conf, 2),
299
- "bounding_box": bbox,
300
- "timestamp": current_time
301
- }
302
-
303
- # Worker tracking
304
- worker_id = None
305
- max_iou = 0
306
- for idx, worker in enumerate(workers):
307
- iou = calculate_iou(bbox, worker["bbox"])
308
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
309
- max_iou = iou
310
- worker_id = worker["id"]
311
- workers[idx]["bbox"] = bbox # Update worker position
312
- workers[idx]["last_seen"] = current_time
313
-
314
- if worker_id is None:
315
- worker_id = len(workers) + 1
316
- workers.append({
317
- "id": worker_id,
318
- "bbox": bbox,
319
- "first_seen": current_time,
320
- "last_seen": current_time
321
- })
322
-
323
- detection["worker_id"] = worker_id
324
-
325
- # Special handling for helmet violations
326
- if label == "no_helmet":
327
- if worker_id not in helmet_violations:
328
- helmet_violations[worker_id] = []
329
- helmet_violations[worker_id].append(detection)
330
- else:
331
- violations.append(detection)
332
-
333
- # Remove workers not seen recently
334
- workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
335
 
336
  # Process helmet violations (require consistent detections)
337
  for worker_id, detections in helmet_violations.items():
@@ -341,6 +356,7 @@ def process_video(video_data):
341
  violations.append(best_detection)
342
 
343
  # Capture snapshot for this violation
 
344
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
345
  ret, snapshot_frame = cap.read()
346
  if ret:
@@ -354,8 +370,8 @@ def process_video(video_data):
354
  "snapshot_path": snapshot_path,
355
  "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
356
  })
 
357
 
358
- cap.release()
359
  os.remove(video_path)
360
  processing_time = time.time() - start_time
361
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
@@ -472,11 +488,11 @@ interface = gr.Interface(
472
  gr.Textbox(label="Salesforce Record ID"),
473
  gr.Textbox(label="Violation Details URL")
474
  ],
475
- title="Worksite Safety Violation Analyzer",
476
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
477
  allow_flagging="never"
478
  )
479
 
480
  if __name__ == "__main__":
481
- logger.info("Launching Enhanced Safety Analyzer App...")
482
  interface.launch()
 
14
  import logging
15
  from retrying import retry
16
  import uuid
17
+ from multiprocessing import Pool, cpu_count
18
+ from functools import partial
19
 
20
  # ==========================
21
+ # Ultra-Fast Configuration
22
  # ==========================
23
  CONFIG = {
24
  "MODEL_PATH": "yolov8_safety.pt",
 
52
  "domain": "login"
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
 
55
  "CONFIDENCE_THRESHOLDS": {
56
  "no_helmet": 0.6,
57
  "no_harness": 0.15,
 
60
  "improper_tool_use": 0.15
61
  },
62
  "IOU_THRESHOLD": 0.4,
63
+ "MIN_VIOLATION_FRAMES": 3,
64
  "HELMET_CONFIDENCE_THRESHOLD": 0.65,
65
+ "WORKER_TRACKING_DURATION": 3.0,
66
+ "MAX_PROCESSING_TIME": 30, # 30 second hard limit
67
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all but one CPU core
68
+ "CHUNK_SIZE": 10 # Frames per parallel batch
69
  }
70
 
71
  # Setup logging
 
84
  logger.info(f"Model loaded: {model_path}")
85
  else:
86
  model_path = CONFIG["FALLBACK_MODEL"]
87
+ logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
88
  if not os.path.isfile(model_path):
89
  logger.info(f"Downloading fallback model: {model_path}")
90
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
 
122
  x1, y1, w1, h1 = box1
123
  x2, y2, w2, h2 = box2
124
 
125
+ # Calculate intersection coordinates
126
  x_left = max(x1 - w1/2, x2 - w2/2)
127
  y_top = max(y1 - h1/2, y2 - h2/2)
128
  x_right = min(x1 + w1/2, x2 + w2/2)
 
138
 
139
  return intersection_area / union_area
140
 
141
+ def process_frame_batch(frame_batch, frame_indices, fps):
142
+ batch_results = []
143
+ results = model(frame_batch, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
144
+
145
+ for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
146
+ current_time = frame_idx / fps
147
+ detections = []
148
+
149
+ boxes = result.boxes
150
+ for box in boxes:
151
+ cls = int(box.cls)
152
+ conf = float(box.conf)
153
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
154
+
155
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
156
+ continue
157
+
158
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
159
+ detections.append({
160
+ "frame": frame_idx,
161
+ "violation": label,
162
+ "confidence": round(conf, 2),
163
+ "bounding_box": bbox,
164
+ "timestamp": current_time
165
+ })
166
+
167
+ batch_results.append((frame_idx, detections))
168
+
169
+ return batch_results
170
+
171
  def generate_violation_pdf(violations, score):
172
  try:
173
  pdf_filename = f"violations_{int(time.time())}.pdf"
 
230
  return max(score, 0)
231
 
232
  # ==========================
233
+ # Ultra-Fast Video Processing
234
  # ==========================
235
  def process_video(video_data):
236
  try:
 
249
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
250
  fps = cap.get(cv2.CAP_PROP_FPS)
251
  if fps <= 0:
252
+ fps = 30
253
  duration = total_frames / fps
254
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
255
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
256
 
257
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
258
 
259
+ # Prepare for parallel processing
260
+ frame_batches = []
261
+ frame_indices_batches = []
262
+ current_batch = []
263
+ current_indices = []
264
+
265
+ # Read all frames upfront for parallel processing
266
+ all_frames = []
267
+ all_indices = []
268
+ for frame_idx in range(total_frames):
269
+ ret, frame = cap.read()
270
+ if not ret:
271
+ break
272
+ all_frames.append(frame)
273
+ all_indices.append(frame_idx)
274
+
275
+ # Organize into batches
276
+ if len(current_batch) >= CONFIG["CHUNK_SIZE"]:
277
+ frame_batches.append(current_batch)
278
+ frame_indices_batches.append(current_indices)
279
+ current_batch = []
280
+ current_indices = []
281
+
282
+ # Add remaining frames
283
+ if current_batch:
284
+ frame_batches.append(current_batch)
285
+ frame_indices_batches.append(current_indices)
286
 
287
+ cap.release()
288
+
289
+ # Process frames in parallel
290
  workers = []
291
  violations = []
292
  helmet_violations = {}
293
  snapshots = []
294
  start_time = time.time()
295
+
296
+ # Use multiprocessing Pool
297
+ with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
298
+ process_func = partial(process_frame_batch, fps=fps)
299
+ results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
300
+
301
+ # Flatten results
302
+ all_detections = []
303
+ for batch_result in results:
304
+ all_detections.extend(batch_result)
305
 
306
+ # Process detections and track workers
307
+ workers = []
308
+ for frame_idx, detections in sorted(all_detections, key=lambda x: x[0]):
309
+ current_time = frame_idx / fps
310
 
311
+ # Update progress periodically
312
+ if time.time() - start_time > 1.0: # Update every second
313
+ progress = (frame_idx / total_frames) * 100
314
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
315
+ start_time = time.time()
316
+
317
+ for detection in detections:
318
+ # Worker tracking
319
+ worker_id = None
320
+ max_iou = 0
321
+ for idx, worker in enumerate(workers):
322
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
323
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
324
+ max_iou = iou
325
+ worker_id = worker["id"]
326
+ workers[idx]["bbox"] = detection["bounding_box"]
327
+ workers[idx]["last_seen"] = current_time
328
+
329
+ if worker_id is None:
330
+ worker_id = len(workers) + 1
331
+ workers.append({
332
+ "id": worker_id,
333
+ "bbox": detection["bounding_box"],
334
+ "first_seen": current_time,
335
+ "last_seen": current_time
336
+ })
337
 
338
+ detection["worker_id"] = worker_id
 
 
339
 
340
+ # Special handling for helmet violations
341
+ if detection["violation"] == "no_helmet":
342
+ if worker_id not in helmet_violations:
343
+ helmet_violations[worker_id] = []
344
+ helmet_violations[worker_id].append(detection)
345
+ else:
346
+ violations.append(detection)
347
+
348
+ # Remove workers not seen recently
349
+ workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  # Process helmet violations (require consistent detections)
352
  for worker_id, detections in helmet_violations.items():
 
356
  violations.append(best_detection)
357
 
358
  # Capture snapshot for this violation
359
+ cap = cv2.VideoCapture(video_path)
360
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
361
  ret, snapshot_frame = cap.read()
362
  if ret:
 
370
  "snapshot_path": snapshot_path,
371
  "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
372
  })
373
+ cap.release()
374
 
 
375
  os.remove(video_path)
376
  processing_time = time.time() - start_time
377
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
 
488
  gr.Textbox(label="Salesforce Record ID"),
489
  gr.Textbox(label="Violation Details URL")
490
  ],
491
+ title="Ultra-Fast Safety Violation Analyzer",
492
+ description="Upload site videos to detect safety violations in under 30 seconds (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use).",
493
  allow_flagging="never"
494
  )
495
 
496
  if __name__ == "__main__":
497
+ logger.info("Launching Ultra-Fast Safety Analyzer App...")
498
  interface.launch()