PrashanthB461 commited on
Commit
6da60dd
·
verified ·
1 Parent(s): 125b2ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -111
app.py CHANGED
@@ -135,6 +135,36 @@ def calculate_iou(box1, box2):
135
 
136
  return intersection_area / union_area
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def generate_violation_pdf(violations, score):
139
  try:
140
  pdf_filename = f"violations_{int(time.time())}.pdf"
@@ -304,13 +334,12 @@ def process_video(video_data):
304
 
305
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
306
 
307
- # Track workers only for helmet violations
308
- helmet_workers = {} # {worker_id: {"first_detected": timestamp, "bbox": bbox}}
309
  violations = []
 
310
  snapshots = []
311
  start_time = time.time()
312
  frame_skip = CONFIG["FRAME_SKIP"]
313
- next_worker_id = 1
314
 
315
  # Process frames in batches
316
  while True:
@@ -354,8 +383,6 @@ def process_video(video_data):
354
 
355
  # Process detections in this frame
356
  boxes = result.boxes
357
- frame_violations = set() # Track violations in this frame to avoid duplicates
358
-
359
  for box in boxes:
360
  cls = int(box.cls)
361
  conf = float(box.conf)
@@ -365,99 +392,78 @@ def process_video(video_data):
365
  continue
366
 
367
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
368
-
369
- # For no_helmet violations, track workers and only record first occurrence
370
- if label == "no_helmet":
371
- # Check if this is a known worker
372
- worker_id = None
373
- for w_id, worker in helmet_workers.items():
374
- iou = calculate_iou(bbox, worker["bbox"])
375
- if iou > 0.4: # IOU threshold
376
- worker_id = w_id
377
- # Update worker's position
378
- helmet_workers[w_id]["bbox"] = bbox
379
- helmet_workers[w_id]["last_seen"] = current_time
380
- break
381
-
382
- # If new worker, assign ID and record first violation
383
- if worker_id is None:
384
- worker_id = next_worker_id
385
- next_worker_id += 1
386
- helmet_workers[worker_id] = {
387
- "bbox": bbox,
388
- "first_seen": current_time,
389
- "last_seen": current_time
390
- }
391
-
392
- # Only record first violation for this worker
393
- detection = {
394
- "frame": frame_idx,
395
- "violation": label,
396
- "confidence": round(conf, 2),
397
- "bounding_box": bbox,
398
- "timestamp": current_time,
399
- "worker_id": worker_id
400
- }
401
- violations.append(detection)
402
-
403
- # Capture snapshot
404
- cap_snapshot = cv2.VideoCapture(video_path)
405
- cap_snapshot.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
406
- ret, snapshot_frame = cap_snapshot.read()
407
- if ret:
408
- snapshot_frame = draw_detections(snapshot_frame, [detection])
409
- snapshot_filename = f"no_helmet_{worker_id}_{frame_idx}.jpg"
410
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
411
- cv2.imwrite(snapshot_path, snapshot_frame)
412
- snapshots.append({
413
- "violation": "no_helmet",
414
- "frame": frame_idx,
415
- "worker_id": worker_id,
416
- "snapshot_path": snapshot_path,
417
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
418
- })
419
- cap_snapshot.release()
420
  else:
421
- # For other violations, only record if not already detected in this frame
422
- if label not in frame_violations:
423
- detection = {
424
- "frame": frame_idx,
425
- "violation": label,
426
- "confidence": round(conf, 2),
427
- "bounding_box": bbox,
428
- "timestamp": current_time
429
- }
430
- violations.append(detection)
431
- frame_violations.add(label)
432
-
433
- # Capture snapshot for first occurrence of this violation type
434
- cap_snapshot = cv2.VideoCapture(video_path)
435
- cap_snapshot.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
436
- ret, snapshot_frame = cap_snapshot.read()
437
- if ret:
438
- snapshot_frame = draw_detections(snapshot_frame, [detection])
439
- snapshot_filename = f"{label}_{frame_idx}.jpg"
440
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
441
- cv2.imwrite(snapshot_path, snapshot_frame)
442
- snapshots.append({
443
- "violation": label,
444
- "frame": frame_idx,
445
- "snapshot_path": snapshot_path,
446
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
447
- })
448
- cap_snapshot.release()
449
 
