PrashanthB461 commited on
Commit
550ca2a
·
verified ·
1 Parent(s): d7cc76c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -85
app.py CHANGED
@@ -51,22 +51,22 @@ CONFIG = {
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
  "FRAME_SKIP": {
54
- "no_helmet": 3,
55
- "no_harness": 2,
56
- "unsafe_posture": 2,
57
- "unsafe_zone": 2,
58
- "improper_tool_use": 2
59
  },
60
- "MAX_PROCESSING_TIME": 60,
61
  "CONFIDENCE_THRESHOLDS": {
62
- "no_helmet": 0.3,
63
- "no_harness": 0.2,
64
- "unsafe_posture": 0.2,
65
- "unsafe_zone": 0.2,
66
- "improper_tool_use": 0.2
67
  },
68
  "IOU_THRESHOLD": 0.4,
69
- "MIN_VIOLATION_FRAMES": 3
 
70
  }
71
 
72
  # Setup logging
@@ -296,35 +296,35 @@ def process_video(video_data):
296
  violations = []
297
  snapshots = []
298
  frame_count = 0
299
- start_time = time.time()
300
  fps = video.get(cv2.CAP_PROP_FPS)
301
  if fps <= 0:
302
  fps = 30 # Default assumption if FPS cannot be determined
 
 
303
 
304
  workers = []
305
  violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
306
  confirmed_violations = {} # Track confirmed violations per worker
307
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
308
  helmet_compliance = {} # Track workers with helmets
309
-
310
- logger.info(f"Processing video with FPS: {fps}")
311
- logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
 
313
  while True:
314
  ret, frame = video.read()
315
  if not ret:
316
  break
317
 
318
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
319
- logger.info("Processing time limit reached")
320
- break
321
-
322
  current_time = frame_count / fps
323
  min_frame_skip = min(CONFIG["FRAME_SKIP"].values())
324
  if frame_count % min_frame_skip != 0:
325
  frame_count += 1
326
  continue
327
 
 
 
 
 
328
  # Run detection on this frame
329
  results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], agnostic_nms=True)
330
 
@@ -348,11 +348,12 @@ def process_video(video_data):
348
 
349
  current_detections.append({
350
  "frame": frame_count,
351
- "violation": label, # Corrected key
352
  "confidence": round(conf, 2),
353
  "bounding_box": bbox,
354
  "timestamp": current_time
355
  })
 
356
 
357
  logger.debug(f"Frame {frame_count}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
358
 
@@ -363,7 +364,7 @@ def process_video(video_data):
363
  logger.error(f"Invalid detection, missing 'violation' key: {detection}")
364
  continue
365
 
366
- # Skip No Helmet detection if worker is compliant
367
  if violation_type == "no_helmet":
368
  matched_worker = None
369
  max_iou = 0
@@ -373,9 +374,18 @@ def process_video(video_data):
373
  max_iou = iou
374
  matched_worker = worker
375
 
376
- if matched_worker and matched_worker["id"] in helmet_compliance:
377
- logger.debug(f"Worker {matched_worker['id']} has helmet, skipping no_helmet violation")
378
- continue
 
 
 
 
 
 
 
 
 
379
 
380
  # Find or create worker
381
  matched_worker = None
@@ -399,6 +409,9 @@ def process_video(video_data):
399
  "first_seen": current_time,
400
  "last_seen": current_time
401
  })
 
 
 
402
 
403
  # Skip if this violation type is already confirmed for this worker
404
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
@@ -407,20 +420,14 @@ def process_video(video_data):
407
 
408
  detection["worker_id"] = worker_id
409
  violation_history[violation_type].append(detection)
410
-
411
- # Update helmet compliance (simulate by checking if No Helmet confidence is low)
412
- if violation_type == "no_helmet" and detection["confidence"] < 0.5:
413
- helmet_compliance[worker_id] = True
414
- logger.debug(f"Worker {worker_id} marked as helmet compliant")
415
 
416
  # Clean up old workers
417
  workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
418
 
419
  frame_count += 1
420
 
421
- video.release()
422
- os.remove(video_path)
423
-
424
  # Process violation history to confirm persistent violations
425
  for violation_type, detections in violation_history.items():
426
  if not detections:
@@ -440,9 +447,14 @@ def process_video(video_data):
440
  continue
441
 
442
  # Skip No Helmet if worker is compliant
443
- if violation_type == "no_helmet" and worker_id in helmet_compliance:
444
- logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
445
- continue
 
 
 
 
 
446
 
447
  best_detection = max(worker_dets, key=lambda x: x["confidence"])
448
  violations.append(best_detection)
@@ -455,22 +467,29 @@ def process_video(video_data):
455
  cap = cv2.VideoCapture(video_path)
