PrashanthB461 commited on
Commit
0615d03
·
verified ·
1 Parent(s): 12dad16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -48
app.py CHANGED
@@ -30,11 +30,11 @@ CONFIG = {
30
  4: "improper_tool_use"
31
  },
32
  "CLASS_COLORS": {
33
- "no_helmet": (0, 0, 255), # Red
34
- "no_harness": (0, 165, 255), # Orange
35
- "unsafe_posture": (0, 255, 0), # Green
36
- "unsafe_zone": (255, 0, 0), # Blue
37
- "improper_tool_use": (255, 255, 0) # Yellow
38
  },
39
  "DISPLAY_NAMES": {
40
  "no_helmet": "No Helmet Violation",
@@ -50,13 +50,7 @@ CONFIG = {
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
- "FRAME_SKIP": {
54
- "no_helmet": 2,
55
- "no_harness": 1,
56
- "unsafe_posture": 1,
57
- "unsafe_zone": 1,
58
- "improper_tool_use": 1
59
- },
60
  "CONFIDENCE_THRESHOLDS": {
61
  "no_helmet": 0.5,
62
  "no_harness": 0.15,
@@ -65,8 +59,9 @@ CONFIG = {
65
  "improper_tool_use": 0.15
66
  },
67
  "IOU_THRESHOLD": 0.4,
68
- "MIN_VIOLATION_FRAMES": 3,
69
- "HELMET_CONFIDENCE_THRESHOLD": 0.7
 
70
  }
71
 
72
  # Setup logging
@@ -101,7 +96,6 @@ model = load_model()
101
  # Enhanced Helper Functions
102
  # ==========================
103
  def draw_detections(frame, detections):
104
- """Draw bounding boxes and labels on frame"""
105
  for det in detections:
106
  label = det.get("violation", "Unknown")
107
  confidence = det.get("confidence", 0.0)
@@ -121,7 +115,6 @@ def draw_detections(frame, detections):
121
  return frame
122
 
123
  def calculate_iou(box1, box2):
124
- """Calculate Intersection over Union (IoU) for two bounding boxes."""
125
  x1, y1, w1, h1 = box1
126
  x2, y2, w2, h2 = box2
127
 
@@ -299,34 +292,41 @@ def process_video(video_data):
299
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
300
  fps = video.get(cv2.CAP_PROP_FPS)
301
  if fps <= 0:
302
- fps = 30 # Default assumption if FPS cannot be determined
303
  video_duration = total_frames / fps
304
  logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")
305
 
306
  workers = []
307
  violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
- confirmed_violations = {} # Track confirmed violations per worker
309
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
310
- helmet_compliance = {} # Track workers with helmets
311
- detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()} # Track detection counts
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
- while True:
314
  ret, frame = video.read()
315
  if not ret:
316
- break
317
-
318
- current_time = frame_count / fps
319
- min_frame_skip = min(CONFIG["FRAME_SKIP"].values())
320
- if frame_count % min_frame_skip != 0:
321
- frame_count += 1
322
  continue
323
 
324
- # Yield progress update
325
- progress = (frame_count / total_frames) * 100
326
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_count}/{total_frames})", "", "", "", ""
 
327
 
328
  # Run detection on this frame
329
- results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], agnostic_nms=True)
330
 
331
  current_detections = []
332
  for result in results:
@@ -347,7 +347,7 @@ def process_video(video_data):
347
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
348
 
349
  current_detections.append({
350
- "frame": frame_count,
351
  "violation": label,
352
  "confidence": round(conf, 2),
353
  "bounding_box": bbox,
@@ -355,16 +355,14 @@ def process_video(video_data):
355
  })
356
  detection_counts[label] += 1
357
 