450
  # Remove inactive workers
451
- inactive_workers = [w_id for w_id, worker in helmet_workers.items()
452
- if current_time - worker["last_seen"] > CONFIG["WORKER_TRACKING_DURATION"]]
453
- for w_id in inactive_workers:
454
- del helmet_workers[w_id]
455
 
456
  cap.release()
457
  os.remove(video_path)
458
  processing_time = time.time() - start_time
459
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  # Generate results
462
  if not violations:
463
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
@@ -467,33 +473,22 @@ def process_video(video_data):
467
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
468
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
469
 
470
- violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID (Helmet Only) |\n"
471
- violation_table += "|------------------------|---------------|------------|--------------------------|\n"
472
  for v in sorted(violations, key=lambda x: x["timestamp"]):
473
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
474
- worker_id = v.get("worker_id", "N/A") if v.get("violation") == "no_helmet" else "N/A"
475
- row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {worker_id} |\n"
476
  violation_table += row
477
 
478
- # Create HTML for snapshots with clickable links
479
- snapshots_html = "<div style='display: flex; flex-wrap: wrap; gap: 10px;'>"
480
- for s in snapshots:
481
- display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
482
- worker_text = f"Worker {s['worker_id']}" if "worker_id" in s else ""
483
- snapshots_html += f"""
484
- <div style='text-align: center; margin: 10px;'>
485
- <a href='{s['snapshot_url']}' target='_blank'>
486
- <img src='{s['snapshot_url']}' style='max-width: 200px; max-height: 150px;'/>
487
- </a>
488
- <p>{display_name} at frame {s['frame']} {worker_text}</p>
489
- </div>
490
- """
491
- snapshots_html += "</div>"
492
 
493
  yield (
494
  violation_table,
495
  f"Safety Score: {score}%",
496
- snapshots_html,
497
  f"Salesforce Record ID: {report_id or 'N/A'}",
498
  final_pdf_url or "N/A"
499
  )
@@ -512,8 +507,8 @@ def gradio_interface(video_file):
512
  with open(video_file, "rb") as f:
513
  video_data = f.read()
514
 
515
- for status, score, snapshots_html, record_id, details_url in process_video(video_data):
516
- yield status, score, snapshots_html, record_id, details_url
517
  except Exception as e:
