PrashanthB461 commited on
Commit
1aa6cf2
Β·
verified Β·
1 Parent(s): 0b86ea2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +409 -262
app.py CHANGED
@@ -24,7 +24,6 @@ import pytz
24
  import shutil
25
  import tempfile
26
  from scipy.spatial import distance
27
- from scipy.optimize import linear_sum_assignment # <-- Added for improved tracking
28
  import asyncio
29
  from functools import partial
30
  from concurrent.futures import ThreadPoolExecutor
@@ -248,7 +247,7 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
248
  # 2. Upload ContentVersion to Salesforce
249
  title = f"Safety_Report_{datetime.now(IST).strftime('%Y%m%d_%H%M%S')}"
250
  b64_pdf = base64.b64encode(pdf_bytes).decode('utf-8')
251
-
252
  logger.info(f"Uploading PDF '{title}.pdf' to Salesforce...")
253
  cv_result = sf.ContentVersion.create({
254
  'Title': title,
@@ -259,7 +258,7 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
259
  if not cv_result.get('success'):
260
  logger.error(f"Failed to create ContentVersion: {cv_result.get('errors')}")
261
  return None, None
262
-
263
  content_version_id = cv_result['id']
264
  logger.info(f"Successfully created ContentVersion with ID: {content_version_id}")
265
 
@@ -278,7 +277,7 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
278
  'LinkedEntityId': record_id,
279
  'ShareType': 'V' # V = Viewer
280
  } for record_id in record_ids]
281
-
282
  link_success_count = 0
283
  for payload in link_payloads:
284
  try:
@@ -289,7 +288,7 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
289
  logger.warning(f"Failed to link to {payload['LinkedEntityId']}: {link_result.get('errors')}")
290
  except Exception as e:
291
  logger.error(f"Error creating ContentDocumentLink for {payload['LinkedEntityId']}: {e}")
292
-
293
  logger.info(f"Successfully created {link_success_count}/{len(record_ids)} links.")
294
 
295
  # 5. Construct URL and Update records
@@ -299,7 +298,7 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
299
 
300
  update_payloads = [{'Id': record_id, 'PDF_Report_URL__c': pdf_url} for record_id in record_ids]
301
  update_results = sf.bulk.Safety_Violation_Log__c.update(update_payloads)
302
-
303
  successful_updates = sum(1 for res in update_results if res.get('success'))
304
  logger.info(f"Successfully updated {successful_updates}/{len(record_ids)} records with the PDF URL.")
305
 
@@ -307,7 +306,7 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
307
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='report_') as temp_pdf:
308
  temp_pdf.write(pdf_bytes)
309
  temp_pdf_path = temp_pdf.name
310
-
311
  logger.info(f"Salesforce report URL: {pdf_url}")
312
  logger.info(f"Temporary local report for download: {temp_pdf_path}")
313
 
@@ -316,128 +315,103 @@ def generate_and_upload_report_to_salesforce(sf, violations, record_ids):
316
  except Exception as e:
317
  logger.error(f"Error in Salesforce PDF report generation/upload: {e}", exc_info=True)
318
  return None, None
319
-
320
- # --- Safety Violation Detector Class (ENHANCED FOR ACCURACY) ---
321
  class SafetyViolationDetector:
322
  def __init__(self):
323
- # Detection thresholds
324
  self.helmet_threshold = 0.75
325
  self.person_threshold = 0.60
326
  self.unsafe_distance = 50 # pixels
 
327
 
328
  # Unauthorized zones (x1, y1, x2, y2)
329
  self.unauthorized_zones = [
330
- [100, 100, 300, 300],
331
- [400, 200, 600, 400]
332
  ]
333
-
334
- # Tracking parameters
335
- self.person_tracker = {}
336
- self.next_person_id = 1
337
- self.max_age = 30 # Max frames to keep a track without a new detection
338
- self.min_hits = 3 # Min consecutive detections to consider a track valid
339
- self.iou_threshold = 0.3 # Min Intersection over Union for a match
340
 
341
- # Violation tracking per session
342
  self.session_violations = {}
 
 
 
 
 
 
 
343
 
344
  def reset_session(self):
345
- """Resets tracking and violation history for a new video or stream."""
 
346
  self.person_tracker = {}
 
347
  self.next_person_id = 1
348
- self.session_violations = {}
349
- logger.info("Tracker and violation session reset.")
350
 
351
  def has_reported_violation(self, person_id, violation_type):
352
- """Checks if a specific violation has already been reported for a person in this session."""
353
- return person_id in self.session_violations and violation_type in self.session_violations[person_id]
 
354
 
355
  def mark_violation_reported(self, person_id, violation_type, timestamp):
356
- """Marks a violation as reported for a person."""
357
  if person_id not in self.session_violations:
358
  self.session_violations[person_id] = {}
359
- self.session_violations[person_id][violation_type] = timestamp
 
 
 
360
 
361
- def _iou(self, box1, box2):
362
- """Calculates Intersection over Union between two bounding boxes."""
363
- x1 = max(box1[0], box2[0])
364
- y1 = max(box1[1], box2[1])
365
- x2 = min(box1[2], box2[2])
366
- y2 = min(box1[3], box2[3])
367
- intersection = max(0, x2 - x1) * max(0, y2 - y1)
368
- area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
369
- area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
370
- union = area1 + area2 - intersection
371
- return intersection / (union + 1e-6)
372
-
373
- def _euclidean_distance(self, point1, point2):
374
- return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)
375
 
376
- def update_trackers(self, detections, current_time):
377
- """
378
- Updates person trackers with new detections using the Hungarian algorithm for matching.
379
- This is the core of stable person tracking.
380
- """
381
- if not self.person_tracker:
382
- for det in detections:
383
- self._create_new_tracker(det, current_time)
384
- return
385
-
386
- tracked_ids = list(self.person_tracker.keys())
387
- tracked_boxes = [self.person_tracker[tid]['box'] for tid in tracked_ids]
388
-
389
- iou_matrix = np.zeros((len(tracked_ids), len(detections)))
390
- for t, trk_box in enumerate(tracked_boxes):
391
- for d, det in enumerate(detections):
392
- iou_matrix[t, d] = self._iou(trk_box, det['box'])
393
 
