PrashanthB461 commited on
Commit
5b97100
·
verified ·
1 Parent(s): 508af1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -171
app.py CHANGED
@@ -36,59 +36,96 @@ class BYTETracker:
36
  self.frame_rate = frame_rate
37
  self.next_id = 1
38
  self.tracks = {} # Store active tracks
 
 
39
 
40
  def update(self, dets, scores, cls):
41
  tracks = []
 
42
 
43
  # Update existing tracks with new detections
44
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
45
  if score < self.track_thresh:
46
- logger.debug(f"Skipping detection with score {score} below threshold {self.track_thresh}")
47
  continue
48
 
49
  x, y, w, h = det
 
 
 
50
 
51
  # Try to match with existing tracks
52
- matched = False
53
  for track_id, track_info in self.tracks.items():
54
- # Simple IOU-based matching
 
 
55
  tx, ty, tw, th = track_info['bbox']
56
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
57
 
58
- if iou > self.match_thresh and track_info['cls'] == cl:
59
- # Update existing track
60
- self.tracks[track_id] = {
61
- 'bbox': [x, y, w, h],
62
- 'score': score,
63
- 'cls': cl,
64
- 'last_seen': time.time()
65
- }
66
- tracks.append({
67
- 'id': track_id,
68
- 'bbox': [x, y, w, h],
69
- 'score': score,
70
- 'cls': cl
71
- })
72
  matched = True
73
- break
74
 
75
- if not matched:
76
- # Create new track
77
- self.tracks[self.next_id] = {
78
  'bbox': [x, y, w, h],
79
  'score': score,
80
  'cls': cl,
81
- 'last_seen': time.time()
82
- }
 
 
 
 
 
 
 
83
  tracks.append({
84
- 'id': self.next_id,
85
  'bbox': [x, y, w, h],
86
  'score': score,
87
  'cls': cl
88
  })
89
- self.next_id += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Remove stale tracks
92
  current_time = time.time()
93
  stale_ids = []
94
  for track_id, track_info in self.tracks.items():
@@ -97,38 +134,41 @@ class BYTETracker:
97
 
98
  for track_id in stale_ids:
99
  del self.tracks[track_id]
 
 
 
 
100
 
101
  return tracks
102
 
103
  def _calculate_iou(self, box1, box2):
104
- """Calculate IOU between two boxes in format [x, y, w, h]"""
105
  x1, y1, w1, h1 = box1
106
  x2, y2, w2, h2 = box2
107
 
108
- # Convert to xmin, ymin, xmax, ymax
109
- xmin1, ymin1 = x1 - w1/2, y1 - h1/2
110
- xmax1, ymax1 = x1 + w1/2, y1 + h1/2
111
- xmin2, ymin2 = x2 - w2/2, y2 - h2/2
112
- xmax2, ymax2 = x2 + w2/2, y2 + h2/2
113
-
114
- # Calculate area of intersection
115
- x_left = max(xmin1, xmin2)
116
- y_top = max(ymin1, ymin2)
117
- x_right = min(xmax1, xmax2)
118
- y_bottom = min(ymax1, ymax2)
119
 
120
  if x_right < x_left or y_bottom < y_top:
121
  return 0.0
122
 
123
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
124
 
125
- # Calculate area of both boxes
126
  box1_area = w1 * h1
127
  box2_area = w2 * h2
128
 
129
- # Calculate IOU
130
  iou = intersection_area / (box1_area + box2_area - intersection_area)
131
  return iou
 
 
 
 
 
 
 
132
 
133
  # ========================== # Optimized Configuration # ==========================
134
  CONFIG = {
@@ -143,11 +183,11 @@ CONFIG = {
143
  4: "improper_tool_use"
144
  },
145
  "CLASS_COLORS": {
146
- "no_helmet": (0, 0, 255), # Red in BGR
147
- "no_harness": (0, 165, 255), # Orange in BGR
148
- "unsafe_posture": (0, 255, 0), # Green in BGR
149
- "unsafe_zone": (255, 0, 0), # Blue in BGR
150
- "improper_tool_use": (255, 255, 0) # Cyan in BGR
151
  },
152
  "DISPLAY_NAMES": {
153
  "no_helmet": "No Helmet Violation",
@@ -162,7 +202,7 @@ CONFIG = {
162
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
163
  "domain": "login"
164
  },
165
- "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Sadio2/resolve/main/static/output/",
166
  "CONFIDENCE_THRESHOLDS": {
167
  "no_helmet": 0.5,
168
  "no_harness": 0.3,
@@ -171,16 +211,17 @@ CONFIG = {
171
  "improper_tool_use": 0.3
172
  },
173
  "MIN_VIOLATION_FRAMES": 1,
174
- "VIOLATION_COOLDOWN": 5.0, # Time in seconds before same violation type can be detected again for the same worker
175
  "WORKER_TRACKING_DURATION": 5.0,
176
  "MAX_PROCESSING_TIME": 60,
177
- "FRAME_SKIP": 1,
178
  "BATCH_SIZE": 16,
179
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
180
  "TRACK_BUFFER": 30,
181
  "TRACK_THRESH": 0.3,
182
  "MATCH_THRESH": 0.7,
183
- "SNAPSHOT_QUALITY": 90 # JPEG quality for snapshots (0-100)
 
184
  }
185
 
186
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -210,7 +251,7 @@ model = load_model()
210
  # ========================== # Helper Functions # ==========================
211
  def preprocess_frame(frame):
212
  """Apply basic preprocessing to enhance detection"""
213
- frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20) # Increase contrast
214
  return frame
215
 
216
  def draw_detections(frame, detections):
@@ -221,6 +262,7 @@ def draw_detections(frame, detections):
221
  label = det.get("violation", "Unknown")
222
  confidence = det.get("confidence", 0.0)
223
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
 
224
 
225
  x1 = int(x - w/2)
226
  y1 = int(y - h/2)
@@ -229,20 +271,18 @@ def draw_detections(frame, detections):
229
 
230
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
231
 
232
- # Draw thicker rectangle with border for better visibility
233
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
234
 
235
- # Add a black background behind text for better readability
236
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
237
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
238
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
239
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
240
 
241
- # Draw worker ID
242
- worker_id = det.get("worker_id", "Unknown")
243
- worker_text = f"Worker: {worker_id}"
244
- cv2.putText(result_frame, worker_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
245
- cv2.putText(result_frame, worker_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
246
 
247
  return result_frame
248
 
@@ -256,15 +296,24 @@ def calculate_safety_score(violations):
256
  "improper_tool_use": 25
257
  }
258
 
259
- # Count unique violation types
260
- unique_violations = set()
261
  for v in violations:
262
- unique_violations.add(v.get("violation", "Unknown"))
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- # Calculate penalty based on unique violation types
265
- total_penalty = sum(penalties.get(v, 0) for v in unique_violations)
266
- score = 100 - total_penalty
267
- return max(score, 0)
268
 
269
  def generate_violation_pdf(violations, score):
270
  """Generate a PDF report for the detected violations"""
@@ -273,25 +322,38 @@ def generate_violation_pdf(violations, score):
273
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
274
  pdf_file = BytesIO()
275
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
 
276
  c.setFont("Helvetica-Bold", 16)
277
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
278
 
 
279
  c.setFont("Helvetica", 12)
280
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
281
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
282
 
 
283
  c.setFont("Helvetica-Bold", 14)
284
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
285
 
 
286
  y_position = 8.2 * inch
287
  c.setFont("Helvetica-Bold", 12)
288
  c.drawString(1 * inch, y_position, "Summary:")
289
  y_position -= 0.3 * inch
290
 
 
 
 
 
 
 
 
 
291
  c.setFont("Helvetica", 10)
292
  summary_data = {
 
293
  "Total Violations Found": len(violations),
294
- "Unique Workers with Violations": len(set(v.get("worker_id", "Unknown") for v in violations)),
295
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
296
  }
297
 
@@ -299,45 +361,38 @@ def generate_violation_pdf(violations, score):
299
  c.drawString(1 * inch, y_position, f"{key}: {value}")
300
  y_position -= 0.25 * inch
301
 
302
- if not violations:
303
- y_position -= 0.3 * inch
304
- c.drawString(1 * inch, y_position, "No violations detected.")
305
- else:
306
- y_position -= 0.5 * inch
307
- c.setFont("Helvetica-Bold", 12)
308
- c.drawString(1 * inch, y_position, "Violation Details:")
309
- y_position -= 0.3 * inch
310
-
311
- c.setFont("Helvetica", 10)
312
- # Sort violations by worker ID and type for better organization
313
- sorted_violations = sorted(violations, key=lambda v: (v.get("worker_id", "Unknown"), v.get("violation", "Unknown")))
314
 
315
- for v in sorted_violations:
316
- worker_id = v.get("worker_id", "Unknown")
317
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
318
- start_time = v.get('start_timestamp', 0.0)
319
- end_time = v.get('end_timestamp', 0.0)
320
- confidence = v.get('confidence', 0.0)
321
 
322
- text = f"Worker ID: {worker_id} - {display_name}"
323
- c.drawString(1 * inch, y_position, text)
324
  y_position -= 0.2 * inch
325
 
326
- details = f" Time: {start_time:.2f}s to {end_time:.2f}s (Confidence: {confidence:.2f})"
327
- c.drawString(1.2 * inch, y_position, details)
328
- y_position -= 0.3 * inch
329
-
330
  if y_position < 1 * inch:
331
  c.showPage()
332
  c.setFont("Helvetica", 10)
333
  y_position = 10 * inch
334
 
335
- c.showPage()
336
  c.save()
337
  pdf_file.seek(0)
338
 
 
339
  with open(pdf_path, "wb") as f:
340
  f.write(pdf_file.getvalue())
 
341
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
342
  logger.info(f"PDF generated: {public_url}")
343
  return pdf_path, public_url, pdf_file
@@ -394,12 +449,11 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
394
  violations_text = ""
395
  for v in violations:
396
  display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
397
- worker_id = v.get('worker_id', 'N/A')
398
- start_time = v.get('start_timestamp', 0.0)
399
- end_time = v.get('end_timestamp', 0.0)
400
  confidence = v.get('confidence', 0.0)
401
 
402
- violations_text += f"Worker {worker_id}: {display_name} ({start_time:.2f}s-{end_time:.2f}s, Conf: {confidence:.2f})\n"
403
 
404
  if not violations_text:
405
  violations_text = "No violations detected."
@@ -474,7 +528,7 @@ def process_video(video_data):
474
  )
475
 
476
  # Track unique violations by worker ID
477
- unique_violations = {} # {worker_id: {violation_type: {first_detection, last_detection, best_confidence, best_frame, cooldown}}}
478
  snapshots = []
479
  start_time = time.time()
480
  frame_skip = CONFIG["FRAME_SKIP"]
@@ -502,6 +556,7 @@ def process_video(video_data):
502
 
503
  batch_frames.append(frame)
504
  batch_indices.append(frame_idx)
 
505
 
506
  if not batch_frames:
507
  break
@@ -510,13 +565,12 @@ def process_video(video_data):
510
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
511
 
512
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
513
- processed_frames += 1
514
  current_time = frame_idx / fps
515
 
516
  # Update progress every second
517
  if time.time() - start_time > 1.0:
518
- progress = (frame_idx / total_frames) * 100
519
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
520
  start_time = time.time()
521
 
522
  boxes = result.boxes
@@ -528,11 +582,9 @@ def process_video(video_data):
528
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
529
 
530
  if label is None:
531
- logger.debug(f"Unknown class ID {cls} detected, skipping")
532
  continue
533
 
534
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
535
- logger.debug(f"Detection for {label} with confidence {conf} below threshold {CONFIG['CONFIDENCE_THRESHOLDS'].get(label, 0.25)}")
536
  continue
537
 
538
  bbox = box.xywh.cpu().numpy()[0]
@@ -542,7 +594,6 @@ def process_video(video_data):
542
  "cls": cls
543
  })
544
 
545
- # Skip tracking if no detections
546
  if not track_inputs:
547
  continue
548
 
@@ -551,8 +602,6 @@ def process_video(video_data):
551
  np.array([t["conf"] for t in track_inputs]),
552
  np.array([t["cls"] for t in track_inputs])
553
  )
554
-
555
- logger.debug(f"Frame {frame_idx}: {len(tracked_objects)} objects tracked")
556
 
557
  # Process tracked objects for violations
558
  for obj in tracked_objects:
@@ -567,90 +616,55 @@ def process_video(video_data):
567
  # Initialize worker if not seen before
568
  if worker_id not in unique_violations:
569
  unique_violations[worker_id] = {}
570
-
571
- # Check if this is a new violation type for this worker or if cooldown has passed
572
- is_new_violation = False
573
  if label not in unique_violations[worker_id]:
574
- # New violation type for this worker
575
- unique_violations[worker_id][label] = {
576
- 'first_detection': current_time,
577
- 'last_detection': current_time,
578
- 'best_confidence': conf,
579
- 'best_frame': frame_idx,
580
- 'best_bbox': bbox,
581
- 'cooldown': current_time + CONFIG["VIOLATION_COOLDOWN"]
582
- }
583
- is_new_violation = True
584
- elif current_time > unique_violations[worker_id][label]['cooldown']:
585
- # Cooldown period has passed, treat as a new violation
586
- unique_violations[worker_id][label] = {
587
- 'first_detection': current_time,
588
- 'last_detection': current_time,
589
- 'best_confidence': conf,
590
- 'best_frame': frame_idx,
591
- 'best_bbox': bbox,
592
- 'cooldown': current_time + CONFIG["VIOLATION_COOLDOWN"]
593
- }
594
- is_new_violation = True
595
- else:
596
- # Update existing violation
597
- violation_info = unique_violations[worker_id][label]
598
- violation_info['last_detection'] = current_time
599
 
600
- # Update if this is a better detection (higher confidence)
601
- if conf > violation_info['best_confidence']:
602
- violation_info['best_confidence'] = conf
603
- violation_info['best_frame'] = frame_idx
604
- violation_info['best_bbox'] = bbox
605
-
606
- # If this is a new violation, capture a snapshot
607
- if is_new_violation:
608
- # Create a detection object for the snapshot
609
  detection = {
610
- "frame": frame_idx,
611
  "violation": label,
612
  "confidence": round(conf, 2),
613
  "bounding_box": bbox,
614
- "timestamp": current_time,
615
- "worker_id": worker_id
616
  }
617
 
618
- # Take a snapshot for the new violation
619
  snapshot_frame = batch_frames[i].copy()
620
  snapshot_frame = draw_detections(snapshot_frame, [detection])
621
 
622
- # Add timestamp to the image
623
  cv2.putText(
624
- snapshot_frame,
625
- f"Time: {current_time:.2f}s",
626
- (10, 30),
627
- cv2.FONT_HERSHEY_SIMPLEX,
628
- 0.7,
629
- (255, 255, 255),
630
  2
631
  )
632
 
633
  # Save snapshot with high quality
634
- snapshot_filename = f"{label}_worker{worker_id}_{int(current_time)}_{frame_idx}.jpg"
635
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
636
 
637
- # Use higher quality for JPEG to ensure better visibility
638
  cv2.imwrite(
639
- snapshot_path,
640
- snapshot_frame,
641
  [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
642
  )
643
 
644
  snapshots.append({
645
  "violation": label,
646
  "worker_id": worker_id,
647
- "frame": frame_idx,
648
  "timestamp": current_time,
649
  "snapshot_path": snapshot_path,
650
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
651
  })
652
 
653
- logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at frame {frame_idx}")
654
 
655
  cap.release()
656
  if os.path.exists(video_path):
@@ -662,15 +676,11 @@ def process_video(video_data):
662
  # Convert tracked violations to final violation list
663
  violations = []
664
  for worker_id, worker_violations in unique_violations.items():
665
- for label, violation_info in worker_violations.items():
666
  violation = {
667
  "worker_id": worker_id,
668
  "violation": label,
669
- "confidence": violation_info['best_confidence'],
670
- "start_timestamp": violation_info['first_detection'],
671
- "end_timestamp": violation_info['last_detection'],
672
- "frame": violation_info['best_frame'],
673
- "bounding_box": violation_info['best_bbox']
674
  }
675
  violations.append(violation)
676
 
@@ -692,20 +702,18 @@ def process_video(video_data):
692
  violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
693
  violation_table += "|-----------|-----------|----------|------------|\n"
694
 
695
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("start_timestamp", 0.0))):
696
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
697
  worker_id = v.get("worker_id", "Unknown")
698
- start_time = v.get('start_timestamp', 0.0)
699
- end_time = v.get('end_timestamp', 0.0)
700
- confidence = v.get('confidence', 0.0)
701
 
702
- row = f"| {display_name} | {worker_id} | {start_time:.2f}-{end_time:.2f} | {confidence:.2f} |\n"
703
- violation_table += row
704
 
705
  # Format snapshots for display
706
  snapshots_text = ""
707
- for i, s in enumerate(snapshots):
708
- display_name = CONFIG["DISPLAY_NAMES"].get(s['violation'], "Unknown")
709
  worker_id = s.get("worker_id", "Unknown")
710
  timestamp = s.get("timestamp", 0.0)
711
 
 
36
  self.frame_rate = frame_rate
37
  self.next_id = 1
38
  self.tracks = {} # Store active tracks
39
+ self.worker_history = {} # Track worker positions over time
40
+ self.last_positions = {} # Last known positions of workers
41
 
42
  def update(self, dets, scores, cls):
43
  tracks = []
44
+ current_time = time.time()
45
 
46
  # Update existing tracks with new detections
47
  for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
48
  if score < self.track_thresh:
 
49
  continue
50
 
51
  x, y, w, h = det
52
+ matched = False
53
+ best_iou = 0
54
+ best_track_id = None
55
 
56
  # Try to match with existing tracks
 
57
  for track_id, track_info in self.tracks.items():
58
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
59
+ continue
60
+
61
  tx, ty, tw, th = track_info['bbox']
62
  iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
63
 
64
+ if iou > self.match_thresh and iou > best_iou:
65
+ best_iou = iou
66
+ best_track_id = track_id
 
 
 
 
 
 
 
 
 
 
 
67
  matched = True
 
68
 
69
+ if matched:
70
+ # Update existing track
71
+ self.tracks[best_track_id].update({
72
  'bbox': [x, y, w, h],
73
  'score': score,
74
  'cls': cl,
75
+ 'last_seen': current_time
76
+ })
77
+
78
+ # Update position history
79
+ if best_track_id not in self.worker_history:
80
+ self.worker_history[best_track_id] = []
81
+ self.worker_history[best_track_id].append([x, y])
82
+ self.last_positions[best_track_id] = [x, y]
83
+
84
  tracks.append({
85
+ 'id': best_track_id,
86
  'bbox': [x, y, w, h],
87
  'score': score,
88
  'cls': cl
89
  })
90
+ else:
91
+ # Create new track
92
+ # Check if this detection might be the same worker from a different angle
93
+ same_worker = False
94
+ for worker_id, last_pos in self.last_positions.items():
95
+ if self._is_same_worker([x, y], last_pos):
96
+ self.tracks[worker_id] = {
97
+ 'bbox': [x, y, w, h],
98
+ 'score': score,
99
+ 'cls': cl,
100
+ 'last_seen': current_time
101
+ }
102
+ tracks.append({
103
+ 'id': worker_id,
104
+ 'bbox': [x, y, w, h],
105
+ 'score': score,
106
+ 'cls': cl
107
+ })
108
+ same_worker = True
109
+ break
110
+
111
+ if not same_worker:
112
+ self.tracks[self.next_id] = {
113
+ 'bbox': [x, y, w, h],
114
+ 'score': score,
115
+ 'cls': cl,
116
+ 'last_seen': current_time
117
+ }
118
+ self.worker_history[self.next_id] = [[x, y]]
119
+ self.last_positions[self.next_id] = [x, y]
120
+ tracks.append({
121
+ 'id': self.next_id,
122
+ 'bbox': [x, y, w, h],
123
+ 'score': score,
124
+ 'cls': cl
125
+ })
126
+ self.next_id += 1
127
 
128
+ # Clean up old tracks
129
  current_time = time.time()
130
  stale_ids = []
131
  for track_id, track_info in self.tracks.items():
 
134
 
135
  for track_id in stale_ids:
136
  del self.tracks[track_id]
137
+ if track_id in self.worker_history:
138
+ del self.worker_history[track_id]
139
+ if track_id in self.last_positions:
140
+ del self.last_positions[track_id]
141
 
142
  return tracks
143
 
144
  def _calculate_iou(self, box1, box2):
145
+ """Calculate IOU between two boxes"""
146
  x1, y1, w1, h1 = box1
147
  x2, y2, w2, h2 = box2
148
 
149
+ # Calculate intersection coordinates
150
+ x_left = max(x1 - w1/2, x2 - w2/2)
151
+ y_top = max(y1 - h1/2, y2 - h2/2)
152
+ x_right = min(x1 + w1/2, x2 + w2/2)
153
+ y_bottom = min(y1 + h1/2, y2 + h2/2)
 
 
 
 
 
 
154
 
155
  if x_right < x_left or y_bottom < y_top:
156
  return 0.0
157
 
158
  intersection_area = (x_right - x_left) * (y_bottom - y_top)
159
 
 
160
  box1_area = w1 * h1
161
  box2_area = w2 * h2
162
 
 
163
  iou = intersection_area / (box1_area + box2_area - intersection_area)
164
  return iou
165
+
166
+ def _is_same_worker(self, pos1, pos2, threshold=100):
167
+ """Check if two positions likely belong to the same worker"""
168
+ x1, y1 = pos1
169
+ x2, y2 = pos2
170
+ distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
171
+ return distance < threshold
172
 
173
  # ========================== # Optimized Configuration # ==========================
174
  CONFIG = {
 
183
  4: "improper_tool_use"
184
  },
185
  "CLASS_COLORS": {
186
+ "no_helmet": (0, 0, 255), # Red
187
+ "no_harness": (0, 165, 255), # Orange
188
+ "unsafe_posture": (0, 255, 0), # Green
189
+ "unsafe_zone": (255, 0, 0), # Blue
190
+ "improper_tool_use": (255, 255, 0) # Cyan
191
  },
192
  "DISPLAY_NAMES": {
193
  "no_helmet": "No Helmet Violation",
 
202
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
203
  "domain": "login"
204
  },
205
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
206
  "CONFIDENCE_THRESHOLDS": {
207
  "no_helmet": 0.5,
208
  "no_harness": 0.3,
 
211
  "improper_tool_use": 0.3
212
  },
213
  "MIN_VIOLATION_FRAMES": 1,
214
+ "VIOLATION_COOLDOWN": 30.0, # Increased cooldown period
215
  "WORKER_TRACKING_DURATION": 5.0,
216
  "MAX_PROCESSING_TIME": 60,
217
+ "FRAME_SKIP": 2, # Skip more frames for faster processing
218
  "BATCH_SIZE": 16,
219
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
220
  "TRACK_BUFFER": 30,
221
  "TRACK_THRESH": 0.3,
222
  "MATCH_THRESH": 0.7,
223
+ "SNAPSHOT_QUALITY": 95, # Higher quality for better visibility
224
+ "MAX_WORKER_DISTANCE": 100 # Maximum pixel distance to consider same worker
225
  }
226
 
227
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
251
  # ========================== # Helper Functions # ==========================
252
  def preprocess_frame(frame):
253
  """Apply basic preprocessing to enhance detection"""
254
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
255
  return frame
256
 
257
  def draw_detections(frame, detections):
 
262
  label = det.get("violation", "Unknown")
263
  confidence = det.get("confidence", 0.0)
264
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
265
+ worker_id = det.get("worker_id", "Unknown")
266
 
267
  x1 = int(x - w/2)
268
  y1 = int(y - h/2)
 
271
 
272
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
273
 
274
+ # Draw thicker rectangle with border
275
  cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
276
 
277
+ # Add black background behind text
278
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
279
  text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
280
  cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
281
  cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
282
 
283
+ # Add confidence score
284
+ conf_text = f"Conf: {confidence:.2f}"
285
+ cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
 
 
286
 
287
  return result_frame
288
 
 
296
  "improper_tool_use": 25
297
  }
298
 
299
+ # Count unique violation types per worker
300
+ worker_violations = {}
301
  for v in violations:
302
+ worker_id = v.get("worker_id", "Unknown")
303
+ violation_type = v.get("violation", "Unknown")
304
+
305
+ if worker_id not in worker_violations:
306
+ worker_violations[worker_id] = set()
307
+ worker_violations[worker_id].add(violation_type)
308
+
309
+ # Calculate total penalty
310
+ total_penalty = 0
311
+ for worker_violations_set in worker_violations.values():
312
+ worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
313
+ total_penalty += worker_penalty
314
 
315
+ score = max(0, 100 - total_penalty)
316
+ return score
 
 
317
 
318
  def generate_violation_pdf(violations, score):
319
  """Generate a PDF report for the detected violations"""
 
322
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
323
  pdf_file = BytesIO()
324
  c = canvas.Canvas(pdf_file, pagesize=letter)
325
+
326
+ # Title
327
  c.setFont("Helvetica-Bold", 16)
328
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
329
 
330
+ # Basic Information
331
  c.setFont("Helvetica", 12)
332
  c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
333
  c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
334
 
335
+ # Safety Score
336
  c.setFont("Helvetica-Bold", 14)
337
  c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
338
 
339
+ # Violation Summary
340
  y_position = 8.2 * inch
341
  c.setFont("Helvetica-Bold", 12)
342
  c.drawString(1 * inch, y_position, "Summary:")
343
  y_position -= 0.3 * inch
344
 
345
+ # Group violations by worker
346
+ worker_violations = {}
347
+ for v in violations:
348
+ worker_id = v.get("worker_id", "Unknown")
349
+ if worker_id not in worker_violations:
350
+ worker_violations[worker_id] = []
351
+ worker_violations[worker_id].append(v)
352
+
353
  c.setFont("Helvetica", 10)
354
  summary_data = {
355
+ "Total Workers with Violations": len(worker_violations),
356
  "Total Violations Found": len(violations),
 
357
  "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
358
  }
359
 
 
361
  c.drawString(1 * inch, y_position, f"{key}: {value}")
362
  y_position -= 0.25 * inch
363
 
364
+ # Detailed Violations by Worker
365
+ y_position -= 0.5 * inch
366
+ c.setFont("Helvetica-Bold", 12)
367
+ c.drawString(1 * inch, y_position, "Violations by Worker:")
368
+ y_position -= 0.3 * inch
369
+
370
+ c.setFont("Helvetica", 10)
371
+ for worker_id, worker_vios in worker_violations.items():
372
+ c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
373
+ y_position -= 0.2 * inch
 
 
374
 
375
+ for v in worker_vios:
 
376
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
377
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
378
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
 
379
 
380
+ violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
381
+ c.drawString(1.2 * inch, y_position, violation_text)
382
  y_position -= 0.2 * inch
383
 
 
 
 
 
384
  if y_position < 1 * inch:
385
  c.showPage()
386
  c.setFont("Helvetica", 10)
387
  y_position = 10 * inch
388
 
 
389
  c.save()
390
  pdf_file.seek(0)
391
 
392
+ # Save PDF file
393
  with open(pdf_path, "wb") as f:
394
  f.write(pdf_file.getvalue())
395
+
396
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
397
  logger.info(f"PDF generated: {public_url}")
398
  return pdf_path, public_url, pdf_file
 
449
  violations_text = ""
450
  for v in violations:
451
  display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
452
+ worker_id = v.get('worker_id', 'Unknown')
453
+ timestamp = v.get('timestamp', 0.0)
 
454
  confidence = v.get('confidence', 0.0)
455
 
456
+ violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
457
 
458
  if not violations_text:
459
  violations_text = "No violations detected."
 
528
  )
529
 
530
  # Track unique violations by worker ID
531
+ unique_violations = {} # {worker_id: {violation_type: first_detection_time}}
532
  snapshots = []
533
  start_time = time.time()
534
  frame_skip = CONFIG["FRAME_SKIP"]
 
556
 
557
  batch_frames.append(frame)
558
  batch_indices.append(frame_idx)
559
+ processed_frames += 1
560
 
561
  if not batch_frames:
562
  break
 
565
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
566
 
567
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
 
568
  current_time = frame_idx / fps
569
 
570
  # Update progress every second
571
  if time.time() - start_time > 1.0:
572
+ progress = (processed_frames / total_frames) * 100
573
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
574
  start_time = time.time()
575
 
576
  boxes = result.boxes
 
582
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
583
 
584
  if label is None:
 
585
  continue
586
 
587
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
 
588
  continue
589
 
590
  bbox = box.xywh.cpu().numpy()[0]
 
594
  "cls": cls
595
  })
596
 
 
597
  if not track_inputs:
598
  continue
599
 
 
602
  np.array([t["conf"] for t in track_inputs]),
603
  np.array([t["cls"] for t in track_inputs])
604
  )
 
 
605
 
606
  # Process tracked objects for violations
607
  for obj in tracked_objects:
 
616
  # Initialize worker if not seen before
617
  if worker_id not in unique_violations:
618
  unique_violations[worker_id] = {}
619
+
620
+ # Check if this violation type has been recorded for this worker
 
621
  if label not in unique_violations[worker_id]:
622
+ # This is a new violation type for this worker
623
+ unique_violations[worker_id][label] = current_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
+ # Create detection object
 
 
 
 
 
 
 
 
626
  detection = {
627
+ "worker_id": worker_id,
628
  "violation": label,
629
  "confidence": round(conf, 2),
630
  "bounding_box": bbox,
631
+ "timestamp": current_time
 
632
  }
633
 
634
+ # Take snapshot for the new violation
635
  snapshot_frame = batch_frames[i].copy()
636
  snapshot_frame = draw_detections(snapshot_frame, [detection])
637
 
638
+ # Add timestamp to snapshot
639
  cv2.putText(
640
+ snapshot_frame,
641
+ f"Time: {current_time:.2f}s",
642
+ (10, 30),
643
+ cv2.FONT_HERSHEY_SIMPLEX,
644
+ 0.7,
645
+ (255, 255, 255),
646
  2
647
  )
648
 
649
  # Save snapshot with high quality
650
+ snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
651
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
652
 
 
653
  cv2.imwrite(
654
+ snapshot_path,
655
+ snapshot_frame,
656
  [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
657
  )
658
 
659
  snapshots.append({
660
  "violation": label,
661
  "worker_id": worker_id,
 
662
  "timestamp": current_time,
663
  "snapshot_path": snapshot_path,
664
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
665
  })
666
 
667
+ logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
668
 
669
  cap.release()
670
  if os.path.exists(video_path):
 
676
  # Convert tracked violations to final violation list
677
  violations = []
678
  for worker_id, worker_violations in unique_violations.items():
679
+ for label, detection_time in worker_violations.items():
680
  violation = {
681
  "worker_id": worker_id,
682
  "violation": label,
683
+ "timestamp": detection_time
 
 
 
 
684
  }
685
  violations.append(violation)
686
 
 
702
  violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
703
  violation_table += "|-----------|-----------|----------|------------|\n"
704
 
705
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
706
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
707
  worker_id = v.get("worker_id", "Unknown")
708
+ timestamp = v.get("timestamp", 0.0)
709
+ confidence = v.get("confidence", 0.0)
 
710
 
711
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
 
712
 
713
  # Format snapshots for display
714
  snapshots_text = ""
715
+ for s in snapshots:
716
+ display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
717
  worker_id = s.get("worker_id", "Unknown")
718
  timestamp = s.get("timestamp", 0.0)
719