456
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
457
  ret, snapshot_frame = cap.read()
458
- cap.release()
 
 
 
 
459
 
460
- if ret:
461
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
462
-
463
- snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
464
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
465
- cv2.imwrite(snapshot_path, snapshot_frame)
466
- snapshots.append({
467
- "violation": violation_type,
468
- "frame": best_detection["frame"],
469
- "snapshot_path": snapshot_path,
470
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
471
- })
472
- snapshot_taken[violation_type] = True
473
- logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
 
 
 
474
 
475
  if not violations:
476
  logger.info("No persistent violations detected")
@@ -514,44 +533,46 @@ def gradio_interface(video_file):
514
  if not video_file:
515
  return "No file uploaded.", "", "No file uploaded.", "", ""
516
  try:
517
- yield "Processing video... please wait.", "", "", "", ""
518
-
519
  with open(video_file, "rb") as f:
520
  video_data = f.read()
521
 
522
- result = process_video(video_data)
523
-
524
- if result.get("message"):
525
- yield result["message"], "", "", "", ""
526
- return
527
-
528
- violation_table = "No violations detected."
529
- if result["violations"]:
530
- header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
531
- separator = "|------------------------|---------------|------------|-----------|\n"
532
- rows = []
533
- violation_name_map = CONFIG["DISPLAY_NAMES"]
534
- for v in result["violations"]:
535
- display_name = violation_name_map.get(v.get("violation", "Unknown"), "Unknown")
536
- row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |"
537
- rows.append(row)
538
- violation_table = header + separator + "\n".join(rows)
539
-
540
- snapshots_text = "No snapshots captured."
541
- if result["snapshots"]:
542
- violation_name_map = CONFIG["DISPLAY_NAMES"]
543
- snapshots_text = "\n".join(
544
- f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: ![]({s.get('snapshot_base64', '')})"
545
- for s in result["snapshots"]
546
- )
547
 
548
- yield (
549
- violation_table,
550
- f"Safety Score: {result['score']}%",
551
- snapshots_text,
552
- f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
553
- result["violation_details_url"] or "N/A"
554
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  except Exception as e:
556
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
557
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
 
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
  "FRAME_SKIP": {
54
+ "no_helmet": 2, # Reduced to process more frames
55
+ "no_harness": 1,
56
+ "unsafe_posture": 1,
57
+ "unsafe_zone": 1,
58
+ "improper_tool_use": 1
59
  },
 
60
  "CONFIDENCE_THRESHOLDS": {
61
+ "no_helmet": 0.5, # Increased to reduce false positives
62
+ "no_harness": 0.15, # Lowered to improve detection
63
+ "unsafe_posture": 0.15,
64
+ "unsafe_zone": 0.15,
65
+ "improper_tool_use": 0.15
66
  },
67
  "IOU_THRESHOLD": 0.4,
68
+ "MIN_VIOLATION_FRAMES": 3,
69
+ "HELMET_CONFIDENCE_THRESHOLD": 0.7 # Require high confidence for No Helmet violation
70
  }
71
 
72
  # Setup logging
 
296
  violations = []
297
  snapshots = []
298
  frame_count = 0
299
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
300
  fps = video.get(cv2.CAP_PROP_FPS)
301
  if fps <= 0:
302
  fps = 30 # Default assumption if FPS cannot be determined
303
+ video_duration = total_frames / fps
304
+ logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")
305
 
306
  workers = []
307
  violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
  confirmed_violations = {} # Track confirmed violations per worker
309
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
310
  helmet_compliance = {} # Track workers with helmets
311
+ detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()} # Track detection counts
 
 
312
 
313
  while True:
314
  ret, frame = video.read()
315
  if not ret:
316
  break
317
 
 
 
 
 
318
  current_time = frame_count / fps
319
  min_frame_skip = min(CONFIG["FRAME_SKIP"].values())
320
  if frame_count % min_frame_skip != 0:
321
  frame_count += 1
322
  continue
323
 
324
+ # Yield progress update
325
+ progress = (frame_count / total_frames) * 100
326
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_count}/{total_frames})", "", "", "", ""
327
+
328
  # Run detection on this frame
329
  results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], agnostic_nms=True)
330
 
 
348
 
349
  current_detections.append({
350
  "frame": frame_count,
351
+ "violation": label,
352
  "confidence": round(conf, 2),
353
  "bounding_box": bbox,
354
  "timestamp": current_time
355
  })
356
+ detection_counts[label] += 1
357
 
358
  logger.debug(f"Frame {frame_count}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
359
 
 
364
  logger.error(f"Invalid detection, missing 'violation' key: {detection}")
365
  continue
366
 
367
+ # Helmet compliance check
368
  if violation_type == "no_helmet":
369
  matched_worker = None
370
  max_iou = 0
 
374
  max_iou = iou
375
  matched_worker = worker
376
 
377
+ if matched_worker:
378
+ worker_id = matched_worker["id"]
379
+ # Require high confidence and persistence for No Helmet violation
380
+ if worker_id not in helmet_compliance:
381
+ helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
382
+ helmet_compliance[worker_id]["no_helmet_frames"] += 1
383
+ if detection["confidence"] < CONFIG["HELMET_CONFIDENCE_THRESHOLD"]:
384
+ helmet_compliance[worker_id]["compliant"] = True
385
+ logger.debug(f"Worker {worker_id} marked as helmet compliant due to low no_helmet confidence")
386
+ if helmet_compliance[worker_id]["compliant"]:
387
+ logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
388
+ continue
389
 
390
  # Find or create worker
391
  matched_worker = None
 
409
  "first_seen": current_time,
410
  "last_seen": current_time
411
  })
412
+ # Initialize helmet compliance for new worker
413
+ if worker_id not in helmet_compliance:
414
+ helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
415
 
416
  # Skip if this violation type is already confirmed for this worker
417
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
 
420
 
421
  detection["worker_id"] = worker_id
422
  violation_history[violation_type].append(detection)
 
 
 
 
 
423
 
424
  # Clean up old workers
425
  workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
426
 
427
  frame_count += 1
428
 
429
+ logger.info(f"Detection counts: {detection_counts}")
430
+
 
431
  # Process violation history to confirm persistent violations
432
  for violation_type, detections in violation_history.items():
433
  if not detections:
 
447
  continue
448
 
449
  # Skip No Helmet if worker is compliant
450
+ if violation_type == "no_helmet":
451
+ if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
452
+ logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
453
+ continue
454
+ # Require persistent No Helmet detections
455
+ if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2:
456
+ logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections")
457
+ continue
458
 
459
  best_detection = max(worker_dets, key=lambda x: x["confidence"])
460
  violations.append(best_detection)
 
467
  cap = cv2.VideoCapture(video_path)
468
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
469
  ret, snapshot_frame = cap.read()
470
+ if not ret:
471
+ logger.error(f"Failed to capture snapshot for {violation_type} at frame {best_detection['frame']}")
472
+ cap.release()
473
+ continue
474
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
475
 
476
+ snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
477
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
478
+ cv2.imwrite(snapshot_path, snapshot_frame)
479
+ snapshots.append({
480
+ "violation": violation_type,
481
+ "frame": best_detection["frame"],
482
+ "snapshot_path": snapshot_path,
483
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
484
+ })
485
+ snapshot_taken[violation_type] = True
486
+ logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
487
+ cap.release()
488
+
489
+ # Clean up video file after snapshots are captured
490
+ video.release()
491
+ os.remove(video_path)
492
+ logger.info(f"Video file {video_path} removed")
493
 
494
  if not violations:
495
  logger.info("No persistent violations detected")
 
533
  if not video_file:
534
  return "No file uploaded.", "", "No file uploaded.", "", ""
535
  try:
 
 
536
  with open(video_file, "rb") as f:
537
  video_data = f.read()
538
 
539
+ # Use generator to yield progress updates
540
+ for status, violations_table, score, snapshots_text, record_id, details_url in process_video(video_data):
541
+ if status.startswith("Processing video"):
542
+ yield status, "", "", "", ""
543
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
+ if status.get("message"):
546
+ yield status["message"], "", "", "", ""
547
+ return
548
+
549
+ violation_table = "No violations detected."
550
+ if status["violations"]:
551
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
552
+ separator = "|------------------------|---------------|------------|-----------|\n"
553
+ rows = []
554
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
555
+ for v in status["violations"]:
556
+ display_name = violation_name_map.get(v.get("violation", "Unknown"), "Unknown")
557
+ row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |"
558
+ rows.append(row)
559
+ violation_table = header + separator + "\n".join(rows)
560
+
561
+ snapshots_text = "No snapshots captured."
562
+ if status["snapshots"]:
563
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
564
+ snapshots_text = "\n".join(
565
+ f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: ![]({s.get('snapshot_base64', '')})"
566
+ for s in status["snapshots"]
567
+ )
568
+
569
+ yield (
570
+ violation_table,
571
+ f"Safety Score: {status['score']}%",
572
+ snapshots_text,
573
+ f"Salesforce Record ID: {status['salesforce_record_id'] or 'N/A'}",
574
+ status["violation_details_url"] or "N/A"
575
+ )
576
  except Exception as e:
577
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
578
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""