PrashanthB461 commited on
Commit
577f505
·
verified ·
1 Parent(s): 357b766

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -464
app.py CHANGED
@@ -1,8 +1,4 @@
1
  import os
2
- import sys
3
- import subprocess
4
- import logging
5
- import warnings
6
  import cv2
7
  import gradio as gr
8
  import torch
@@ -15,162 +11,15 @@ from reportlab.pdfgen import canvas
15
  from reportlab.lib.units import inch
16
  from io import BytesIO
17
  import base64
 
18
  from retrying import retry
19
  import uuid
20
  from multiprocessing import Pool, cpu_count
21
  from functools import partial
22
 
23
- # ========================== # Configuration and Setup # ==========================
24
- os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
25
- os.makedirs('/tmp/Ultralytics', exist_ok=True)
26
-
27
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
- logger = logging.getLogger(__name__)
29
-
30
- # ========================== # ByteTrack Implementation # ==========================
31
- class BYTETracker:
32
- def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
33
- self.track_thresh = track_thresh
34
- self.track_buffer = track_buffer
35
- self.match_thresh = match_thresh
36
- self.frame_rate = frame_rate
37
- self.next_id = 1
38
- self.tracks = {} # Store active tracks
39
- self.worker_history = {} # Track worker positions over time
40
- self.last_positions = {} # Last known positions of workers
41
-
42
- def update(self, dets, scores, cls):
43
- tracks = []
44
- current_time = time.time()
45
-
46
- # Update existing tracks with new detections
47
- for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
48
- if score < self.track_thresh:
49
- continue
50
-
51
- x, y, w, h = det
52
- matched = False
53
- best_iou = 0
54
- best_track_id = None
55
-
56
- # Try to match with existing tracks
57
- for track_id, track_info in self.tracks.items():
58
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
59
- continue
60
-
61
- tx, ty, tw, th = track_info['bbox']
62
- iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
63
-
64
- if iou > self.match_thresh and iou > best_iou:
65
- best_iou = iou
66
- best_track_id = track_id
67
- matched = True
68
-
69
- if matched:
70
- # Update existing track
71
- self.tracks[best_track_id].update({
72
- 'bbox': [x, y, w, h],
73
- 'score': score,
74
- 'cls': cl,
75
- 'last_seen': current_time
76
- })
77
-
78
- # Update position history
79
- if best_track_id not in self.worker_history:
80
- self.worker_history[best_track_id] = []
81
- self.worker_history[best_track_id].append([x, y])
82
- self.last_positions[best_track_id] = [x, y]
83
-
84
- tracks.append({
85
- 'id': best_track_id,
86
- 'bbox': [x, y, w, h],
87
- 'score': score,
88
- 'cls': cl
89
- })
90
- else:
91
- # Create new track
92
- # Check if this detection might be the same worker from a different angle
93
- same_worker = False
94
- for worker_id, last_pos in self.last_positions.items():
95
- if self._is_same_worker([x, y], last_pos):
96
- self.tracks[worker_id] = {
97
- 'bbox': [x, y, w, h],
98
- 'score': score,
99
- 'cls': cl,
100
- 'last_seen': current_time
101
- }
102
- tracks.append({
103
- 'id': worker_id,
104
- 'bbox': [x, y, w, h],
105
- 'score': score,
106
- 'cls': cl
107
- })
108
- same_worker = True
109
- break
110
-
111
- if not same_worker:
112
- self.tracks[self.next_id] = {
113
- 'bbox': [x, y, w, h],
114
- 'score': score,
115
- 'cls': cl,
116
- 'last_seen': current_time
117
- }
118
- self.worker_history[self.next_id] = [[x, y]]
119
- self.last_positions[self.next_id] = [x, y]
120
- tracks.append({
121
- 'id': self.next_id,
122
- 'bbox': [x, y, w, h],
123
- 'score': score,
124
- 'cls': cl
125
- })
126
- self.next_id += 1
127
-
128
- # Clean up old tracks
129
- current_time = time.time()
130
- stale_ids = []
131
- for track_id, track_info in self.tracks.items():
132
- if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
133
- stale_ids.append(track_id)
134
-
135
- for track_id in stale_ids:
136
- del self.tracks[track_id]
137
- if track_id in self.worker_history:
138
- del self.worker_history[track_id]
139
- if track_id in self.last_positions:
140
- del self.last_positions[track_id]
141
-
142
- return tracks
143
-
144
- def _calculate_iou(self, box1, box2):
145
- """Calculate IOU between two boxes"""
146
- x1, y1, w1, h1 = box1
147
- x2, y2, w2, h2 = box2
148
-
149
- # Calculate intersection coordinates
150
- x_left = max(x1 - w1/2, x2 - w2/2)
151
- y_top = max(y1 - h1/2, y2 - h2/2)
152
- x_right = min(x1 + w1/2, x2 + w2/2)
153
- y_bottom = min(y1 + h1/2, y2 + h2/2)
154
-
155
- if x_right < x_left or y_bottom < y_top:
156
- return 0.0
157
-
158
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
159
-
160
- box1_area = w1 * h1
161
- box2_area = w2 * h2
162
-
163
- iou = intersection_area / (box1_area + box2_area - intersection_area)
164
- return iou
165
-
166
- def _is_same_worker(self, pos1, pos2, threshold=100):
167
- """Check if two positions likely belong to the same worker"""
168
- x1, y1 = pos1
169
- x2, y2 = pos2
170
- distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
171
- return distance < threshold
172
-
173
- # ========================== # Optimized Configuration # ==========================
174
  CONFIG = {
175
  "MODEL_PATH": "yolov8_safety.pt",
176
  "FALLBACK_MODEL": "yolov8n.pt",
@@ -183,11 +32,11 @@ CONFIG = {
183
  4: "improper_tool_use"
184
  },
185
  "CLASS_COLORS": {
186
- "no_helmet": (0, 0, 255), # Red
187
- "no_harness": (0, 165, 255), # Orange
188
- "unsafe_posture": (0, 255, 0), # Green
189
- "unsafe_zone": (255, 0, 0), # Blue
190
- "improper_tool_use": (255, 255, 0) # Cyan
191
  },
192
  "DISPLAY_NAMES": {
193
  "no_helmet": "No Helmet Violation",
@@ -204,26 +53,26 @@ CONFIG = {
204
  },
205
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
206
  "CONFIDENCE_THRESHOLDS": {
207
- "no_helmet": 0.5,
208
- "no_harness": 0.3,
209
- "unsafe_posture": 0.3,
210
- "unsafe_zone": 0.3,
211
- "improper_tool_use": 0.3
212
  },
213
- "MIN_VIOLATION_FRAMES": 1,
214
- "VIOLATION_COOLDOWN": 30.0, # Increased cooldown period
215
- "WORKER_TRACKING_DURATION": 5.0,
216
- "MAX_PROCESSING_TIME": 60,
217
- "FRAME_SKIP": 2, # Skip more frames for faster processing
218
- "BATCH_SIZE": 16,
219
- "PARALLEL_WORKERS": max(1, cpu_count() - 1),
220
- "TRACK_BUFFER": 30,
221
- "TRACK_THRESH": 0.3,
222
- "MATCH_THRESH": 0.7,
223
- "SNAPSHOT_QUALITY": 95, # Higher quality for better visibility
224
- "MAX_WORKER_DISTANCE": 100 # Maximum pixel distance to consider same worker
225
  }
226
 
 
 
 
 
 
 
227
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228
  logger.info(f"Using device: {device}")
229
 
@@ -238,9 +87,7 @@ def load_model():
238
  if not os.path.isfile(model_path):
239
  logger.info(f"Downloading fallback model: {model_path}")
240
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
241
-
242
  model = YOLO(model_path).to(device)
243
- logger.info(f"Model classes: {model.names}")
244
  return model
245
  except Exception as e:
246
  logger.error(f"Failed to load model: {e}")
@@ -248,151 +95,118 @@ def load_model():
248
 
249
  model = load_model()
250
 
251
- # ========================== # Helper Functions # ==========================
252
- def preprocess_frame(frame):
253
- """Apply basic preprocessing to enhance detection"""
254
- frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
255
- return frame
256
-
257
  def draw_detections(frame, detections):
258
- """Draw bounding boxes and labels on detection frame with improved visibility"""
259
- result_frame = frame.copy()
260
-
261
  for det in detections:
262
  label = det.get("violation", "Unknown")
263
  confidence = det.get("confidence", 0.0)
264
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
265
- worker_id = det.get("worker_id", "Unknown")
266
-
267
  x1 = int(x - w/2)
268
  y1 = int(y - h/2)
269
  x2 = int(x + w/2)
270
  y2 = int(y + h/2)
271
 
272
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
 
273
 
274
- # Draw thicker rectangle with border
275
- cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
276
-
277
- # Add black background behind text
278
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
279
- text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
280
- cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
281
- cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
282
-
283
- # Add confidence score
284
- conf_text = f"Conf: {confidence:.2f}"
285
- cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
286
-
287
- return result_frame
288
 
289
- def calculate_safety_score(violations):
290
- """Calculate safety score based on detected violations"""
291
- penalties = {
292
- "no_helmet": 25,
293
- "no_harness": 30,
294
- "unsafe_posture": 20,
295
- "unsafe_zone": 35,
296
- "improper_tool_use": 25
297
- }
298
 
299
- # Count unique violation types per worker
300
- worker_violations = {}
301
- for v in violations:
302
- worker_id = v.get("worker_id", "Unknown")
303
- violation_type = v.get("violation", "Unknown")
304
-
305
- if worker_id not in worker_violations:
306
- worker_violations[worker_id] = set()
307
- worker_violations[worker_id].add(violation_type)
 
 
 
 
 
 
 
 
 
308
 
309
- # Calculate total penalty
310
- total_penalty = 0
311
- for worker_violations_set in worker_violations.values():
312
- worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
313
- total_penalty += worker_penalty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
- score = max(0, 100 - total_penalty)
316
- return score
317
 
318
  def generate_violation_pdf(violations, score):
319
- """Generate a PDF report for the detected violations"""
320
  try:
321
  pdf_filename = f"violations_{int(time.time())}.pdf"
322
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
323
  pdf_file = BytesIO()
324
  c = canvas.Canvas(pdf_file, pagesize=letter)
325
-
326
- # Title
327
- c.setFont("Helvetica-Bold", 16)
328
- c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
329
-
330
- # Basic Information
331
  c.setFont("Helvetica", 12)
332
- c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
333
- c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
334
-
335
- # Safety Score
336
- c.setFont("Helvetica-Bold", 14)
337
- c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
338
-
339
- # Violation Summary
340
- y_position = 8.2 * inch
341
- c.setFont("Helvetica-Bold", 12)
342
- c.drawString(1 * inch, y_position, "Summary:")
343
- y_position -= 0.3 * inch
344
-
345
- # Group violations by worker
346
- worker_violations = {}
347
- for v in violations:
348
- worker_id = v.get("worker_id", "Unknown")
349
- if worker_id not in worker_violations:
350
- worker_violations[worker_id] = []
351
- worker_violations[worker_id].append(v)
352
-
353
  c.setFont("Helvetica", 10)
354
- summary_data = {
355
- "Total Workers with Violations": len(worker_violations),
356
- "Total Violations Found": len(violations),
357
- "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
 
 
358
  }
359
-
360
- for key, value in summary_data.items():
361
  c.drawString(1 * inch, y_position, f"{key}: {value}")
362
- y_position -= 0.25 * inch
363
 
364
- # Detailed Violations by Worker
365
- y_position -= 0.5 * inch
366
- c.setFont("Helvetica-Bold", 12)
367
- c.drawString(1 * inch, y_position, "Violations by Worker:")
368
  y_position -= 0.3 * inch
369
-
370
- c.setFont("Helvetica", 10)
371
- for worker_id, worker_vios in worker_violations.items():
372
- c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
373
- y_position -= 0.2 * inch
374
-
375
- for v in worker_vios:
376
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
377
- time_str = f"{v.get('timestamp', 0.0):.2f}s"
378
- conf_str = f"{v.get('confidence', 0.0):.2f}"
379
-
380
- violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
381
- c.drawString(1.2 * inch, y_position, violation_text)
382
- y_position -= 0.2 * inch
383
-
384
  if y_position < 1 * inch:
385
  c.showPage()
386
  c.setFont("Helvetica", 10)
387
  y_position = 10 * inch
388
 
 
389
  c.save()
390
  pdf_file.seek(0)
391
 
392
- # Save PDF file
393
  with open(pdf_path, "wb") as f:
394
  f.write(pdf_file.getvalue())
395
-
396
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
397
  logger.info(f"PDF generated: {public_url}")
398
  return pdf_path, public_url, pdf_file
@@ -400,9 +214,23 @@ def generate_violation_pdf(violations, score):
400
  logger.error(f"Error generating PDF: {e}")
401
  return "", "", None
402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
404
  def connect_to_salesforce():
405
- """Connect to Salesforce with retry logic"""
406
  try:
407
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
408
  logger.info("Connected to Salesforce")
@@ -413,12 +241,10 @@ def connect_to_salesforce():
413
  raise
414
 
415
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
416
- """Upload PDF report to Salesforce"""
417
  try:
418
  if not pdf_file:
419
  logger.error("No PDF file provided for upload")
420
  return ""
421
-
422
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
423
  content_version_data = {
424
  "Title": f"Safety_Violation_Report_{int(time.time())}",
@@ -428,11 +254,9 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
428
  }
429
  content_version = sf.ContentVersion.create(content_version_data)
430
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
431
-
432
  if not result['records']:
433
  logger.error("Failed to retrieve ContentVersion")
434
  return ""
435
-
436
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
437
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
438
  return file_url
@@ -441,23 +265,12 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
441
  return ""
442
 
443
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
444
- """Push violation report to Salesforce"""
445
  try:
446
  sf = connect_to_salesforce()
447
-
448
- # Format violations for Salesforce
449
- violations_text = ""
450
- for v in violations:
451
- display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
452
- worker_id = v.get('worker_id', 'Unknown')
453
- timestamp = v.get('timestamp', 0.0)
454
- confidence = v.get('confidence', 0.0)
455
-
456
- violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
457
-
458
- if not violations_text:
459
- violations_text = "No violations detected."
460
-
461
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
462
 
463
  record_data = {
@@ -467,9 +280,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
467
  "Status__c": "Pending",
468
  "PDF_Report_URL__c": pdf_url
469
  }
470
-
471
  logger.info(f"Creating Salesforce record with data: {record_data}")
472
-
473
  try:
474
  record = sf.Safety_Video_Report__c.create(record_data)
475
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
@@ -477,7 +288,6 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
477
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
478
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
479
  logger.warning(f"Fell back to Account record: {record['id']}")
480
-
481
  record_id = record["id"]
482
 
483
  if pdf_file:
@@ -497,47 +307,46 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
497
  logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
498
  return None, ""
499
 
 
 
 
500
  def process_video(video_data):
501
- """Process video to detect safety violations"""
502
  try:
503
- os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
504
- logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
505
-
506
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
507
  with open(video_path, "wb") as f:
508
  f.write(video_data)
509
  logger.info(f"Video saved: {video_path}")
510
 
 
511
  cap = cv2.VideoCapture(video_path)
512
  if not cap.isOpened():
513
- os.remove(video_path)
514
  raise ValueError("Could not open video file")
515
 
 
516
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
517
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
 
518
  duration = total_frames / fps
519
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
520
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
521
- logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
522
 
523
- tracker = BYTETracker(
524
- track_thresh=CONFIG["TRACK_THRESH"],
525
- track_buffer=CONFIG["TRACK_BUFFER"],
526
- match_thresh=CONFIG["MATCH_THRESH"],
527
- frame_rate=fps
528
- )
529
 
530
- # Track unique violations by worker ID
531
- unique_violations = {} # {worker_id: {violation_type: first_detection_time}}
 
532
  snapshots = []
533
  start_time = time.time()
534
  frame_skip = CONFIG["FRAME_SKIP"]
535
- processed_frames = 0
536
 
537
- while processed_frames < total_frames:
 
538
  batch_frames = []
539
  batch_indices = []
540
 
 
541
  for _ in range(CONFIG["BATCH_SIZE"]):
542
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
543
  if frame_idx >= total_frames:
@@ -547,8 +356,6 @@ def process_video(video_data):
547
  if not ret:
548
  break
549
 
550
- frame = preprocess_frame(frame)
551
-
552
  # Skip frames if needed
553
  for _ in range(frame_skip - 1):
554
  if not cap.grab():
@@ -556,172 +363,127 @@ def process_video(video_data):
556
 
557
  batch_frames.append(frame)
558
  batch_indices.append(frame_idx)
559
- processed_frames += 1
560
 
 
561
  if not batch_frames:
562
  break
563
 
564
- # Process batch with YOLO model
565
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
566
 
 
567
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
568
  current_time = frame_idx / fps
569
 
570
- # Update progress every second
571
- if time.time() - start_time > 1.0:
572
- progress = (processed_frames / total_frames) * 100
573
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
574
  start_time = time.time()
575
 
 
576
  boxes = result.boxes
577
- track_inputs = []
578
-
579
  for box in boxes:
580
  cls = int(box.cls)
581
  conf = float(box.conf)
582
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
583
 
584
- if label is None:
585
- continue
586
-
587
- if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
588
  continue
589
 
590
- bbox = box.xywh.cpu().numpy()[0]
591
- track_inputs.append({
592
- "bbox": bbox,
593
- "conf": conf,
594
- "cls": cls
595
- })
 
 
596
 
597
- if not track_inputs:
598
- continue
599
-
600
- tracked_objects = tracker.update(
601
- np.array([t["bbox"] for t in track_inputs]),
602
- np.array([t["conf"] for t in track_inputs]),
603
- np.array([t["cls"] for t in track_inputs])
604
- )
605
-
606
- # Process tracked objects for violations
607
- for obj in tracked_objects:
608
- worker_id = obj['id']
609
- label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
610
- conf = obj['score']
611
- bbox = obj['bbox']
612
-
613
- if label is None:
614
- continue
615
-
616
- # Initialize worker if not seen before
617
- if worker_id not in unique_violations:
618
- unique_violations[worker_id] = {}
619
-
620
- # Check if this violation type has been recorded for this worker
621
- if label not in unique_violations[worker_id]:
622
- # This is a new violation type for this worker
623
- unique_violations[worker_id][label] = current_time
624
-
625
- # Create detection object
626
- detection = {
627
- "worker_id": worker_id,
628
- "violation": label,
629
- "confidence": round(conf, 2),
630
- "bounding_box": bbox,
631
- "timestamp": current_time
632
- }
633
-
634
- # Take snapshot for the new violation
635
- snapshot_frame = batch_frames[i].copy()
636
- snapshot_frame = draw_detections(snapshot_frame, [detection])
637
-
638
- # Add timestamp to snapshot
639
- cv2.putText(
640
- snapshot_frame,
641
- f"Time: {current_time:.2f}s",
642
- (10, 30),
643
- cv2.FONT_HERSHEY_SIMPLEX,
644
- 0.7,
645
- (255, 255, 255),
646
- 2
647
- )
648
-
649
- # Save snapshot with high quality
650
- snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
651
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
652
-
653
- cv2.imwrite(
654
- snapshot_path,
655
- snapshot_frame,
656
- [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
657
- )
658
-
659
- snapshots.append({
660
- "violation": label,
661
- "worker_id": worker_id,
662
- "timestamp": current_time,
663
- "snapshot_path": snapshot_path,
664
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
665
  })
666
-
667
- logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
  cap.release()
670
- if os.path.exists(video_path):
671
- os.remove(video_path)
672
-
673
  processing_time = time.time() - start_time
674
- logger.info(f"Processing complete in {processing_time:.2f}s")
675
-
676
- # Convert tracked violations to final violation list
677
- violations = []
678
- for worker_id, worker_violations in unique_violations.items():
679
- for label, detection_time in worker_violations.items():
680
- violation = {
681
- "worker_id": worker_id,
682
- "violation": label,
683
- "timestamp": detection_time
684
- }
685
- violations.append(violation)
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
 
687
  if not violations:
688
- logger.info("No violations detected after processing")
689
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
690
  return
691
 
692
- # Calculate safety score
693
  score = calculate_safety_score(violations)
694
-
695
- # Generate PDF report
696
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
697
-
698
- # Push report to Salesforce
699
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
700
 
701
- # Format violations table for display
702
- violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
703
- violation_table += "|-----------|-----------|----------|------------|\n"
704
-
705
- for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
706
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
707
- worker_id = v.get("worker_id", "Unknown")
708
- timestamp = v.get("timestamp", 0.0)
709
- confidence = v.get("confidence", 0.0)
710
-
711
- violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
712
-
713
- # Format snapshots for display
714
- snapshots_text = ""
715
- for s in snapshots:
716
- display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
717
- worker_id = s.get("worker_id", "Unknown")
718
- timestamp = s.get("timestamp", 0.0)
719
-
720
- snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
721
- snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
722
 
723
- if not snapshots_text:
724
- snapshots_text = "No snapshots captured."
 
 
725
 
726
  yield (
727
  violation_table,
@@ -733,27 +495,24 @@ def process_video(video_data):
733
 
734
  except Exception as e:
735
  logger.error(f"Error processing video: {e}", exc_info=True)
736
- if 'video_path' in locals() and os.path.exists(video_path):
737
- os.remove(video_path)
738
  yield f"Error processing video: {e}", "", "", "", ""
739
 
 
 
 
740
  def gradio_interface(video_file):
741
- """Gradio interface for the video processing"""
742
  if not video_file:
743
  return "No file uploaded.", "", "No file uploaded.", "", ""
744
-
745
  try:
746
  with open(video_file, "rb") as f:
747
  video_data = f.read()
748
 
749
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
750
  yield status, score, snapshots_text, record_id, details_url
751
-
752
  except Exception as e:
753
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
754
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
755
 
756
- # ========================== # Gradio Interface # ==========================
757
  interface = gr.Interface(
758
  fn=gradio_interface,
759
  inputs=gr.Video(label="Upload Site Video"),
@@ -765,7 +524,7 @@ interface = gr.Interface(
765
  gr.Textbox(label="Violation Details URL")
766
  ],
767
  title="Worksite Safety Violation Analyzer",
768
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
769
  allow_flagging="never"
770
  )
771
 
 
1
  import os
 
 
 
 
2
  import cv2
3
  import gradio as gr
4
  import torch
 
11
  from reportlab.lib.units import inch
12
  from io import BytesIO
13
  import base64
14
+ import logging
15
  from retrying import retry
16
  import uuid
17
  from multiprocessing import Pool, cpu_count
18
  from functools import partial
19
 
20
+ # ==========================
21
+ # Optimized Configuration
22
+ # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  CONFIG = {
24
  "MODEL_PATH": "yolov8_safety.pt",
25
  "FALLBACK_MODEL": "yolov8n.pt",
 
32
  4: "improper_tool_use"
33
  },
34
  "CLASS_COLORS": {
35
+ "no_helmet": (0, 0, 255), # Red
36
+ "no_harness": (0, 165, 255), # Orange
37
+ "unsafe_posture": (0, 255, 0), # Green
38
+ "unsafe_zone": (255, 0, 0), # Blue
39
+ "improper_tool_use": (255, 255, 0) # Yellow
40
  },
41
  "DISPLAY_NAMES": {
42
  "no_helmet": "No Helmet Violation",
 
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
+ "no_helmet": 0.75, # Increased for stricter helmet detection
57
+ "no_harness": 0.4,
58
+ "unsafe_posture": 0.4,
59
+ "unsafe_zone": 0.4,
60
+ "improper_tool_use": 0.4
61
  },
62
+ "MIN_VIOLATION_FRAMES": 3,
63
+ "WORKER_TRACKING_DURATION": 3.0,
64
+ "MAX_PROCESSING_TIME": 60, # 1 minute limit
65
+ "FRAME_SKIP": 2, # Process every 2nd frame for speed
66
+ "BATCH_SIZE": 16, # Frames per batch
67
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1) # Use all CPU cores except one
 
 
 
 
 
 
68
  }
69
 
70
+ # Setup logging
71
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
72
+ logger = logging.getLogger(__name__)
73
+
74
+ os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
75
+
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  logger.info(f"Using device: {device}")
78
 
 
87
  if not os.path.isfile(model_path):
88
  logger.info(f"Downloading fallback model: {model_path}")
89
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
 
90
  model = YOLO(model_path).to(device)
 
91
  return model
92
  except Exception as e:
93
  logger.error(f"Failed to load model: {e}")
 
95
 
96
  model = load_model()
97
 
98
+ # ==========================
99
+ # Optimized Helper Functions
100
+ # ==========================
 
 
 
101
  def draw_detections(frame, detections):
 
 
 
102
  for det in detections:
103
  label = det.get("violation", "Unknown")
104
  confidence = det.get("confidence", 0.0)
105
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
106
+
 
107
  x1 = int(x - w/2)
108
  y1 = int(y - h/2)
109
  x2 = int(x + w/2)
110
  y2 = int(y + h/2)
111
 
112
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
113
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
114
 
115
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
116
+ cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
117
+ return frame
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def calculate_iou(box1, box2):
120
+ x1, y1, w1, h1 = box1
121
+ x2, y2, w2, h2 = box2
 
 
 
 
 
 
122
 
123
+ x_left = max(x1 - w1/2, x2 - w2/2)
124
+ y_top = max(y1 - h1/2, y2 - h2/2)
125
+ x_right = min(x1 + w1/2, x2 + w2/2)
126
+ y_bottom = min(y1 + h1/2, y2 + h2/2)
127
+
128
+ if x_right < x_left or y_bottom < y_top:
129
+ return 0.0
130
+
131
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
132
+ box1_area = w1 * h1
133
+ box2_area = w2 * h2
134
+ union_area = box1_area + box2_area - intersection_area
135
+
136
+ return intersection_area / union_area
137
+
138
+ def process_frame_batch(frame_batch, frame_indices, fps):
139
+ batch_results = []
140
+ results = model(frame_batch, device=device, conf=0.1, verbose=False)
141
 
142
+ for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
143
+ current_time = frame_idx / fps
144
+ detections = []
145
+
146
+ boxes = result.boxes
147
+ for box in boxes:
148
+ cls = int(box.cls)
149
+ conf = float(box.conf)
150
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
151
+
152
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
153
+ continue
154
+
155
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
156
+ detections.append({
157
+ "frame": frame_idx,
158
+ "violation": label,
159
+ "confidence": round(conf, 2),
160
+ "bounding_box": bbox,
161
+ "timestamp": current_time
162
+ })
163
+
164
+ batch_results.append((frame_idx, detections))
165
 
166
+ return batch_results
 
167
 
168
  def generate_violation_pdf(violations, score):
 
169
  try:
170
  pdf_filename = f"violations_{int(time.time())}.pdf"
171
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
172
  pdf_file = BytesIO()
173
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
 
 
 
 
 
174
  c.setFont("Helvetica", 12)
175
+ c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  c.setFont("Helvetica", 10)
177
+
178
+ y_position = 9.5 * inch
179
+ report_data = {
180
+ "Compliance Score": f"{score}%",
181
+ "Violations Found": len(violations),
182
+ "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
183
  }
184
+ for key, value in report_data.items():
 
185
  c.drawString(1 * inch, y_position, f"{key}: {value}")
186
+ y_position -= 0.3 * inch
187
 
 
 
 
 
188
  y_position -= 0.3 * inch
189
+ c.drawString(1 * inch, y_position, "Violation Details:")
190
+ y_position -= 0.3 * inch
191
+ if not violations:
192
+ c.drawString(1 * inch, y_position, "No violations detected.")
193
+ else:
194
+ for v in violations:
 
195
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
196
+ text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
197
+ c.drawString(1 * inch, y_position, text)
198
+ y_position -= 0.3 * inch
 
 
 
 
199
  if y_position < 1 * inch:
200
  c.showPage()
201
  c.setFont("Helvetica", 10)
202
  y_position = 10 * inch
203
 
204
+ c.showPage()
205
  c.save()
206
  pdf_file.seek(0)
207
 
 
208
  with open(pdf_path, "wb") as f:
209
  f.write(pdf_file.getvalue())
 
210
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
211
  logger.info(f"PDF generated: {public_url}")
212
  return pdf_path, public_url, pdf_file
 
214
  logger.error(f"Error generating PDF: {e}")
215
  return "", "", None
216
 
217
+ def calculate_safety_score(violations):
218
+ penalties = {
219
+ "no_helmet": 25,
220
+ "no_harness": 30,
221
+ "unsafe_posture": 20,
222
+ "unsafe_zone": 35,
223
+ "improper_tool_use": 25
224
+ }
225
+ total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
226
+ score = 100 - total_penalty
227
+ return max(score, 0)
228
+
229
+ # ==========================
230
+ # Salesforce Integration
231
+ # ==========================
232
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
233
  def connect_to_salesforce():
 
234
  try:
235
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
236
  logger.info("Connected to Salesforce")
 
241
  raise
242
 
243
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
 
244
  try:
245
  if not pdf_file:
246
  logger.error("No PDF file provided for upload")
247
  return ""
 
248
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
249
  content_version_data = {
250
  "Title": f"Safety_Violation_Report_{int(time.time())}",
 
254
  }
255
  content_version = sf.ContentVersion.create(content_version_data)
256
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
 
257
  if not result['records']:
258
  logger.error("Failed to retrieve ContentVersion")
259
  return ""
 
260
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
261
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
262
  return file_url
 
265
  return ""
266
 
267
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
 
268
  try:
269
  sf = connect_to_salesforce()
270
+ violations_text = "\n".join(
271
+ f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
272
+ for v in violations
273
+ ) or "No violations detected."
 
 
 
 
 
 
 
 
 
 
274
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
275
 
276
  record_data = {
 
280
  "Status__c": "Pending",
281
  "PDF_Report_URL__c": pdf_url
282
  }
 
283
  logger.info(f"Creating Salesforce record with data: {record_data}")
 
284
  try:
285
  record = sf.Safety_Video_Report__c.create(record_data)
286
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
288
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
289
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
290
  logger.warning(f"Fell back to Account record: {record['id']}")
 
291
  record_id = record["id"]
292
 
293
  if pdf_file:
 
307
  logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
308
  return None, ""
309
 
310
+ # ==========================
311
+ # Fast Video Processing
312
+ # ==========================
313
  def process_video(video_data):
 
314
  try:
315
+ # Create temp video file
 
 
316
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
317
  with open(video_path, "wb") as f:
318
  f.write(video_data)
319
  logger.info(f"Video saved: {video_path}")
320
 
321
+ # Open video file
322
  cap = cv2.VideoCapture(video_path)
323
  if not cap.isOpened():
 
324
  raise ValueError("Could not open video file")
325
 
326
+ # Get video properties
327
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
328
+ fps = cap.get(cv2.CAP_PROP_FPS)
329
+ if fps <= 0:
330
+ fps = 30
331
  duration = total_frames / fps
332
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
333
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
334
 
335
+ logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
 
 
 
 
 
336
 
337
+ workers = []
338
+ violations = []
339
+ helmet_violations = {}
340
  snapshots = []
341
  start_time = time.time()
342
  frame_skip = CONFIG["FRAME_SKIP"]
 
343
 
344
+ # Process frames in batches
345
+ while True:
346
  batch_frames = []
347
  batch_indices = []
348
 
349
+ # Collect frames for this batch
350
  for _ in range(CONFIG["BATCH_SIZE"]):
351
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
352
  if frame_idx >= total_frames:
 
356
  if not ret:
357
  break
358
 
 
 
359
  # Skip frames if needed
360
  for _ in range(frame_skip - 1):
361
  if not cap.grab():
 
363
 
364
  batch_frames.append(frame)
365
  batch_indices.append(frame_idx)
 
366
 
367
+ # Break if no more frames
368
  if not batch_frames:
369
  break
370
 
371
+ # Run batch detection
372
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
373
 
374
+ # Process results for each frame in batch
375
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
376
  current_time = frame_idx / fps
377
 
378
+ # Update progress periodically
379
+ if time.time() - start_time > 1.0: # Update every second
380
+ progress = (frame_idx / total_frames) * 100
381
+ yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
382
  start_time = time.time()
383
 
384
+ # Process detections in this frame
385
  boxes = result.boxes
 
 
386
  for box in boxes:
387
  cls = int(box.cls)
388
  conf = float(box.conf)
389
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
390
 
391
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
 
 
 
392
  continue
393
 
394
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
395
+ detection = {
396
+ "frame": frame_idx,
397
+ "violation": label,
398
+ "confidence": round(conf, 2),
399
+ "bounding_box": bbox,
400
+ "timestamp": current_time
401
+ }
402
 
403
+ # Worker tracking
404
+ worker_id = None
405
+ max_iou = 0
406
+ for idx, worker in enumerate(workers):
407
+ iou = calculate_iou(bbox, worker["bbox"])
408
+ if iou > max_iou and iou > 0.4: # IOU threshold
409
+ max_iou = iou
410
+ worker_id = worker["id"]
411
+ workers[idx]["bbox"] = bbox
412
+ workers[idx]["last_seen"] = current_time
413
+
414
+ if worker_id is None:
415
+ worker_id = len(workers) + 1
416
+ workers.append({
417
+ "id": worker_id,
418
+ "bbox": bbox,
419
+ "first_seen": current_time,
420
+ "last_seen": current_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  })
422
+
423
+ detection["worker_id"] = worker_id
424
+
425
+ # Track helmet violations with stricter criteria
426
+ if detection["violation"] == "no_helmet":
427
+ # Only include high-confidence no_helmet detections
428
+ if conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
429
+ if worker_id not in helmet_violations:
430
+ helmet_violations[worker_id] = []
431
+ helmet_violations[worker_id].append(detection)
432
+ else:
433
+ violations.append(detection)
434
+
435
+ # Remove inactive workers
436
+ workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
437
 
438
  cap.release()
439
+ os.remove(video_path)
 
 
440
  processing_time = time.time() - start_time
441
+ logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
442
+
443
+ # Confirm helmet violations (require multiple detections)
444
+ for worker_id, detections in helmet_violations.items():
445
+ if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
446
+ # Select the detection with the highest confidence
447
+ best_detection = max(detections, key=lambda x: x["confidence"])
448
+ violations.append(best_detection)
449
+
450
+ # Capture snapshot for confirmed no_helmet violation
451
+ cap = cv2.VideoCapture(video_path)
452
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
453
+ ret, snapshot_frame = cap.read()
454
+ if ret:
455
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
456
+ snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
457
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
458
+ cv2.imwrite(snapshot_path, snapshot_frame)
459
+ snapshots.append({
460
+ "violation": "no_helmet",
461
+ "frame": best_detection["frame"],
462
+ "snapshot_path": snapshot_path,
463
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
464
+ })
465
+ cap.release()
466
 
467
+ # Generate results
468
  if not violations:
 
469
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
470
  return
471
 
 
472
  score = calculate_safety_score(violations)
 
 
473
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
 
474
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
475
 
476
+ violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
477
+ violation_table += "|------------------------|---------------|------------|-----------|\n"
478
+ for v in sorted(violations, key=lambda x: x["timestamp"]):
 
 
479
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
480
+ row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
481
+ violation_table += row
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
+ snapshots_text = "\n".join(
484
+ f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
485
+ for s in snapshots
486
+ ) if snapshots else "No snapshots captured."
487
 
488
  yield (
489
  violation_table,
 
495
 
496
  except Exception as e:
497
  logger.error(f"Error processing video: {e}", exc_info=True)
 
 
498
  yield f"Error processing video: {e}", "", "", "", ""
499
 
500
+ # ==========================
501
+ # Gradio Interface
502
+ # ==========================
503
  def gradio_interface(video_file):
 
504
  if not video_file:
505
  return "No file uploaded.", "", "No file uploaded.", "", ""
 
506
  try:
507
  with open(video_file, "rb") as f:
508
  video_data = f.read()
509
 
510
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
511
  yield status, score, snapshots_text, record_id, details_url
 
512
  except Exception as e:
513
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
514
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
515
 
 
516
  interface = gr.Interface(
517
  fn=gradio_interface,
518
  inputs=gr.Video(label="Upload Site Video"),
 
524
  gr.Textbox(label="Violation Details URL")
525
  ],
526
  title="Worksite Safety Violation Analyzer",
527
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
528
  allow_flagging="never"
529
  )
530