PrashanthB461 commited on
Commit
15e98c2
·
verified ·
1 Parent(s): 77abc00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -130
app.py CHANGED
@@ -19,7 +19,6 @@ from retrying import retry
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
- import face_recognition
23
  from collections import defaultdict
24
 
25
  # ========================== # Configuration and Setup # ==========================
@@ -29,98 +28,134 @@ os.makedirs('/tmp/Ultralytics', exist_ok=True)
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
- # ========================== # Face Recognition Setup # ==========================
33
- class FaceTracker:
34
- def __init__(self):
35
- self.known_faces = {}
36
- self.next_face_id = 1
37
- self.tolerance = 0.6
38
- self.frame_skip = 5 # Process face recognition every N frames
39
-
40
- def get_face_encoding(self, frame, box):
41
- """Extract face encoding from bounding box"""
42
- x, y, w, h = box
43
- x1, y1 = int(x - w/2), int(y - h/2)
44
- x2, y2 = int(x + w/2), int(y + h/2)
45
-
46
- # Expand the face area slightly
47
- expand = 0.2
48
- h_expand = int((y2 - y1) * expand)
49
- w_expand = int((x2 - x1) * expand)
50
 
51
- y1 = max(0, y1 - h_expand)
52
- y2 = min(frame.shape[0], y2 + h_expand)
53
- x1 = max(0, x1 - w_expand)
54
- x2 = min(frame.shape[1], x2 + w_expand)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- face_frame = frame[y1:y2, x1:x2]
 
 
57
 
58
- if face_frame.size == 0:
59
- return None
 
 
60
 
61
- # Convert to RGB (face_recognition uses RGB)
62
- rgb_frame = cv2.cvtColor(face_frame, cv2.COLOR_BGR2RGB)
 
 
 
 
63
 
64
- # Get face encodings
65
- encodings = face_recognition.face_encodings(rgb_frame)
66
- return encodings[0] if encodings else None
 
 
67
 
68
- def identify_face(self, frame, box):
69
- """Identify or register a new face"""
70
- encoding = self.get_face_encoding(frame, box)
71
- if encoding is None:
72
- return None
73
 
74
- # Compare with known faces
75
- for face_id, known_encoding in self.known_faces.items():
76
- matches = face_recognition.compare_faces([known_encoding], encoding, tolerance=self.tolerance)
77
- if matches[0]:
78
- return face_id
79
-
80
- # If no match, register new face
81
- face_id = f"face_{self.next_face_id}"
82
- self.known_faces[face_id] = encoding
83
- self.next_face_id += 1
84
- return face_id
85
-
86
- # ========================== # Position-Based Tracker # ==========================
87
- class PositionTracker:
88
- def __init__(self, distance_threshold=100, cooldown=30):
89
- self.workers = {}
90
- self.distance_threshold = distance_threshold
91
- self.cooldown = cooldown
92
- self.next_id = 1
93
 
94
- def track(self, position, violation_type, current_time):
95
- """Track worker position and return worker ID"""
96
- # Check if this is a known worker
97
- for worker_id, worker_data in self.workers.items():
98
- last_pos = worker_data['position']
99
- last_time = worker_data['last_seen']
100
-
101
- # Calculate distance and time difference
102
- distance = np.sqrt((position[0] - last_pos[0])**2 + (position[1] - last_pos[1])**2)
103
- time_diff = current_time - last_time
 
104
 
105
- # If close enough and not too much time has passed
106
- if distance < self.distance_threshold and time_diff < self.cooldown:
107
- # Check if this violation type was already recorded
108
- if violation_type not in worker_data['violations']:
109
- worker_data['position'] = position
110
- worker_data['last_seen'] = current_time
111
- worker_data['violations'].add(violation_type)
112
- return worker_id
113
- return None # Violation already recorded
114
 
115
- # If no match, create new worker
116
- worker_id = f"worker_{self.next_id}"
117
- self.workers[worker_id] = {
118
- 'position': position,
119
- 'last_seen': current_time,
120
- 'violations': {violation_type}
121
- }
122
- self.next_id += 1
123
- return worker_id
 