358
- logger.debug(f"Frame {frame_count}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
359
 
360
- # Process detections and associate with workers
361
  for detection in current_detections:
362
  violation_type = detection.get("violation", None)
363
  if violation_type is None:
364
  logger.error(f"Invalid detection, missing 'violation' key: {detection}")
365
  continue
366
 
367
- # Helmet compliance check
368
  if violation_type == "no_helmet":
369
  matched_worker = None
370
  max_iou = 0
@@ -376,7 +374,6 @@ def process_video(video_data):
376
 
377
  if matched_worker:
378
  worker_id = matched_worker["id"]
379
- # Require high confidence and persistence for No Helmet violation
380
  if worker_id not in helmet_compliance:
381
  helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
382
  helmet_compliance[worker_id]["no_helmet_frames"] += 1
@@ -387,7 +384,6 @@ def process_video(video_data):
387
  logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
388
  continue
389
 
390
- # Find or create worker
391
  matched_worker = None
392
  max_iou = 0
393
 
@@ -409,11 +405,9 @@ def process_video(video_data):
409
  "first_seen": current_time,
410
  "last_seen": current_time
411
  })
412
- # Initialize helmet compliance for new worker
413
  if worker_id not in helmet_compliance:
414
  helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
415
 
416
- # Skip if this violation type is already confirmed for this worker
417
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
418
  logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
419
  continue
@@ -421,14 +415,10 @@ def process_video(video_data):
421
  detection["worker_id"] = worker_id
422
  violation_history[violation_type].append(detection)
423
 
424
- # Clean up old workers
425
  workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
426
 
427
- frame_count += 1
428
-
429
  logger.info(f"Detection counts: {detection_counts}")
430
 
431
- # Process violation history to confirm persistent violations
432
  for violation_type, detections in violation_history.items():
433
  if not detections:
434
  logger.info(f"No detections for {violation_type}")
@@ -442,16 +432,13 @@ def process_video(video_data):
442
 
443
  for worker_id, worker_dets in worker_violations.items():
444
  if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
445
- # Skip if already confirmed
446
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
447
  continue
448
 
449
- # Skip No Helmet if worker is compliant
450
  if violation_type == "no_helmet":
451
  if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
452
  logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
453
  continue
454
- # Require persistent No Helmet detections
455
  if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2:
456
  logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections")
457
  continue
@@ -486,7 +473,6 @@ def process_video(video_data):
486
  logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
487
  cap.release()
488
 
489
- # Clean up video file after snapshots are captured
490
  video.release()
491
  os.remove(video_path)
492
  logger.info(f"Video file {video_path} removed")
@@ -537,7 +523,6 @@ def gradio_interface(video_file):
537
  with open(video_file, "rb") as f:
538
  video_data = f.read()
539
 
540
- # Use generator to yield updates
541
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
542
  yield status, score, snapshots_text, record_id, details_url
543
  except Exception as e:
 
30
  4: "improper_tool_use"
31
  },
32
  "CLASS_COLORS": {
33
+ "no_helmet": (0, 0, 255),
34
+ "no_harness": (0, 165, 255),
35
+ "unsafe_posture": (0, 255, 0),
36
+ "unsafe_zone": (255, 0, 0),
37
+ "improper_tool_use": (255, 255, 0)
38
  },
39
  "DISPLAY_NAMES": {
40
  "no_helmet": "No Helmet Violation",
 
50
  "domain": "login"
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
+ "FRAME_SKIP": 5, # Increased to process fewer frames (1 frame every 5 frames)
 
 
 
 
 
 
54
  "CONFIDENCE_THRESHOLDS": {
55
  "no_helmet": 0.5,
56
  "no_harness": 0.15,
 
59
  "improper_tool_use": 0.15
60
  },
61
  "IOU_THRESHOLD": 0.4,
62
+ "MIN_VIOLATION_FRAMES": 2, # Reduced to ensure violations are detected with fewer frames
63
+ "HELMET_CONFIDENCE_THRESHOLD": 0.7,
64
+ "MAX_PROCESSING_TIME": 30 # Enforce 30-second processing limit
65
  }
66
 
67
  # Setup logging
 
96
  # Enhanced Helper Functions
97
  # ==========================
98
  def draw_detections(frame, detections):
 