394
- # Use Hungarian algorithm to find optimal matches (maximize IOU)
395
- row_ind, col_ind = linear_sum_assignment(1 - iou_matrix)
396
-
397
- matched_indices = {r: c for r, c in zip(row_ind, col_ind) if iou_matrix[r, c] >= self.iou_threshold}
398
- unmatched_trackers = set(range(len(tracked_ids))) - set(matched_indices.keys())
399
- unmatched_detections = set(range(len(detections))) - set(matched_indices.values())
400
-
401
- # Update matched trackers
402
- for t_idx, d_idx in matched_indices.items():
403
- track_id = tracked_ids[t_idx]
404
- self.person_tracker[track_id].update({
405
- 'box': detections[d_idx]['box'],
406
- 'center': self._get_center(detections[d_idx]['box']),
407
- 'time_since_update': 0,
408
- 'hits': self.person_tracker[track_id]['hits'] + 1,
409
- 'hit_streak': self.person_tracker[track_id]['hit_streak'] + 1,
410
- 'last_seen': current_time
411
- })
412
-
413
- # Handle unmatched trackers (potential disappearance)
414
- for t_idx in unmatched_trackers:
415
- track_id = tracked_ids[t_idx]
416
- self.person_tracker[track_id]['time_since_update'] += 1
417
- self.person_tracker[track_id]['hit_streak'] = 0
418
-
419
- # Create new trackers for unmatched detections
420
- for d_idx in unmatched_detections:
421
- self._create_new_tracker(detections[d_idx], current_time)
422
-
423
- # Remove stale trackers
424
- self.person_tracker = {tid: tdata for tid, tdata in self.person_tracker.items() if tdata['time_since_update'] <= self.max_age}
425
-
426
- def _create_new_tracker(self, detection, current_time):
427
- self.person_tracker[self.next_person_id] = {
428
- 'id': self.next_person_id,
429
- 'box': detection['box'],
430
- 'center': self._get_center(detection['box']),
431
- 'time_since_update': 0,
432
- 'hits': 1,
433
- 'hit_streak': 1,
434
- 'first_seen': current_time,
435
- 'last_seen': current_time,
436
- }
437
- self.next_person_id += 1
438
-
439
- def _get_center(self, box):
440
- return ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
441
 
442
  def detect_violations(self, results, frame):
443
  current_time = time.time()
@@ -447,27 +421,49 @@ class SafetyViolationDetector:
447
  class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
448
  class_names = results[0].names
449
 
450
- # Step 1: Extract all person and helmet detections from the current frame
451
- person_detections = []
452
  helmets = []
 
453
  for box, conf, cls_id in zip(boxes, confidences, class_ids):
454
  class_name = class_names[cls_id]
455
  if class_name == "person" and conf >= self.person_threshold:
456
- person_detections.append({'box': box, 'confidence': conf})
 
 
 
 
 
 
457
  elif class_name == "hard hat" and conf >= self.helmet_threshold:
458
- helmets.append({'box': box, 'confidence': conf})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
- # Step 2: Update the tracker with the new person detections
461
- self.update_trackers(person_detections, current_time)
462
-
463
- # Step 3: Get a list of *confirmed* persons from the tracker
464
- persons_to_check = [
465
- data for data in self.person_tracker.values()
466
- if data['hit_streak'] >= self.min_hits and data['time_since_update'] == 0
467
- ]
468
 
469
- # Step 4: Check for violations for each confirmed person
470
- for person in persons_to_check:
471
  helmet_violation = self._check_helmet_violation(person, helmets, frame, current_time)
472
  if helmet_violation:
473
  violations.append(helmet_violation)
@@ -476,34 +472,77 @@ class SafetyViolationDetector:
476
  if unauthorized_violation:
477
  violations.append(unauthorized_violation)
478
 
479
- # Step 5: Check for pairwise violations like unsafe distance
480
- distance_violations = self._check_distance_violations(persons_to_check, frame, current_time)
481
  violations.extend(distance_violations)
482
 
 
 
 
483
  return violations
484
 
485
  def _check_helmet_violation(self, person, helmets, frame, current_time):
486
  person_id = person['id']
 
487
  violation_type = 'no_helmet'
488
 
489
  if self.has_reported_violation(person_id, violation_type):
490
  return None
491
 
492
- person_box = person['box']
493
  head_region = [
494
- person_box[0], person_box[1], person_box[2], person_box[1] + (person_box[3] - person_box[1]) * 0.4
 
 
 
495
  ]
496
 
497
- has_helmet = any(self._iou(helmet['box'], head_region) > 0.1 for helmet in helmets)
498
-
 
 
 
 
 
 
499
  if not has_helmet:
500
- self.mark_violation_reported(person_id, violation_type, current_time)
501
- self._annotate_frame(frame, person_box, person_id, "NO HELMET", (0, 0, 255))
502
- logger.info(f"NEW VIOLATION: No helmet detected for person {person_id}")
503
- return {
504
- 'type': violation_type, 'severity': 'Critical', 'person_id': person_id,
505
- 'timestamp': current_time
506
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  return None
508
 
509
  def _check_unauthorized_area(self, person, frame, current_time):
@@ -513,73 +552,190 @@ class SafetyViolationDetector:
513
  if self.has_reported_violation(person_id, violation_type):
514
  return None
515
 
 
 
 
516
  for zone in self.unauthorized_zones:
517
  zx1, zy1, zx2, zy2 = zone
518
- if (zx1 <= person['center'][0] <= zx2 and zy1 <= person['center'][1] <= zy2):
519
- self.mark_violation_reported(person_id, violation_type, current_time)
520
- cv2.rectangle(frame, (zx1, zy1), (zx2, zy2), (255, 0, 255), 2)
521
- self._annotate_frame(frame, person['box'], person_id, "UNAUTHORIZED", (255, 0, 255))
522
- logger.info(f"NEW VIOLATION: Unauthorized area detected for person {person_id}")
523
- return {
524
- 'type': violation_type, 'severity': 'High', 'person_id': person_id,
525
- 'zone': zone, 'timestamp': current_time
526
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  return None
528
 
529
  def _check_distance_violations(self, persons, frame, current_time):
530
  violations = []
531
- violation_type = 'unsafe_distance'
 
532
 
533
  for i in range(len(persons)):
534
- for j in range(i + 1, len(persons)):
535
- p1 = persons[i]
536
- p2 = persons[j]
537
- dist = self._euclidean_distance(p1['center'], p2['center'])
538
-
539
  if dist < self.unsafe_distance:
540
- # Report once per pair of individuals
541
- if self.has_reported_violation(p1['id'], f"{violation_type}_{p2['id']}") or \
542
- self.has_reported_violation(p2['id'], f"{violation_type}_{p1['id']}"):
 
 
 
543
  continue
544
-
545
- self.mark_violation_reported(p1['id'], f"{violation_type}_{p2['id']}", current_time)
546
- self.mark_violation_reported(p2['id'], f"{violation_type}_{p1['id']}", current_time)
547
-
548
- self._annotate_distance(frame, p1['box'], p2['box'], p1['id'], p2['id'], dist)
549
- logger.info(f"NEW VIOLATION: Unsafe distance between persons {p1['id']} and {p2['id']}")
550
- violations.append({
551
- 'type': violation_type, 'severity': 'Moderate', 'distance': dist,
552
- 'person1_id': p1['id'], 'person2_id': p2['id'], 'timestamp': current_time
553
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  return violations
555
 
556
- def _annotate_frame(self, frame, box, person_id, label_text, color):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  x1, y1, x2, y2 = map(int, box)
558
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
559
- label = f"ID:{person_id:02d} {label_text}"
560
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
 
561
 
562
  def _annotate_distance(self, frame, box1, box2, id1, id2, dist):
563
- for box, pid in [(box1, id1), (box2, id2)]:
564
- x1, y1, x2, y2 = map(int, box)
565
- color = (0, 165, 255) # Orange
566
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
567
- cv2.putText(frame, f"ID:{pid:02d}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
568
-
569
- center1 = self._get_center(box1)
570
- center2 = self._get_center(box2)
571
- cv2.line(frame, (int(center1[0]), int(center1[1])), (int(center2[0]), int(center2[1])), (0, 165, 255), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
  def get_session_summary(self):
574
  summary = {
575
- 'total_persons_tracked': self.next_person_id - 1,
576
  'violations_by_type': {},
577
- 'persons_with_violations': list(self.session_violations.keys())
578
  }
 
579
  for person_id, violations in self.session_violations.items():
580
- for violation_type in violations:
581
- vt = violation_type.split('_')[0] # Group distance violations
582
- summary['violations_by_type'][vt] = summary['violations_by_type'].get(vt, 0) + 1
 
 
 
 
 
 
 
 
 
583
  return summary
584
 
585
  # --- Frame Processing Functions ---
@@ -642,7 +798,7 @@ async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
642
  frame_count = 0
643
  processed_frames = 0
644
  violation_count = 0
645
-
646
  # Get Salesforce connection once at the beginning
647
  sf = None
648
  if SALESFORCE_ENABLED:
@@ -650,7 +806,7 @@ async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
650
  sf = get_salesforce_connection()
651
  except Exception as e:
652
  logger.error(f"Could not connect to Salesforce at start: {e}")
653
-
654
  fps = cap.get(cv2.CAP_PROP_FPS)
655
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
656
  duration = total_frames / fps if fps > 0 else 0
@@ -666,53 +822,50 @@ async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
666
  if frame_count % frame_skip != 0:
667
  continue
668
 
669
- original_frame_copy = frame.copy()
670
  processed_frames += 1
671
  timestamp = datetime.now(IST).isoformat()
672
 
673
  progress_percent = min(100, (frame_count / total_frames) * 100)
674
  progress(progress_percent / 100, desc=f"Processing frame {frame_count}/{total_frames}")
675
 
676
- # Note: YOLO model expects a standard image, not pre-resized
677
- results = yolo_model(original_frame_copy)
678
 
679
- # The detect_violations method now annotates the frame directly
680
- violations = tracker.detect_violations(results, original_frame_copy)
681
 
682
- if violations:
683
- frames.append(original_frame_copy)
684
-
685
- for violation in violations:
686
- violation_count += 1
687
- snapshot_url = save_snapshot(original_frame_copy)
688
- worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
689
- if violation['type'] == 'unsafe_distance':
690
- worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
691
-
692
- violation_data = {
693
- 'violation_type': violation['type'].replace('_', ' ').title(),
694
- 'severity': violation['severity'],
695
- 'timestamp': timestamp,
696
- 'snapshot_url': snapshot_url,
697
- 'site_id': 'SITE001',
698
- 'camera_id': 'CAM001',
699
- 'worker_id': worker_id,
700
- 'frame_number': frame_count
701
- }
702
 
703
- if violation['type'] == 'unsafe_distance':
704
- violation_data['distance'] = f"{violation['distance']:.1f}px"
 
 
 
 
705
 
706
- current_run_violations.append(violation_data)
707
- log_violation(violation_data)
708
- send_alert(violation_data)
 
 
 
709
 
710
- if sf:
711
- record_id, message = create_salesforce_violation_record(sf, violation_data)
712
- if record_id:
713
- new_sf_record_ids.append(record_id)
714
- else:
715
- logger.error(f"Salesforce push failed for violation: {message}")
716
 
717
  cap.release()
718
  processing_time = time.time() - start_time
@@ -727,12 +880,12 @@ async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
727
  logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
728
  pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
729
  if not pdf_temp_path:
730
- logger.error("Failed to generate and upload Salesforce report.")
731
  elif not current_run_violations:
732
  logger.info("No violations detected, skipping report generation.")
733
  else:
734
  logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
735
-
736
  session_summary = tracker.get_session_summary()
737
  logger.info(f"Video analysis complete. Session summary: {session_summary}")
738
 
@@ -741,16 +894,13 @@ async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
741
  frame_count,
742
  processed_frames,
743
  duration,
744
- len(current_run_violations), # Use length of unique violations
745
  processing_time,
746
  actual_fps,
747
  session_summary
748
  )
749
 
750
- # Convert processed frames to RGB for Gradio display
751
- display_frames = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
752
-
753
- return display_frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations)
754
  except Exception as e:
755
  logger.error(f"Video processing error: {e}", exc_info=True)
756
  error_message = f"Video processing failed: {str(e)}"
@@ -760,7 +910,7 @@ async def process_video(video_path, frame_skip=1, progress=gr.Progress()):
760
  cv2.destroyAllWindows()
761
 
762
  # --- RTSP Processing ---
763
- async def process_rtsp_stream(rtsp_url, max_frames=300, frame_skip=5, progress=gr.Progress()):
764
  global processing_active
765
  processing_active = True
766
 
@@ -782,39 +932,33 @@ async def process_rtsp_stream(rtsp_url, max_frames=300, frame_skip=5, progress=g
782
  sf = get_salesforce_connection()
783
  except Exception as e:
784
  logger.error(f"Could not connect to Salesforce at start: {e}")
785
-
786
- processed_frames = []
787
  violation_count = 0
788
 
789
  progress(0, desc="Connecting to RTSP stream...")
790
- frame_gen = capture_rtsp_frames(rtsp_url, max_frames)
791
-
792
- fc = 0
793
- for frame, timestamp, _, _ in frame_gen:
794
- fc += 1
795
  if not processing_active:
796
  break
797
 
798
  if fc % frame_skip != 0:
799
  continue
800
-
801
- original_frame_copy = frame.copy()
802
 
803
- progress(fc / max_frames, desc=f"Processing frame {fc}/{max_frames}")
 
804
 
805
- results = yolo_model(original_frame_copy)
806
- violations = tracker.detect_violations(results, original_frame_copy)
 
 
807
 
808
- if violations:
809
- processed_frames.append(original_frame_copy)
810
-
811
  for violation in violations:
812
  violation_count += 1
813
- snapshot_url = save_snapshot(original_frame_copy)
814
  worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
815
  if violation['type'] == 'unsafe_distance':
816
  worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
817
-
818
  violation_data = {
819
  'violation_type': violation['type'].replace('_', ' ').title(),
820
  'severity': violation['severity'],
@@ -840,6 +984,9 @@ async def process_rtsp_stream(rtsp_url, max_frames=300, frame_skip=5, progress=g
840
  else:
841
  logger.error(f"Salesforce push failed for violation: {message}")
842
 
 
 
 
843
  if not processing_active:
844
  logger.info("Processing cancelled by user.")
845
 
@@ -849,20 +996,21 @@ async def process_rtsp_stream(rtsp_url, max_frames=300, frame_skip=5, progress=g
849
  logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
850
  pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
851
  if not pdf_temp_path:
852
- logger.error("Failed to generate and upload Salesforce report.")
853
  elif not current_run_violations:
854
  logger.info("No violations detected, skipping report generation.")
855
  else:
856
  logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
857
 
 
 
 
858
  session_summary = tracker.get_session_summary()
859
  logger.info(f"RTSP analysis complete. Session summary: {session_summary}")
860
 
861
- status_message = f"Processed {fc} frames with {violation_count} unique violations. Persons tracked: {session_summary['total_persons_tracked']}"
862
-
863
- display_frames = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in processed_frames]
864
 
865
- return status_message, display_frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations), pdf_temp_path
866
  except Exception as e:
867
  logger.error(f"RTSP processing error: {e}", exc_info=True)
868
  error_message = f"RTSP processing failed: {str(e)}"
@@ -884,23 +1032,23 @@ def generate_status_message(has_violations, total_frames, processed_frames, dura
884
 
885
  if session_summary:
886
  base_message += f"""
887
- πŸ‘₯ TOTAL PERSONS TRACKED: {session_summary['total_persons_tracked']}
888
  πŸ” VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys()) if session_summary['violations_by_type'] else 'None'}"""
889
 
890
  if has_violations:
891
  return f"""{base_message}
892
- 🚨 UNIQUE VIOLATIONS LOGGED: {violation_count}
893
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
894
- Each violation type is logged only once per person."""
895
  else:
896
  return f"""{base_message}
897
  βœ… NO VIOLATIONS DETECTED
898
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
899
- All safety protocols followed."""
900
 
901
  def save_snapshot(frame):
902
  try:
903
- filename = f"snapshot_{int(time.time())}_{np.random.randint(100,999)}.jpg"
904
  snapshot_dir = "./snapshots"
905
  os.makedirs(snapshot_dir, exist_ok=True)
906
  snapshot_path = os.path.join(snapshot_dir, filename)
@@ -960,7 +1108,7 @@ Note: Each violation type reported only once per person
960
  text += f"""
961
  β”Œβ”€ ALERT #{i:02d} ─ {severity_emoji} {violation['violation_type'].upper()}
962
  β”‚
963
- β”œβ”€ πŸ• Time: {datetime.fromisoformat(violation['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}
964
  β”œβ”€ ⚠️ Severity: {violation['severity']}
965
  β”œβ”€ πŸ“ Location: Site {violation['site_id']} | Camera {violation['camera_id']}
966
  β”œβ”€ πŸ‘· Worker: {violation.get('worker_id', 'UNKNOWN')}
@@ -974,15 +1122,14 @@ Note: Each violation type reported only once per person
974
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
975
  β€’ Total Violations: {len(violations)}
976
  β€’ Critical: {sum(1 for v in violations if v['severity'] == 'Critical')}
977
- β€’ High: {sum(1 for v in violations if v['severity'] == 'High')}
978
  β€’ Moderate: {sum(1 for v in violations if v['severity'] == 'Moderate')}
979
- β€’ Last Alert: {datetime.fromisoformat(violations[-1]['timestamp']).strftime('%H:%M:%S') if violations else 'N/A'}
980
 
981
  πŸ”„ System Status: ACTIVELY MONITORING
982
  ⚑ Response Time: <5 seconds
983
  🎯 Detection Accuracy: >90% confidence"""
984
  return text
985
-
986
  def generate_heatmap(violations):
987
  if not violations:
988
  return None
@@ -1508,7 +1655,7 @@ with gr.Blocks(
1508
  animateStars();
1509
  </script>
1510
  """)
1511
-
1512
  # Professional Header
1513
  gr.HTML("""
1514
  <div class="main-header">
@@ -1516,7 +1663,7 @@ with gr.Blocks(
1516
  <p class="header-subtitle">Enhanced Multi-Person Tracking - Each violation type detected only once per person per video</p>
1517
  </div>
1518
  """)
1519
-
1520
  # Smart Media Analysis Section
1521
  gr.HTML('<div class="section-header">πŸ“· Smart Media Analysis</div>')
1522
  with gr.Row():
@@ -1535,7 +1682,7 @@ with gr.Blocks(
1535
  elem_classes=["btn-primary"],
1536
  size="lg"
1537
  )
1538
-
1539
  # Analysis Results Section
1540
  gr.HTML('<div class="section-header">πŸ“Š Analysis Results & Violation Details</div>')
1541
  with gr.Row():
@@ -1560,7 +1707,7 @@ with gr.Blocks(
1560
  label="πŸ“₯ Download Professional Report",
1561
  elem_classes=["file-component"]
1562
  )
1563
-
1564
  # Violation Details Section
1565
  gr.HTML('<div class="section-header">🚨 Real-time Violation Monitoring</div>')
1566
  with gr.Group(elem_classes=["professional-card", "alert-panel"]):
@@ -1572,7 +1719,7 @@ with gr.Blocks(
1572
  value=format_violations_as_text(recent_violations),
1573
  interactive=False
1574
  )
1575
-
1576
  # Live Stream Processing Section
1577
  gr.HTML('<div class="section-header">πŸ“Ή Live Stream Monitoring</div>')
1578
  with gr.Row():
@@ -1615,7 +1762,7 @@ with gr.Blocks(
1615
  rows=2,
1616
  object_fit="cover"
1617
  )
1618
-
1619
  # Live Violation Log Section
1620
  gr.HTML('<div class="section-header">πŸ“Š Live Violation Analytics</div>')
1621
  with gr.Row():
@@ -1639,8 +1786,8 @@ with gr.Blocks(
1639
  label="πŸ“₯ Download RTSP Professional Report",
1640
  elem_classes=["file-component"]
1641
  )
1642
-
1643
- # Professional Footer
1644
  gr.HTML(f"""
1645
  <div class="footer-info">
1646
  <h3>πŸ›‘οΈ Dynamic Safety Violation Detection using CCTV + AI</h3>
 
24
  import shutil
25
  import tempfile
26
  from scipy.spatial import distance
 
27
  import asyncio
28
  from functools import partial
29
  from concurrent.futures import ThreadPoolExecutor
 
247
  # 2. Upload ContentVersion to Salesforce
248
  title = f"Safety_Report_{datetime.now(IST).strftime('%Y%m%d_%H%M%S')}"
249
  b64_pdf = base64.b64encode(pdf_bytes).decode('utf-8')
250
+
251
  logger.info(f"Uploading PDF '{title}.pdf' to Salesforce...")
252
  cv_result = sf.ContentVersion.create({
253
  'Title': title,
 
258
  if not cv_result.get('success'):
259
  logger.error(f"Failed to create ContentVersion: {cv_result.get('errors')}")
260
  return None, None
261
+
262
  content_version_id = cv_result['id']
263
  logger.info(f"Successfully created ContentVersion with ID: {content_version_id}")
264
 
 
277
  'LinkedEntityId': record_id,
278
  'ShareType': 'V' # V = Viewer
279
  } for record_id in record_ids]
280
+
281
  link_success_count = 0
282
  for payload in link_payloads:
283
  try:
 
288
  logger.warning(f"Failed to link to {payload['LinkedEntityId']}: {link_result.get('errors')}")
289
  except Exception as e:
290
  logger.error(f"Error creating ContentDocumentLink for {payload['LinkedEntityId']}: {e}")
291
+
292
  logger.info(f"Successfully created {link_success_count}/{len(record_ids)} links.")
293
 
294
  # 5. Construct URL and Update records
 
298
 
299
  update_payloads = [{'Id': record_id, 'PDF_Report_URL__c': pdf_url} for record_id in record_ids]
300
  update_results = sf.bulk.Safety_Violation_Log__c.update(update_payloads)
301
+
302
  successful_updates = sum(1 for res in update_results if res.get('success'))
303
  logger.info(f"Successfully updated {successful_updates}/{len(record_ids)} records with the PDF URL.")
304
 
 
306
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf', prefix='report_') as temp_pdf:
307
  temp_pdf.write(pdf_bytes)
308
  temp_pdf_path = temp_pdf.name
309
+
310
  logger.info(f"Salesforce report URL: {pdf_url}")
311
  logger.info(f"Temporary local report for download: {temp_pdf_path}")
312
 
 
315
  except Exception as e:
316
  logger.error(f"Error in Salesforce PDF report generation/upload: {e}", exc_info=True)
317
  return None, None
318
+
319
+ # --- Safety Violation Detector Class ---
320
  class SafetyViolationDetector:
321
  def __init__(self):
322
+ # Detection thresholds (fine-tuned for better accuracy)
323
  self.helmet_threshold = 0.75
324
  self.person_threshold = 0.60
325
  self.unsafe_distance = 50 # pixels
326
+ self.violation_cooldown = 20 # seconds
327
 
328
  # Unauthorized zones (x1, y1, x2, y2)
329
  self.unauthorized_zones = [
330
+ [100, 100, 300, 300], # Example zone 1
331
+ [400, 200, 600, 400] # Example zone 2
332
  ]
 
 
 
 
 
 
 
333
 
334
+ self.active_violations = {}
335
  self.session_violations = {}
336
+ self.person_tracker = {}
337
+ self.person_positions_history = {}
338
+ self.next_person_id = 1
339
+ self.max_tracking_distance = 120
340
+ self.max_history_length = 20 # Increased for better tracking continuity
341
+ self.min_iou_for_match = 0.4 # Adjusted for stricter matching
342
+ self.min_score_for_match = 0.5 # Adjusted for stricter matching
343
 
344
  def reset_session(self):
345
+ self.session_violations = {}
346
+ self.active_violations = {}
347
  self.person_tracker = {}
348
+ self.person_positions_history = {}
349
  self.next_person_id = 1
350
+ logger.info("Session violation tracking reset for new video")
 
351
 
352
  def has_reported_violation(self, person_id, violation_type):
353
+ if person_id not in self.session_violations:
354
+ return False
355
+ return violation_type in self.session_violations[person_id]
356
 
357
  def mark_violation_reported(self, person_id, violation_type, timestamp):
 
358
  if person_id not in self.session_violations:
359
  self.session_violations[person_id] = {}
360
+ self.session_violations[person_id][violation_type] = {
361
+ 'first_detected': timestamp,
362
+ 'count': self.session_violations[person_id].get(violation_type, {}).get('count', 0) + 1
363
+ }
364
 
365
+ def _get_stable_person_id(self, box, current_time):
366
+ center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
367
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ best_match_id = None
370
+ best_match_score = 0
371
+ min_distance = float('inf')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
+ for person_id, history in self.person_positions_history.items():
374
+ if not history['positions']:
375
+ continue
376
+
377
+ last_position = history['positions'][-1]
378
+ last_box = history['boxes'][-1]
379
+
380
+ dist = np.sqrt((center[0] - last_position[0])**2 + (center[1] - last_position[1])**2)
381
+ iou = self._iou(box, last_box)
382
+
383
+ # Weighted scoring: more emphasis on IoU for stable tracking
384
+ score = (1.0 / (1.0 + dist/50)) * 0.3 + iou * 0.7
385
+
386
+ if score > best_match_score and score > self.min_score_for_match and iou > self.min_iou_for_match:
387
+ best_match_score = score
388
+ best_match_id = person_id
389
+ min_distance = dist
390
+
391
+ if best_match_id is not None:
392
+ person_id = best_match_id
393
+ else:
394
+ person_id = self.next_person_id
395
+ self.next_person_id += 1
396
+ self.person_positions_history[person_id] = {
397
+ 'positions': [],
398
+ 'boxes': [],
399
+ 'first_seen': current_time,
400
+ 'last_seen': current_time,
401
+ 'features': [] # Store features for enhanced tracking
402
+ }
403
+
404
+ # Update tracking history
405
+ self.person_positions_history[person_id]['positions'].append(center)
406
+ self.person_positions_history[person_id]['boxes'].append(box)
407
+ self.person_positions_history[person_id]['last_seen'] = current_time
408
+
409
+ # Limit history length to prevent memory issues
410
+ if len(self.person_positions_history[person_id]['positions']) > self.max_history_length:
411
+ self.person_positions_history[person_id]['positions'].pop(0)
412
+ self.person_positions_history[person_id]['boxes'].pop(0)
413
+
414
+ return person_id
 
 
 
 
 
415
 
416
  def detect_violations(self, results, frame):
417
  current_time = time.time()
 
421
  class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
422
  class_names = results[0].names
423
 
424
+ persons = []
 
425
  helmets = []
426
+
427
  for box, conf, cls_id in zip(boxes, confidences, class_ids):
428
  class_name = class_names[cls_id]
429
  if class_name == "person" and conf >= self.person_threshold:
430
+ person_id = self._get_stable_person_id(box, current_time)
431
+ persons.append({
432
+ 'box': box,
433
+ 'confidence': conf,
434
+ 'center': ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2),
435
+ 'id': person_id
436
+ })
437
  elif class_name == "hard hat" and conf >= self.helmet_threshold:
438
+ helmets.append({
439
+ 'box': box,
440
+ 'confidence': conf,
441
+ 'area': (box[2] - box[0]) * (box[3] - box[1])
442
+ })
443
+
444
+ current_person_ids = set()
445
+ for person in persons:
446
+ person_id = person['id']
447
+ current_person_ids.add(person_id)
448
+
449
+ if person_id not in self.person_tracker:
450
+ self.person_tracker[person_id] = {
451
+ 'first_seen': current_time,
452
+ 'last_seen': current_time,
453
+ 'positions': [person['center']],
454
+ 'helmet_status': False,
455
+ 'violations': {},
456
+ 'last_violation_frame': {}
457
+ }
458
+ else:
459
+ self.person_tracker[person_id]['last_seen'] = current_time
460
+ self.person_tracker[person_id]['positions'].append(person['center'])
461
+ if len(self.person_tracker[person_id]['positions']) > self.max_history_length:
462
+ self.person_tracker[person_id]['positions'].pop(0)
463
 
464
+ for person in persons:
465
+ person_id = person['id']
 
 
 
 
 
 
466
 
 
 
467
  helmet_violation = self._check_helmet_violation(person, helmets, frame, current_time)
468
  if helmet_violation:
469
  violations.append(helmet_violation)
 
472
  if unauthorized_violation:
473
  violations.append(unauthorized_violation)
474
 
475
+ distance_violations = self._check_distance_violations(persons, frame, current_time)
 
476
  violations.extend(distance_violations)
477
 
478
+ self._cleanup_violations(current_time)
479
+ self._cleanup_inactive_persons(current_person_ids, current_time)
480
+
481
  return violations
482
 
483
  def _check_helmet_violation(self, person, helmets, frame, current_time):
484
  person_id = person['id']
485
+ person_box = person['box']
486
  violation_type = 'no_helmet'
487
 
488
  if self.has_reported_violation(person_id, violation_type):
489
  return None
490
 
 
491
  head_region = [
492
+ person_box[0],
493
+ max(person_box[1], person_box[1] + (person_box[3] - person_box[1]) * 0.3),
494
+ person_box[2],
495
+ person_box[1] + (person_box[3] - person_box[1]) * 0.3
496
  ]
497
 
498
+ has_helmet = False
499
+ for helmet in helmets:
500
+ if self._iou(helmet['box'], head_region) > 0.1:
501
+ has_helmet = True
502
+ break
503
+
504
+ self.person_tracker[person_id]['helmet_status'] = has_helmet
505
+
506
  if not has_helmet:
507
+ violation_key = f"no_helmet_{person_id}"
508
+
509
+ # Ensure persistent violation check across frames
510
+ if violation_key not in self.active_violations:
511
+ self.active_violations[violation_key] = {
512
+ 'type': 'no_helmet',
513
+ 'person_id': person_id,
514
+ 'first_detected': current_time,
515
+ 'last_detected': current_time,
516
+ 'count': 0,
517
+ 'confirmed': False
518
+ }
519
+
520
+ self.active_violations[violation_key]['count'] += 1
521
+ self.active_violations[violation_key]['last_detected'] = current_time
522
+
523
+ # Require consistent detection over multiple frames for confirmation
524
+ if self.active_violations[violation_key]['count'] >= 5 and not self.active_violations[violation_key]['confirmed']:
525
+ self.mark_violation_reported(person_id, violation_type, current_time)
526
+ self.active_violations[violation_key]['confirmed'] = True
527
+
528
+ if 'no_helmet' not in self.person_tracker[person_id]['violations']:
529
+ self.person_tracker[person_id]['violations']['no_helmet'] = {
530
+ 'count': 0,
531
+ 'last_time': 0
532
+ }
533
+ self.person_tracker[person_id]['violations']['no_helmet']['count'] += 1
534
+ self.person_tracker[person_id]['violations']['no_helmet']['last_time'] = current_time
535
+
536
+ self._annotate_frame(frame, person_box, person_id, "NO HELMET", (0, 0, 255))
537
+ logger.info(f"CONFIRMED VIOLATION: No helmet detected for person {person_id}")
538
+
539
+ return {
540
+ 'type': 'no_helmet',
541
+ 'severity': 'Critical',
542
+ 'person': person,
543
+ 'person_id': person_id,
544
+ 'timestamp': current_time
545
+ }
546
  return None
547
 
548
  def _check_unauthorized_area(self, person, frame, current_time):
 
552
  if self.has_reported_violation(person_id, violation_type):
553
  return None
554
 
555
+ x1, y1, x2, y2 = person['box']
556
+ person_center = ((x1 + x2) / 2, (y1 + y2) / 2)
557
+
558
  for zone in self.unauthorized_zones:
559
  zx1, zy1, zx2, zy2 = zone
560
+ if (zx1 <= person_center[0] <= zx2 and zy1 <= person_center[1] <= zy2):
561
+ violation_key = f"unauthorized_area_{person_id}_{zx1}_{zy1}"
562
+
563
+ if violation_key not in self.active_violations:
564
+ self.active_violations[violation_key] = {
565
+ 'type': 'unauthorized_area',
566
+ 'person_id': person_id,
567
+ 'zone': zone,
568
+ 'first_detected': current_time,
569
+ 'last_detected': current_time,
570
+ 'count': 0,
571
+ 'confirmed': False
572
+ }
573
+
574
+ self.active_violations[violation_key]['count'] += 1
575
+ self.active_violations[violation_key]['last_detected'] = current_time
576
+
577
+ if self.active_violations[violation_key]['count'] >= 5 and not self.active_violations[violation_key]['confirmed']:
578
+ self.mark_violation_reported(person_id, violation_type, current_time)
579
+ self.active_violations[violation_key]['confirmed'] = True
580
+
581
+ if 'unauthorized_area' not in self.person_tracker[person_id]['violations']:
582
+ self.person_tracker[person_id]['violations']['unauthorized_area'] = {
583
+ 'count': 0,
584
+ 'last_time': 0
585
+ }
586
+ self.person_tracker[person_id]['violations']['unauthorized_area']['count'] += 1
587
+ self.person_tracker[person_id]['violations']['unauthorized_area']['last_time'] = current_time
588
+
589
+ cv2.rectangle(frame, (zx1, zy1), (zx2, zy2), (255, 0, 255), 2)
590
+ self._annotate_frame(frame, person['box'], person_id, "UNAUTHORIZED", (255, 0, 255))
591
+ logger.info(f"CONFIRMED VIOLATION: Unauthorized area detected for person {person_id}")
592
+
593
+ return {
594
+ 'type': 'unauthorized_area',
595
+ 'severity': 'High',
596
+ 'person': person,
597
+ 'person_id': person_id,
598
+ 'zone': zone,
599
+ 'timestamp': current_time
600
+ }
601
  return None
602
 
603
  def _check_distance_violations(self, persons, frame, current_time):
604
  violations = []
605
+ if len(persons) < 2:
606
+ return violations
607
 
608
  for i in range(len(persons)):
609
+ for j in range(i+1, len(persons)):
610
+ dist = self._euclidean_distance(persons[i]['center'], persons[j]['center'])
 
 
 
611
  if dist < self.unsafe_distance:
612
+ person1_id = persons[i]['id']
613
+ person2_id = persons[j]['id']
614
+ violation_type = 'unsafe_distance'
615
+
616
+ if (self.has_reported_violation(person1_id, violation_type) or
617
+ self.has_reported_violation(person2_id, violation_type)):
618
  continue
619
+
620
+ pair_key = f"{min(person1_id, person2_id)}_{max(person1_id, person2_id)}"
621
+ violation_key = f"unsafe_distance_{pair_key}"
622
+
623
+ if violation_key not in self.active_violations:
624
+ self.active_violations[violation_key] = {
625
+ 'type': 'unsafe_distance',
626
+ 'person1_id': person1_id,
627
+ 'person2_id': person2_id,
628
+ 'first_detected': current_time,
629
+ 'last_detected': current_time,
630
+ 'count': 0,
631
+ 'confirmed': False
632
+ }
633
+
634
+ self.active_violations[violation_key]['count'] += 1
635
+ self.active_violations[violation_key]['last_detected'] = current_time
636
+
637
+ if self.active_violations[violation_key]['count'] >= 5 and not self.active_violations[violation_key]['confirmed']:
638
+ self.mark_violation_reported(person1_id, violation_type, current_time)
639
+ self.mark_violation_reported(person2_id, violation_type, current_time)
640
+ self.active_violations[violation_key]['confirmed'] = True
641
+
642
+ for pid in [person1_id, person2_id]:
643
+ if 'unsafe_distance' not in self.person_tracker[pid]['violations']:
644
+ self.person_tracker[pid]['violations']['unsafe_distance'] = {
645
+ 'count': 0,
646
+ 'last_time': 0
647
+ }
648
+ self.person_tracker[pid]['violations']['unsafe_distance']['count'] += 1
649
+ self.person_tracker[pid]['violations']['unsafe_distance']['last_time'] = current_time
650
+
651
+ self._annotate_distance(frame, persons[i]['box'], persons[j]['box'],
652
+ person1_id, person2_id, dist)
653
+ logger.info(f"CONFIRMED VIOLATION: Unsafe distance detected between persons {person1_id} and {person2_id}")
654
+
655
+ violations.append({
656
+ 'type': 'unsafe_distance',
657
+ 'severity': 'Moderate',
658
+ 'person1': persons[i],
659
+ 'person2': persons[j],
660
+ 'distance': dist,
661
+ 'person1_id': person1_id,
662
+ 'person2_id': person2_id,
663
+ 'timestamp': current_time
664
+ })
665
  return violations
666
 
667
+ def _cleanup_violations(self, current_time):
668
+ expired_violations = [
669
+ k for k, v in self.active_violations.items()
670
+ if current_time - v['last_detected'] > self.violation_cooldown
671
+ ]
672
+ for key in expired_violations:
673
+ del self.active_violations[key]
674
+
675
+ def _cleanup_inactive_persons(self, current_person_ids, current_time):
676
+ inactive_timeout = 60
677
+ expired_persons = [
678
+ pid for pid, data in self.person_tracker.items()
679
+ if pid not in current_person_ids and
680
+ current_time - data['last_seen'] > inactive_timeout
681
+ ]
682
+ for pid in expired_persons:
683
+ del self.person_tracker[pid]
684
+ if pid in self.person_positions_history:
685
+ del self.person_positions_history[pid]
686
+
687
+ def _annotate_frame(self, frame, box, person_id, violation_type, color):
688
  x1, y1, x2, y2 = map(int, box)
689
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
690
+ label = f"ID:{person_id:03d} {violation_type}"
691
+ cv2.putText(frame, label, (x1, y1 - 10),
692
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
693
 
694
  def _annotate_distance(self, frame, box1, box2, id1, id2, dist):
695
+ x1, y1, x2, y2 = map(int, box1)
696
+ x3, y3, x4, y4 = map(int, box2)
697
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 165, 255), 2)
698
+ cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 165, 255), 2)
699
+ center1 = ((x1 + x2) // 2, (y1 + y2) // 2)
700
+ center2 = ((x3 + x4) // 2, (y3 + y4) // 2)
701
+ cv2.line(frame, center1, center2, (0, 165, 255), 2)
702
+ mid_point = ((center1[0] + center2[0]) // 2, (center1[1] + center2[1]) // 2)
703
+ cv2.putText(frame, f"{dist:.1f}px", mid_point,
704
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 2)
705
+
706
+ def _iou(self, box1, box2):
707
+ x1 = max(box1[0], box2[0])
708
+ y1 = max(box1[1], box2[1])
709
+ x2 = min(box1[2], box2[2])
710
+ y2 = min(box1[3], box2[3])
711
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
712
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
713
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
714
+ return intersection / (area1 + area2 - intersection + 1e-6)
715
+
716
+ def _euclidean_distance(self, point1, point2):
717
+ return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)
718
 
719
  def get_session_summary(self):
720
  summary = {
721
+ 'total_persons': len(self.session_violations),
722
  'violations_by_type': {},
723
+ 'persons_with_violations': []
724
  }
725
+
726
  for person_id, violations in self.session_violations.items():
727
+ person_info = {
728
+ 'person_id': person_id,
729
+ 'violations': list(violations.keys()),
730
+ 'violation_count': len(violations)
731
+ }
732
+ summary['persons_with_violations'].append(person_info)
733
+
734
+ for violation_type in violations.keys():
735
+ if violation_type not in summary['violations_by_type']:
736
+ summary['violations_by_type'][violation_type] = 0
737
+ summary['violations_by_type'][violation_type] += 1
738
+
739
  return summary
740
 
741
  # --- Frame Processing Functions ---
 
798
  frame_count = 0
799
  processed_frames = 0
800
  violation_count = 0
801
+
802
  # Get Salesforce connection once at the beginning
803
  sf = None
804
  if SALESFORCE_ENABLED:
 
806
  sf = get_salesforce_connection()
807
  except Exception as e:
808
  logger.error(f"Could not connect to Salesforce at start: {e}")
809
+
810
  fps = cap.get(cv2.CAP_PROP_FPS)
811
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
812
  duration = total_frames / fps if fps > 0 else 0
 
822
  if frame_count % frame_skip != 0:
823
  continue
824
 
 
825
  processed_frames += 1
826
  timestamp = datetime.now(IST).isoformat()
827
 
828
  progress_percent = min(100, (frame_count / total_frames) * 100)
829
  progress(progress_percent / 100, desc=f"Processing frame {frame_count}/{total_frames}")
830
 
831
+ processed_frame = preprocess_frame(frame)
832
+ results = yolo_model.predict(processed_frame)
833
 
834
+ violations = tracker.detect_violations(results, frame)
 
835
 
836
+ for violation in violations:
837
+ violation_count += 1
838
+ snapshot_url = save_snapshot(frame)
839
+ worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
840
+ if violation['type'] == 'unsafe_distance':
841
+ worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
842
+ violation_data = {
843
+ 'violation_type': violation['type'].replace('_', ' ').title(),
844
+ 'severity': violation['severity'],
845
+ 'timestamp': timestamp,
846
+ 'snapshot_url': snapshot_url,
847
+ 'site_id': 'SITE001',
848
+ 'camera_id': 'CAM001',
849
+ 'worker_id': worker_id,
850
+ 'frame_number': frame_count
851
+ }
 
 
 
 
852
 
853
+ if violation['type'] == 'unsafe_distance':
854
+ violation_data['distance'] = f"{violation['distance']:.1f}px"
855
+
856
+ current_run_violations.append(violation_data)
857
+ log_violation(violation_data)
858
+ send_alert(violation_data)
859
 
860
+ if sf:
861
+ record_id, message = create_salesforce_violation_record(sf, violation_data)
862
+ if record_id:
863
+ new_sf_record_ids.append(record_id)
864
+ else:
865
+ logger.error(f"Salesforce push failed for violation: {message}")
866
 
867
+ if violations:
868
+ frames.append(frame)
 
 
 
 
869
 
870
  cap.release()
871
  processing_time = time.time() - start_time
 
880
  logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
881
  pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
882
  if not pdf_temp_path:
883
+ logger.error("Failed to generate and upload Salesforce report.")
884
  elif not current_run_violations:
885
  logger.info("No violations detected, skipping report generation.")
886
  else:
887
  logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
888
+
889
  session_summary = tracker.get_session_summary()
890
  logger.info(f"Video analysis complete. Session summary: {session_summary}")
891
 
 
894
  frame_count,
895
  processed_frames,
896
  duration,
897
+ violation_count,
898
  processing_time,
899
  actual_fps,
900
  session_summary
901
  )
902
 
903
+ return frames, status_message, pdf_temp_path, format_violations_as_text(current_run_violations)
 
 
 
904
  except Exception as e:
905
  logger.error(f"Video processing error: {e}", exc_info=True)
906
  error_message = f"Video processing failed: {str(e)}"
 
910
  cv2.destroyAllWindows()
911
 
912
  # --- RTSP Processing ---
913
+ async def process_rtsp_stream(rtsp_url, max_frames=None, frame_skip=1, progress=gr.Progress()):
914
  global processing_active
915
  processing_active = True
916
 
 
932
  sf = get_salesforce_connection()
933
  except Exception as e:
934
  logger.error(f"Could not connect to Salesforce at start: {e}")
935
+
936
+ frames = []
937
  violation_count = 0
938
 
939
  progress(0, desc="Connecting to RTSP stream...")
940
+
941
+ for frame, timestamp, fc, _ in capture_rtsp_frames(rtsp_url, max_frames):
 
 
 
942
  if not processing_active:
943
  break
944
 
945
  if fc % frame_skip != 0:
946
  continue
 
 
947
 
948
+ progress_percent = min(100, (fc / (max_frames if max_frames else 100)) * 100)
949
+ progress(progress_percent / 100, desc=f"Processing frame {fc}")
950
 
951
+ processed_frame = preprocess_frame(frame)
952
+ results = yolo_model.predict(processed_frame)
953
+
954
+ violations = tracker.detect_violations(results, frame)
955
 
 
 
 
956
  for violation in violations:
957
  violation_count += 1
958
+ snapshot_url = save_snapshot(frame)
959
  worker_id = f"WORKER{violation.get('person_id', 'UNKNOWN')}"
960
  if violation['type'] == 'unsafe_distance':
961
  worker_id = f"WORKER{violation['person1_id']} & WORKER{violation['person2_id']}"
 
962
  violation_data = {
963
  'violation_type': violation['type'].replace('_', ' ').title(),
964
  'severity': violation['severity'],
 
984
  else:
985
  logger.error(f"Salesforce push failed for violation: {message}")
986
 
987
+ if violations:
988
+ frames.append(frame)
989
+
990
  if not processing_active:
991
  logger.info("Processing cancelled by user.")
992
 
 
996
  logger.info(f"Generating PDF report and uploading to Salesforce for {len(new_sf_record_ids)} violations...")
997
  pdf_temp_path, pdf_sf_url = generate_and_upload_report_to_salesforce(sf, current_run_violations, new_sf_record_ids)
998
  if not pdf_temp_path:
999
+ logger.error("Failed to generate and upload Salesforce report.")
1000
  elif not current_run_violations:
1001
  logger.info("No violations detected, skipping report generation.")
1002
  else:
1003
  logger.warning("Salesforce not configured or no violations recorded. Skipping Salesforce report upload.")
1004
 
1005
+ if not processing_active:
1006
+ return "Processing cancelled.", frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations), pdf_temp_path
1007
+
1008
  session_summary = tracker.get_session_summary()
1009
  logger.info(f"RTSP analysis complete. Session summary: {session_summary}")
1010
 
1011
+ status_message = f"Processed {len(frames)} frames with {violation_count} unique violations. Persons tracked: {session_summary['total_persons']}"
 
 
1012
 
1013
+ return status_message, frames, format_violations_as_text(current_run_violations), generate_heatmap(current_run_violations), pdf_temp_path
1014
  except Exception as e:
1015
  logger.error(f"RTSP processing error: {e}", exc_info=True)
1016
  error_message = f"RTSP processing failed: {str(e)}"
 
1032
 
1033
  if session_summary:
1034
  base_message += f"""
1035
+ πŸ‘₯ UNIQUE PERSONS TRACKED: {session_summary['total_persons']}
1036
  πŸ” VIOLATION TYPES: {', '.join(session_summary['violations_by_type'].keys()) if session_summary['violations_by_type'] else 'None'}"""
1037
 
1038
  if has_violations:
1039
  return f"""{base_message}
1040
+ 🚨 UNIQUE VIOLATIONS: {violation_count}
1041
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1042
+ Each violation reported only once per person"""
1043
  else:
1044
  return f"""{base_message}
1045
  βœ… NO VIOLATIONS DETECTED
1046
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1047
+ All safety protocols followed"""
1048
 
1049
  def save_snapshot(frame):
1050
  try:
1051
+ filename = f"snapshot_{int(time.time())}.jpg"
1052
  snapshot_dir = "./snapshots"
1053
  os.makedirs(snapshot_dir, exist_ok=True)
1054
  snapshot_path = os.path.join(snapshot_dir, filename)
 
1108
  text += f"""
1109
  β”Œβ”€ ALERT #{i:02d} ─ {severity_emoji} {violation['violation_type'].upper()}
1110
  β”‚
1111
+ β”œβ”€ πŸ• Time: {violation['timestamp']}
1112
  β”œβ”€ ⚠️ Severity: {violation['severity']}
1113
  β”œβ”€ πŸ“ Location: Site {violation['site_id']} | Camera {violation['camera_id']}
1114
  β”œβ”€ πŸ‘· Worker: {violation.get('worker_id', 'UNKNOWN')}
 
1122
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1123
  β€’ Total Violations: {len(violations)}
1124
  β€’ Critical: {sum(1 for v in violations if v['severity'] == 'Critical')}
 
1125
  β€’ Moderate: {sum(1 for v in violations if v['severity'] == 'Moderate')}
1126
+ β€’ Last Alert: {violations[-1]['timestamp'] if violations else 'N/A'}
1127
 
1128
  πŸ”„ System Status: ACTIVELY MONITORING
1129
  ⚑ Response Time: <5 seconds
1130
  🎯 Detection Accuracy: >90% confidence"""
1131
  return text
1132
+
1133
  def generate_heatmap(violations):
1134
  if not violations:
1135
  return None
 
1655
  animateStars();
1656
  </script>
1657
  """)
1658
+
1659
  # Professional Header
1660
  gr.HTML("""
1661
  <div class="main-header">
 
1663
  <p class="header-subtitle">Enhanced Multi-Person Tracking - Each violation type detected only once per person per video</p>
1664
  </div>
1665
  """)
1666
+
1667
  # Smart Media Analysis Section
1668
  gr.HTML('<div class="section-header">πŸ“· Smart Media Analysis</div>')
1669
  with gr.Row():
 
1682
  elem_classes=["btn-primary"],
1683
  size="lg"
1684
  )
1685
+
1686
  # Analysis Results Section
1687
  gr.HTML('<div class="section-header">πŸ“Š Analysis Results & Violation Details</div>')
1688
  with gr.Row():
 
1707
  label="πŸ“₯ Download Professional Report",
1708
  elem_classes=["file-component"]
1709
  )
1710
+
1711
  # Violation Details Section
1712
  gr.HTML('<div class="section-header">🚨 Real-time Violation Monitoring</div>')
1713
  with gr.Group(elem_classes=["professional-card", "alert-panel"]):
 
1719
  value=format_violations_as_text(recent_violations),
1720
  interactive=False
1721
  )
1722
+
1723
  # Live Stream Processing Section
1724
  gr.HTML('<div class="section-header">πŸ“Ή Live Stream Monitoring</div>')
1725
  with gr.Row():
 
1762
  rows=2,
1763
  object_fit="cover"
1764
  )
1765
+
1766
  # Live Violation Log Section
1767
  gr.HTML('<div class="section-header">πŸ“Š Live Violation Analytics</div>')
1768
  with gr.Row():
 
1786
  label="πŸ“₯ Download RTSP Professional Report",
1787
  elem_classes=["file-component"]
1788
  )
1789
+
1790
+ # Professional Footer
1791
  gr.HTML(f"""
1792
  <div class="footer-info">
1793
  <h3>πŸ›‘οΈ Dynamic Safety Violation Detection using CCTV + AI</h3>