124
 
125
  # ========================== # Optimized Configuration # ==========================
126
  CONFIG = {
@@ -139,7 +174,7 @@ CONFIG = {
139
  "no_harness": (0, 165, 255), # Orange
140
  "unsafe_posture": (0, 255, 0), # Green
141
  "unsafe_zone": (255, 0, 0), # Blue
142
- "improper_tool_use": (255, 255, 0) # Cyan
143
  },
144
  "DISPLAY_NAMES": {
145
  "no_helmet": "No Helmet Violation",
@@ -162,15 +197,16 @@ CONFIG = {
162
  "unsafe_zone": 0.3,
163
  "improper_tool_use": 0.3
164
  },
165
- "MIN_VIOLATION_FRAMES": 1,
166
- "VIOLATION_COOLDOWN": 30.0,
167
  "WORKER_TRACKING_DURATION": 5.0,
168
  "MAX_PROCESSING_TIME": 60,
169
  "FRAME_SKIP": 2,
170
  "BATCH_SIZE": 16,
171
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
172
- "FACE_RECOGNITION_INTERVAL": 5, # Process face recognition every N frames
173
- "POSITION_TRACKING_THRESHOLD": 100, # pixels
 
174
  "SNAPSHOT_QUALITY": 95,
175
  "MAX_WORKER_DISTANCE": 100
176
  }
@@ -248,14 +284,20 @@ def calculate_safety_score(violations):
248
  }
249
 
250
  # Count unique violation types per worker
251
- worker_violations = defaultdict(set)
252
  for v in violations:
253
  worker_id = v.get("worker_id", "Unknown")
254
  violation_type = v.get("violation", "Unknown")
 
 
 
255
  worker_violations[worker_id].add(violation_type)
256
 
257
  # Calculate total penalty
258
- total_penalty = sum(penalties.get(v, 0) for violations_set in worker_violations.values() for v in violations_set)
 
 
 
259
 
260
  score = max(0, 100 - total_penalty)
261
  return score
@@ -288,9 +330,11 @@ def generate_violation_pdf(violations, score):
288
  y_position -= 0.3 * inch
289
 
290
  # Group violations by worker
291
- worker_violations = defaultdict(list)
292
  for v in violations:
293
  worker_id = v.get("worker_id", "Unknown")
 
 
294
  worker_violations[worker_id].append(v)
295
 
296
  c.setFont("Helvetica", 10)
@@ -441,7 +485,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
441
  return None, ""
442
 
443
  def process_video(video_data):
444
- """Process video to detect safety violations"""
445
  try:
446
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
447
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
@@ -463,11 +507,11 @@ def process_video(video_data):
463
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
464
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
465
 
466
- # Initialize trackers
467
- face_tracker = FaceTracker()
468
- position_tracker = PositionTracker(
469
- distance_threshold=CONFIG["POSITION_TRACKING_THRESHOLD"],
470
- cooldown=CONFIG["VIOLATION_COOLDOWN"]
471
  )
472
 
473
  violations = []
@@ -475,7 +519,6 @@ def process_video(video_data):
475
  start_time = time.time()
476
  frame_skip = CONFIG["FRAME_SKIP"]
477
  processed_frames = 0
478
- frame_count = 0
479
 
480
  while processed_frames < total_frames:
481
  batch_frames = []
@@ -500,7 +543,6 @@ def process_video(video_data):
500
  batch_frames.append(frame)
501
  batch_indices.append(frame_idx)
502
  processed_frames += 1
503
- frame_count += 1
504
 
505
  if not batch_frames:
506
  break
@@ -518,7 +560,7 @@ def process_video(video_data):
518
  start_time = time.time()
519
 
520
  boxes = result.boxes
521
- detections = []
522
 
523
  for box in boxes:
524
  cls = int(box.cls)
@@ -532,35 +574,42 @@ def process_video(video_data):
532
  continue
533
 
534
  bbox = box.xywh.cpu().numpy()[0]
 
 
 
 
 
 
 
 
535
 
536
- # For helmet violations, use face recognition
537
- if label == "no_helmet" and frame_count % CONFIG["FACE_RECOGNITION_INTERVAL"] == 0:
538
- worker_id = face_tracker.identify_face(batch_frames[i], bbox)
539
- else:
540
- # For other violations, use position tracking
541
- position = (bbox[0], bbox[1])
542
- worker_id = position_tracker.track(position, label, current_time)
 
 
 
 
 
 
543
 
544
- if worker_id is None:
545
- continue # Skip if this is a duplicate violation
546
 
547
- detection = {
548
- "worker_id": worker_id,
549
- "violation": label,
550
- "confidence": round(conf, 2),
551
- "bounding_box": bbox,
552
- "timestamp": current_time
553
- }
554
- detections.append(detection)
555
-
556
- # Process new violations
557
- for detection in detections:
558
- # Check if we already have this violation for this worker
559
- existing = next((v for v in violations
560
- if v["worker_id"] == detection["worker_id"]
561
- and v["violation"] == detection["violation"]), None)
562
-
563
- if not existing:
564
  violations.append(detection)
565
 
566
  # Take snapshot for the new violation
@@ -579,7 +628,7 @@ def process_video(video_data):
579
  )
580
 
581
  # Save snapshot with high quality
582
- snapshot_filename = f"violation_{detection['violation']}_worker{detection['worker_id']}_{int(current_time*100)}.jpg"
583
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
584
 
585
  cv2.imwrite(
@@ -589,14 +638,14 @@ def process_video(video_data):
589
  )
590
 
591
  snapshots.append({
592
- "violation": detection["violation"],
593
- "worker_id": detection["worker_id"],
594
  "timestamp": current_time,
595
  "snapshot_path": snapshot_path,
596
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
597
  })
598
 
599
- logger.info(f"Captured snapshot for {detection['violation']} violation by worker {detection['worker_id']} at {current_time:.2f}s")
600
 
601
  cap.release()
602
  if os.path.exists(video_path):
 
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
 
22
  from collections import defaultdict
23
 
24
  # ========================== # Configuration and Setup # ==========================
 
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
  logger = logging.getLogger(__name__)
30
 
31
+ # ========================== # Enhanced BYTETracker Implementation # ==========================
32
+ class EnhancedBYTETracker:
33
+ def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
34
+ self.track_thresh = track_thresh
35
+ self.track_buffer = track_buffer
36
+ self.match_thresh = match_thresh
37
+ self.frame_rate = frame_rate
38
+ self.next_id = 1
39
+ self.tracks = {} # Store active tracks
40
+ self.violation_history = defaultdict(dict) # Track violations per worker
41
+ self.last_positions = {} # Last known positions of workers
42
+ self.violation_cooldown = {} # Cooldown periods for violations
43
+
44
+ def update(self, dets, scores, cls, frame_idx):
45
+ tracks = []
46
+ current_time = frame_idx / self.frame_rate
 
 
47
 
48
+ # Update existing tracks with new detections
49
+ for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
50
+ if score < self.track_thresh:
51
+ continue
52
+
53
+ x, y, w, h = det
54
+ matched = False
55
+ best_iou = 0
56
+ best_track_id = None
57
+
58
+ # Try to match with existing tracks
59
+ for track_id, track_info in self.tracks.items():
60
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
61
+ continue
62
+
63
+ tx, ty, tw, th = track_info['bbox']
64
+ iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
65
+
66
+ if iou > self.match_thresh and iou > best_iou:
67
+ best_iou = iou
68
+ best_track_id = track_id
69
+ matched = True
70
+
71
+ if matched:
72
+ # Update existing track
73
+ self.tracks[best_track_id].update({
74
+ 'bbox': [x, y, w, h],
75
+ 'score': score,
76
+ 'cls': cl,
77
+ 'last_seen': current_time
78
+ })
79
+
80
+ # Update position
81
+ self.last_positions[best_track_id] = [x, y]
82
+
83
+ tracks.append({
84
+ 'id': best_track_id,
85
+ 'bbox': [x, y, w, h],
86
+ 'score': score,
87
+ 'cls': cl
88
+ })
89
+ else:
90
+ # Create new track
91
+ self.tracks[self.next_id] = {
92
+ 'bbox': [x, y, w, h],
93
+ 'score': score,
94
+ 'cls': cl,
95
+ 'last_seen': current_time
96
+ }
97
+ self.last_positions[self.next_id] = [x, y]
98
+ tracks.append({
99
+ 'id': self.next_id,
100
+ 'bbox': [x, y, w, h],
101
+ 'score': score,
102
+ 'cls': cl
103
+ })
104
+ self.next_id += 1
105
 
106
+ # Clean up old tracks
107
+ stale_ids = [tid for tid, track in self.tracks.items()
108
+ if current_time - track['last_seen'] > self.track_buffer / self.frame_rate]
109
 
110
+ for track_id in stale_ids:
111
+ del self.tracks[track_id]
112
+ if track_id in self.last_positions:
113
+ del self.last_positions[track_id]
114
 
115
+ return tracks
116
+
117
+ def _calculate_iou(self, box1, box2):
118
+ """Calculate IOU between two boxes"""
119
+ x1, y1, w1, h1 = box1
120
+ x2, y2, w2, h2 = box2
121
 
122
+ # Calculate intersection coordinates
123
+ x_left = max(x1 - w1/2, x2 - w2/2)
124
+ y_top = max(y1 - h1/2, y2 - h2/2)
125
+ x_right = min(x1 + w1/2, x2 + w2/2)
126
+ y_bottom = min(y1 + h1/2, y2 + h2/2)
127
 
128
+ if x_right < x_left or y_bottom < y_top:
129
+ return 0.0
 
 
 
130
 
131
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ box1_area = w1 * h1
134
+ box2_area = w2 * h2
135
+
136
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
137
+ return iou
138
+
139
+ def is_new_violation(self, worker_id, violation_type, bbox, current_time):
140
+ """Check if this is a new violation that should be reported"""
141
+ # Check if this worker has had this violation type before
142
+ if violation_type in self.violation_history[worker_id]:
143
+ last_time, last_bbox = self.violation_history[worker_id][violation_type]
144
 
145
+ # For no_helmet, we only report once per worker
146
+ if violation_type == "no_helmet":
147
+ return False
 
 
 
 
 
 
148
 
149
+ # For other violations, check if it's in a new location or after cooldown
150
+ if current_time - last_time < CONFIG["VIOLATION_COOLDOWN"]:
151
+ # Check if it's in the same area (using IOU)
152
+ iou = self._calculate_iou(bbox, last_bbox)
153
+ if iou > 0.3: # If overlap > 30%, consider it the same violation
154
+ return False
155
+
156
+ # If we get here, it's a new violation
157
+ self.violation_history[worker_id][violation_type] = (current_time, bbox)
158
+ return True
159
 
160
  # ========================== # Optimized Configuration # ==========================
161
  CONFIG = {
 
174
  "no_harness": (0, 165, 255), # Orange
175
  "unsafe_posture": (0, 255, 0), # Green
176
  "unsafe_zone": (255, 0, 0), # Blue
177
+ "improper_tool_use": (255, 255, 0) # Cyan
178
  },
179
  "DISPLAY_NAMES": {
180
  "no_helmet": "No Helmet Violation",
 
197
  "unsafe_zone": 0.3,
198
  "improper_tool_use": 0.3
199
  },
200
+ "MIN_VIOLATION_FRAMES": 5,
201
+ "VIOLATION_COOLDOWN": 30.0, # 30 seconds cooldown for same violation type in same area
202
  "WORKER_TRACKING_DURATION": 5.0,
203
  "MAX_PROCESSING_TIME": 60,
204
  "FRAME_SKIP": 2,
205
  "BATCH_SIZE": 16,
206
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
207
+ "TRACK_BUFFER": 30,
208
+ "TRACK_THRESH": 0.3,
209
+ "MATCH_THRESH": 0.7,
210
  "SNAPSHOT_QUALITY": 95,
211
  "MAX_WORKER_DISTANCE": 100
212
  }
 
284
  }
285
 
286
  # Count unique violation types per worker
287
+ worker_violations = {}
288
  for v in violations:
289
  worker_id = v.get("worker_id", "Unknown")
290
  violation_type = v.get("violation", "Unknown")
291
+
292
+ if worker_id not in worker_violations:
293
+ worker_violations[worker_id] = set()
294
  worker_violations[worker_id].add(violation_type)
295
 
296
  # Calculate total penalty
297
+ total_penalty = 0
298
+ for worker_violations_set in worker_violations.values():
299
+ worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
300
+ total_penalty += worker_penalty
301
 
302
  score = max(0, 100 - total_penalty)
303
  return score
 
330
  y_position -= 0.3 * inch
331
 
332
  # Group violations by worker
333
+ worker_violations = {}
334
  for v in violations:
335
  worker_id = v.get("worker_id", "Unknown")
336
+ if worker_id not in worker_violations:
337
+ worker_violations[worker_id] = []
338
  worker_violations[worker_id].append(v)
339
 
340
  c.setFont("Helvetica", 10)
 
485
  return None, ""
486
 
487
  def process_video(video_data):
488
+ """Process video to detect safety violations with enhanced tracking"""
489
  try:
490
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
491
  logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
 
507
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
508
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
509
 
510
+ tracker = EnhancedBYTETracker(
511
+ track_thresh=CONFIG["TRACK_THRESH"],
512
+ track_buffer=CONFIG["TRACK_BUFFER"],
513
+ match_thresh=CONFIG["MATCH_THRESH"],
514
+ frame_rate=fps
515
  )
516
 
517
  violations = []
 
519
  start_time = time.time()
520
  frame_skip = CONFIG["FRAME_SKIP"]
521
  processed_frames = 0
 
522
 
523
  while processed_frames < total_frames:
524
  batch_frames = []
 
543
  batch_frames.append(frame)
544
  batch_indices.append(frame_idx)
545
  processed_frames += 1
 
546
 
547
  if not batch_frames:
548
  break
 
560
  start_time = time.time()
561
 
562
  boxes = result.boxes
563
+ track_inputs = []
564
 
565
  for box in boxes:
566
  cls = int(box.cls)
 
574
  continue
575
 
576
  bbox = box.xywh.cpu().numpy()[0]
577
+ track_inputs.append({
578
+ "bbox": bbox,
579
+ "conf": conf,
580
+ "cls": cls
581
+ })
582
+
583
+ if not track_inputs:
584
+ continue
585
 
586
+ tracked_objects = tracker.update(
587
+ np.array([t["bbox"] for t in track_inputs]),
588
+ np.array([t["conf"] for t in track_inputs]),
589
+ np.array([t["cls"] for t in track_inputs]),
590
+ frame_idx
591
+ )
592
+
593
+ # Process tracked objects for violations
594
+ for obj in tracked_objects:
595
+ worker_id = obj['id']
596
+ label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
597
+ conf = obj['score']
598
+ bbox = obj['bbox']
599
 
600
+ if label is None:
601
+ continue
602
 
603
+ # Check if this is a new violation that should be reported
604
+ if tracker.is_new_violation(worker_id, label, bbox, current_time):
605
+ # This is a new violation that should be reported
606
+ detection = {
607
+ "worker_id": worker_id,
608
+ "violation": label,
609
+ "confidence": round(conf, 2),
610
+ "bounding_box": bbox,
611
+ "timestamp": current_time
612
+ }
 
 
 
 
 
 
 
613
  violations.append(detection)
614
 
615
  # Take snapshot for the new violation
 
628
  )
629
 
630
  # Save snapshot with high quality
631
+ snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
632
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
633
 
634
  cv2.imwrite(
 
638
  )
639
 
640
  snapshots.append({
641
+ "violation": label,
642
+ "worker_id": worker_id,
643
  "timestamp": current_time,
644
  "snapshot_path": snapshot_path,
645
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
646
  })
647
 
648
+ logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
649
 
650
  cap.release()
651
  if os.path.exists(video_path):