99
  for det in detections:
100
  label = det.get("violation", "Unknown")
101
  confidence = det.get("confidence", 0.0)
 
115
  return frame
116
 
117
  def calculate_iou(box1, box2):
 
118
  x1, y1, w1, h1 = box1
119
  x2, y2, w2, h2 = box2
120
 
 
292
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
293
  fps = video.get(cv2.CAP_PROP_FPS)
294
  if fps <= 0:
295
+ fps = 30
296
  video_duration = total_frames / fps
297
  logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")
298
 
299
  workers = []
300
  violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
301
+ confirmed_violations = {}
302
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
303
+ helmet_compliance = {}
304
+ detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()}
305
+ start_time = time.time()
306
+
307
+ # Calculate frames to process within 30 seconds
308
+ target_frames = int(total_frames / CONFIG["FRAME_SKIP"])
309
+ frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
310
+
311
+ processed_frames = 0
312
+ for idx in frame_indices:
313
+ elapsed_time = time.time() - start_time
314
+ if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
315
+ logger.info(f"Processing time limit of {CONFIG['MAX_PROCESSING_TIME']} seconds reached. Processed {processed_frames}/{target_frames} frames.")
316
+ break
317
 
318
+ video.set(cv2.CAP_PROP_POS_FRAMES, idx)
319
  ret, frame = video.read()
320
  if not ret:
 
 
 
 
 
 
321
  continue
322
 
323
+ processed_frames += 1
324
+ current_time = idx / fps
325
+ progress = (processed_frames / target_frames) * 100
326
+ yield f"Processing video... {progress:.1f}% complete (Frame {idx}/{total_frames})", "", "", "", ""
327
 
328
  # Run detection on this frame
329
+ results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"])
330
 
331
  current_detections = []
332
  for result in results:
 
347
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
348
 
349
  current_detections.append({
350
+ "frame": idx,
351
  "violation": label,
352
  "confidence": round(conf, 2),
353
  "bounding_box": bbox,
 
355
  })
356
  detection_counts[label] += 1
357
 
358
+ logger.debug(f"Frame {idx}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
359
 
 
360
  for detection in current_detections:
361
  violation_type = detection.get("violation", None)
362
  if violation_type is None:
363
  logger.error(f"Invalid detection, missing 'violation' key: {detection}")
364
  continue
365
 
 
366
  if violation_type == "no_helmet":
367
  matched_worker = None
368
  max_iou = 0
 
374
 
375
  if matched_worker:
376
  worker_id = matched_worker["id"]
 
377
  if worker_id not in helmet_compliance:
378
  helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
379
  helmet_compliance[worker_id]["no_helmet_frames"] += 1
 
384
  logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
385
  continue
386
 
 
387
  matched_worker = None
388
  max_iou = 0
389
 
 
405
  "first_seen": current_time,
406
  "last_seen": current_time
407
  })
 
408
  if worker_id not in helmet_compliance:
409
  helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
410
 
 
411
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
412
  logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
413
  continue
 
415
  detection["worker_id"] = worker_id
416
  violation_history[violation_type].append(detection)
417
 
 
418
  workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
419
 
 
 
420
  logger.info(f"Detection counts: {detection_counts}")
421
 
 
422
  for violation_type, detections in violation_history.items():
423
  if not detections:
424
  logger.info(f"No detections for {violation_type}")
 
432
 
433
  for worker_id, worker_dets in worker_violations.items():
434
  if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
 
435
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
436
  continue
437
 
 
438
  if violation_type == "no_helmet":
439
  if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
440
  logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
441
  continue
 
442
  if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2:
443
  logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections")
444
  continue
 
473
  logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
474
  cap.release()
475
 
 
476
  video.release()
477
  os.remove(video_path)
478
  logger.info(f"Video file {video_path} removed")
 
523
  with open(video_file, "rb") as f:
524
  video_data = f.read()
525
 
 
526
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
527
  yield status, score, snapshots_text, record_id, details_url
528
  except Exception as e: