PrashanthB461 commited on
Commit
83ad321
·
verified ·
1 Parent(s): 5b58187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -55
app.py CHANGED
@@ -65,7 +65,7 @@ CONFIG = {
65
  "FRAME_SKIP": 2, # Process every 2nd frame for speed
66
  "BATCH_SIZE": 16, # Frames per batch
67
  "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all CPU cores except one
68
- "VIOLATION_COOLDOWN": 5.0 # Seconds before same violation can be detected again for same worker
69
  }
70
 
71
  # Setup logging
@@ -307,12 +307,10 @@ def process_video(video_data):
307
 
308
  workers = []
309
  violations = []
 
310
  snapshots = []
311
  start_time = time.time()
312
  frame_skip = CONFIG["FRAME_SKIP"]
313
-
314
- # Track active violations to prevent duplicate detections
315
- active_violations = {} # Format: {worker_id: {violation_type: last_detection_time}}
316
 
317
  # Process frames in batches
318
  while True:
@@ -386,54 +384,60 @@ def process_video(video_data):
386
  "last_seen": current_time
387
  })
388
 
389
- # Check if this worker already has this violation type in cooldown period
390
- if worker_id in active_violations:
391
- if label in active_violations[worker_id]:
392
- last_detection_time = active_violations[worker_id][label]
393
- if current_time - last_detection_time < CONFIG["VIOLATION_COOLDOWN"]:
394
- continue # Skip this detection as it's within cooldown period
395
-
396
- # Create detection record
397
- detection = {
398
- "frame": frame_idx,
399
- "violation": label,
400
- "confidence": round(conf, 2),
401
- "bounding_box": bbox,
402
- "timestamp": current_time,
403
- "worker_id": worker_id
404
- }
405
-
406
- # Add to violations list
407
- violations.append(detection)
408
 
409
- # Update active violations
410
- if worker_id not in active_violations:
411
- active_violations[worker_id] = {}
412
- active_violations[worker_id][label] = current_time
413
-
414
- # Capture snapshot for this violation
415
- snapshot_frame = batch_frames[i].copy()
416
- snapshot_frame = draw_detections(snapshot_frame, [detection])
417
- snapshot_filename = f"{label}_{worker_id}_{frame_idx}.jpg"
418
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
419
- cv2.imwrite(snapshot_path, snapshot_frame)
420
- snapshots.append({
421
- "violation": label,
422
- "frame": frame_idx,
423
- "worker_id": worker_id,
424
- "snapshot_path": snapshot_path,
425
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
426
- })
427
-
428
- # Remove inactive workers and their violations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
430
-
431
- # Clean up active_violations for workers no longer tracked
432
- active_violations = {
433
- worker_id: violations
434
- for worker_id, violations in active_violations.items()
435
- if any(w["id"] == worker_id for w in workers)
436
- }
437
 
438
  cap.release()
439
  os.remove(video_path)
@@ -449,16 +453,26 @@ def process_video(video_data):
449
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
450
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
451
 
452
- violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
453
- violation_table += "|------------------------|---------------|------------|-----------|\n"
454
  for v in sorted(violations, key=lambda x: x["timestamp"]):
455
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
456
- row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
 
 
 
 
 
 
 
 
 
457
  violation_table += row
458
 
459
  snapshots_text = "\n".join(
460
- f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} (Worker {s['worker_id']}) at frame {s['frame']}: ![]({s['snapshot_base64']})"
461
- for s in snapshots
 
462
  ) if snapshots else "No snapshots captured."
463
 
464
  yield (
 
65
  "FRAME_SKIP": 2, # Process every 2nd frame for speed
66
  "BATCH_SIZE": 16, # Frames per batch
67
  "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all CPU cores except one
68
+ "VIOLATION_COOLDOWN": 5.0 # Seconds before same violation can be reported again for same worker
69
  }
70
 
71
  # Setup logging
 
307
 
308
  workers = []
309
  violations = []
310
+ violation_history = {} # Track when violations were last reported for each worker
311
  snapshots = []
312
  start_time = time.time()
313
  frame_skip = CONFIG["FRAME_SKIP"]
 
 
 
314
 
315
  # Process frames in batches
316
  while True:
 
384
  "last_seen": current_time
385
  })
386
 
387
+ # Check if we should report this violation (cooldown period)
388
+ violation_key = f"{worker_id}_{label}"
389
+ last_reported = violation_history.get(violation_key, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ if current_time - last_reported >= CONFIG["VIOLATION_COOLDOWN"]:
392
+ detection = {
393
+ "frame": frame_idx,
394
+ "violation": label,
395
+ "confidence": round(conf, 2),
396
+ "bounding_box": bbox,
397
+ "timestamp": current_time,
398
+ "worker_id": worker_id
399
+ }
400
+
401
+ # Track helmet violations with stricter criteria
402
+ if detection["violation"] == "no_helmet":
403
+ if conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
404
+ # Only report if we have multiple detections
405
+ if worker_id not in violation_history:
406
+ violation_history[worker_id] = []
407
+ violation_history[worker_id].append(detection)
408
+
409
+ # Check if we have enough detections to confirm
410
+ if len(violation_history[worker_id]) >= CONFIG["MIN_VIOLATION_FRAMES"]:
411
+ # Select the detection with the highest confidence
412
+ best_detection = max(violation_history[worker_id], key=lambda x: x["confidence"])
413
+ violations.append(best_detection)
414
+ violation_history[violation_key] = current_time
415
+
416
+ # Capture snapshot for confirmed no_helmet violation
417
+ cap_snapshot = cv2.VideoCapture(video_path)
418
+ cap_snapshot.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
419
+ ret, snapshot_frame = cap_snapshot.read()
420
+ if ret:
421
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
422
+ snapshot_filename = f"violation_{worker_id}_{label}_{best_detection['frame']}.jpg"
423
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
+ cv2.imwrite(snapshot_path, snapshot_frame)
425
+ snapshots.append({
426
+ "violation": label,
427
+ "frame": best_detection["frame"],
428
+ "timestamp": best_detection["timestamp"],
429
+ "worker_id": worker_id,
430
+ "snapshot_path": snapshot_path,
431
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
432
+ })
433
+ cap_snapshot.release()
434
+ del violation_history[worker_id] # Reset for this worker
435
+ else:
436
+ violations.append(detection)
437
+ violation_history[violation_key] = current_time
438
+
439
+ # Remove inactive workers
440
  workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
 
 
 
 
 
 
 
441
 
442
  cap.release()
443
  os.remove(video_path)
 
453
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
454
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
455
 
456
+ violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID | Snapshot |\n"
457
+ violation_table += "|------------------------|---------------|------------|-----------|----------|\n"
458
  for v in sorted(violations, key=lambda x: x["timestamp"]):
459
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
460
+ # Find matching snapshot if exists
461
+ snapshot_markdown = ""
462
+ for s in snapshots:
463
+ if (s["worker_id"] == v.get("worker_id") and
464
+ s["violation"] == v.get("violation") and
465
+ abs(s["timestamp"] - v.get("timestamp", 0)) < 0.5):
466
+ snapshot_markdown = f"[View]({s['snapshot_url']})"
467
+ break
468
+
469
+ row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} | {snapshot_markdown} |\n"
470
  violation_table += row
471
 
472
  snapshots_text = "\n".join(
473
+ f"### {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} (Worker {s['worker_id']}) at {s['timestamp']:.2f}s\n"
474
+ f"![Violation]({s['snapshot_url']})"
475
+ for s in sorted(snapshots, key=lambda x: x["timestamp"])
476
  ) if snapshots else "No snapshots captured."
477
 
478
  yield (