518
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
519
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
@@ -524,7 +519,7 @@ interface = gr.Interface(
524
  outputs=[
525
  gr.Markdown(label="Detected Safety Violations"),
526
  gr.Textbox(label="Compliance Score"),
527
- gr.HTML(label="Violation Snapshots (Click to enlarge)"),
528
  gr.Textbox(label="Salesforce Record ID"),
529
  gr.Textbox(label="Violation Details URL")
530
  ],
 
135
 
136
  return intersection_area / union_area
137
 
138
+ def process_frame_batch(frame_batch, frame_indices, fps):
139
+ batch_results = []
140
+ results = model(frame_batch, device=device, conf=0.1, verbose=False)
141
+
142
+ for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
143
+ current_time = frame_idx / fps
144
+ detections = []
145
+
146
+ boxes = result.boxes
147
+ for box in boxes:
148
+ cls = int(box.cls)
149
+ conf = float(box.conf)
150
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
151
+
152
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
153
+ continue
154
+
155
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
156
+ detections.append({
157
+ "frame": frame_idx,
158
+ "violation": label,
159
+ "confidence": round(conf, 2),
160
+ "bounding_box": bbox,
161
+ "timestamp": current_time
162
+ })
163
+
164
+ batch_results.append((frame_idx, detections))
165
+
166
+ return batch_results
167
+
168
  def generate_violation_pdf(violations, score):
169
  try:
170
  pdf_filename = f"violations_{int(time.time())}.pdf"
 
334
 
335
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
336
 
337
+ workers = []
 
338
  violations = []
339
+ helmet_violations = {}
340
  snapshots = []
341
  start_time = time.time()
342
  frame_skip = CONFIG["FRAME_SKIP"]
 
343
 
344
  # Process frames in batches
345
  while True:
 
383
 
384
  # Process detections in this frame
385
  boxes = result.boxes
 
 
386
  for box in boxes:
387
  cls = int(box.cls)
388
  conf = float(box.conf)
 
392
  continue
393
 
394
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
395
+ detection = {
396
+ "frame": frame_idx,
397
+ "violation": label,
398
+ "confidence": round(conf, 2),
399
+ "bounding_box": bbox,
400
+ "timestamp": current_time
401
+ }
402
+
403
+ # Worker tracking
404
+ worker_id = None
405
+ max_iou = 0
406
+ for idx, worker in enumerate(workers):
407
+ iou = calculate_iou(bbox, worker["bbox"])
408
+ if iou > max_iou and iou > 0.4: # IOU threshold
409
+ max_iou = iou
410
+ worker_id = worker["id"]
411
+ workers[idx]["bbox"] = bbox
412
+ workers[idx]["last_seen"] = current_time
413
+
414
+ if worker_id is None:
415
+ worker_id = len(workers) + 1
416
+ workers.append({
417
+ "id": worker_id,
418
+ "bbox": bbox,
419
+ "first_seen": current_time,
420
+ "last_seen": current_time
421
+ })
422
+
423
+ detection["worker_id"] = worker_id
424
+
425
+ # Track helmet violations with stricter criteria
426
+ if detection["violation"] == "no_helmet":
427
+ # Only include high-confidence no_helmet detections
428
+ if conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
429
+ if worker_id not in helmet_violations:
430
+ helmet_violations[worker_id] = []
431
+ helmet_violations[worker_id].append(detection)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  else:
433
+ violations.append(detection)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  # Remove inactive workers
436
+ workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
 
 
 
437
 
438
  cap.release()
439
  os.remove(video_path)
440
  processing_time = time.time() - start_time
441
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
442
 
443
+ # Confirm helmet violations (require multiple detections)
444
+ for worker_id, detections in helmet_violations.items():
445
+ if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
446
+ # Select the detection with the highest confidence
447
+ best_detection = max(detections, key=lambda x: x["confidence"])
448
+ violations.append(best_detection)
449
+
450
+ # Capture snapshot for confirmed no_helmet violation
451
+ cap = cv2.VideoCapture(video_path)
452
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
453
+ ret, snapshot_frame = cap.read()
454
+ if ret:
455
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
456
+ snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
457
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
458
+ cv2.imwrite(snapshot_path, snapshot_frame)
459
+ snapshots.append({
460
+ "violation": "no_helmet",
461
+ "frame": best_detection["frame"],
462
+ "snapshot_path": snapshot_path,
463
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
464
+ })
465
+ cap.release()
466
+
467
  # Generate results
468
  if not violations:
469
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
 
473
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
474
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
475
 
476
+ violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
477
+ violation_table += "|------------------------|---------------|------------|-----------|\n"
478
  for v in sorted(violations, key=lambda x: x["timestamp"]):
479
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
480
+ row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
 
481
  violation_table += row
482
 
483
+ snapshots_text = "\n".join(
484
+ f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
485
+ for s in snapshots
486
+ ) if snapshots else "No snapshots captured."
 
 
 
 
 
 
 
 
 
 
487
 
488
  yield (
489
  violation_table,
490
  f"Safety Score: {score}%",
491
+ snapshots_text,
492
  f"Salesforce Record ID: {report_id or 'N/A'}",
493
  final_pdf_url or "N/A"
494
  )
 
507
  with open(video_file, "rb") as f:
508
  video_data = f.read()
509
 
510
+ for status, score, snapshots_text, record_id, details_url in process_video(video_data):
511
+ yield status, score, snapshots_text, record_id, details_url
512
  except Exception as e:
513
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
514
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
 
519
  outputs=[
520
  gr.Markdown(label="Detected Safety Violations"),
521
  gr.Textbox(label="Compliance Score"),
522
+ gr.Markdown(label="Snapshots"),
523
  gr.Textbox(label="Salesforce Record ID"),
524
  gr.Textbox(label="Violation Details URL")
525
  ],