PrashanthB461 commited on
Commit
d047708
·
verified ·
1 Parent(s): 81b0d8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +472 -224
app.py CHANGED
@@ -1,4 +1,15 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
  import gradio as gr
4
  import torch
@@ -11,15 +22,162 @@ from reportlab.pdfgen import canvas
11
  from reportlab.lib.units import inch
12
  from io import BytesIO
13
  import base64
14
- import logging
15
  from retrying import retry
16
  import uuid
17
  from multiprocessing import Pool, cpu_count
18
  from functools import partial
19
 
20
- # ==========================
21
- # Optimized Configuration
22
- # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  CONFIG = {
24
  "MODEL_PATH": "yolov8_safety.pt",
25
  "FALLBACK_MODEL": "yolov8n.pt",
@@ -32,11 +190,11 @@ CONFIG = {
32
  4: "improper_tool_use"
33
  },
34
  "CLASS_COLORS": {
35
- "no_helmet": (0, 0, 255), # Red
36
- "no_harness": (0, 165, 255), # Orange
37
- "unsafe_posture": (0, 255, 0), # Green
38
- "unsafe_zone": (255, 0, 0), # Blue
39
- "improper_tool_use": (255, 255, 0) # Yellow
40
  },
41
  "DISPLAY_NAMES": {
42
  "no_helmet": "No Helmet Violation",
@@ -53,26 +211,26 @@ CONFIG = {
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
- "no_helmet": 0.75, # Increased for stricter helmet detection
57
- "no_harness": 0.4,
58
- "unsafe_posture": 0.4,
59
- "unsafe_zone": 0.4,
60
- "improper_tool_use": 0.4
61
  },
62
- "MIN_VIOLATION_FRAMES": 3,
63
- "WORKER_TRACKING_DURATION": 3.0,
64
- "MAX_PROCESSING_TIME": 60, # 1 minute limit
65
- "FRAME_SKIP": 2, # Process every 2nd frame for speed
66
- "BATCH_SIZE": 16, # Frames per batch
67
- "PARALLEL_WORKERS": max(1, cpu_count() - 1) # Use all CPU cores except one
 
 
 
 
 
 
68
  }
69
 
70
- # Setup logging
71
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
72
- logger = logging.getLogger(__name__)
73
-
74
- os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
75
-
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  logger.info(f"Using device: {device}")
78
 
@@ -87,7 +245,9 @@ def load_model():
87
  if not os.path.isfile(model_path):
88
  logger.info(f"Downloading fallback model: {model_path}")
89
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
 
90
  model = YOLO(model_path).to(device)
 
91
  return model
92
  except Exception as e:
93
  logger.error(f"Failed to load model: {e}")
@@ -95,118 +255,151 @@ def load_model():
95
 
96
  model = load_model()
97
 
98
- # ==========================
99
- # Optimized Helper Functions
100
- # ==========================
 
 
 
101
  def draw_detections(frame, detections):
 
 
 
102
  for det in detections:
103
  label = det.get("violation", "Unknown")
104
  confidence = det.get("confidence", 0.0)
105
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
106
-
 
107
  x1 = int(x - w/2)
108
  y1 = int(y - h/2)
109
  x2 = int(x + w/2)
110
  y2 = int(y + h/2)
111
 
112
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
113
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
114
 
115
- display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
116
- cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
117
- return frame
118
-
119
- def calculate_iou(box1, box2):
120
- x1, y1, w1, h1 = box1
121
- x2, y2, w2, h2 = box2
122
-
123
- x_left = max(x1 - w1/2, x2 - w2/2)
124
- y_top = max(y1 - h1/2, y2 - h2/2)
125
- x_right = min(x1 + w1/2, x2 + w2/2)
126
- y_bottom = min(y1 + h1/2, y2 + h2/2)
127
-
128
- if x_right < x_left or y_bottom < y_top:
129
- return 0.0
130
-
131
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
132
- box1_area = w1 * h1
133
- box2_area = w2 * h2
134
- union_area = box1_area + box2_area - intersection_area
135
-
136
- return intersection_area / union_area
137
-
138
- def process_frame_batch(frame_batch, frame_indices, fps):
139
- batch_results = []
140
- results = model(frame_batch, device=device, conf=0.1, verbose=False)
141
-
142
- for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
143
- current_time = frame_idx / fps
144
- detections = []
145
 
146
- boxes = result.boxes
147
- for box in boxes:
148
- cls = int(box.cls)
149
- conf = float(box.conf)
150
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
151
-
152
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
153
- continue
 
 
 
154
 
155
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
156
- detections.append({
157
- "frame": frame_idx,
158
- "violation": label,
159
- "confidence": round(conf, 2),
160
- "bounding_box": bbox,
161
- "timestamp": current_time
162
- })
 
 
 
 
 
 
 
163
 
164
- batch_results.append((frame_idx, detections))
 
 
 
 
 
 
 
 
165
 
166
- return batch_results
 
167
 
168
  def generate_violation_pdf(violations, score):
 
169
  try:
170
  pdf_filename = f"violations_{int(time.time())}.pdf"
171
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
172
  pdf_file = BytesIO()
173
  c = canvas.Canvas(pdf_file, pagesize=letter)
174
- c.setFont("Helvetica", 12)
 
 
175
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  c.setFont("Helvetica", 10)
177
-
178
- y_position = 9.5 * inch
179
- report_data = {
180
- "Compliance Score": f"{score}%",
181
- "Violations Found": len(violations),
182
- "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
183
  }
184
- for key, value in report_data.items():
 
185
  c.drawString(1 * inch, y_position, f"{key}: {value}")
186
- y_position -= 0.3 * inch
187
 
 
 
 
 
188
  y_position -= 0.3 * inch
189
- c.drawString(1 * inch, y_position, "Violation Details:")
190
- y_position -= 0.3 * inch
191
- if not violations:
192
- c.drawString(1 * inch, y_position, "No violations detected.")
193
- else:
194
- for v in violations:
 
195
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
196
- text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
197
- c.drawString(1 * inch, y_position, text)
198
- y_position -= 0.3 * inch
 
 
 
 
199
  if y_position < 1 * inch:
200
  c.showPage()
201
  c.setFont("Helvetica", 10)
202
  y_position = 10 * inch
203
 
204
- c.showPage()
205
  c.save()
206
  pdf_file.seek(0)
207
 
 
208
  with open(pdf_path, "wb") as f:
209
  f.write(pdf_file.getvalue())
 
210
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
211
  logger.info(f"PDF generated: {public_url}")
212
  return pdf_path, public_url, pdf_file
@@ -214,23 +407,9 @@ def generate_violation_pdf(violations, score):
214
  logger.error(f"Error generating PDF: {e}")
215
  return "", "", None
216
 
217
- def calculate_safety_score(violations):
218
- penalties = {
219
- "no_helmet": 25,
220
- "no_harness": 30,
221
- "unsafe_posture": 20,
222
- "unsafe_zone": 35,
223
- "improper_tool_use": 25
224
- }
225
- total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
226
- score = 100 - total_penalty
227
- return max(score, 0)
228
-
229
- # ==========================
230
- # Salesforce Integration
231
- # ==========================
232
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
233
  def connect_to_salesforce():
 
234
  try:
235
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
236
  logger.info("Connected to Salesforce")
@@ -241,10 +420,12 @@ def connect_to_salesforce():
241
  raise
242
 
243
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
 
244
  try:
245
  if not pdf_file:
246
  logger.error("No PDF file provided for upload")
247
  return ""
 
248
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
249
  content_version_data = {
250
  "Title": f"Safety_Violation_Report_{int(time.time())}",
@@ -254,9 +435,11 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
254
  }
255
  content_version = sf.ContentVersion.create(content_version_data)
256
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
 
257
  if not result['records']:
258
  logger.error("Failed to retrieve ContentVersion")
259
  return ""
 
260
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
261
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
262
  return file_url
@@ -265,12 +448,23 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
265
  return ""
266
 
267
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
 
268
  try:
269
  sf = connect_to_salesforce()
270
- violations_text = "\n".join(
271
- f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
272
- for v in violations
273
- ) or "No violations detected."
 
 
 
 
 
 
 
 
 
 
274
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
275
 
276
  record_data = {
@@ -280,7 +474,9 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
280
  "Status__c": "Pending",
281
  "PDF_Report_URL__c": pdf_url
282
  }
 
283
  logger.info(f"Creating Salesforce record with data: {record_data}")
 
284
  try:
285
  record = sf.Safety_Video_Report__c.create(record_data)
286
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
@@ -288,6 +484,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
288
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
289
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
290
  logger.warning(f"Fell back to Account record: {record['id']}")
 
291
  record_id = record["id"]
292
 
293
  if pdf_file:
@@ -307,46 +504,47 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
307
  logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
308
  return None, ""
309
 
310
- # ==========================
311
- # Fast Video Processing
312
- # ==========================
313
  def process_video(video_data):
 
314
  try:
315
- # Create temp video file
 
 
316
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
317
  with open(video_path, "wb") as f:
318
  f.write(video_data)
319
  logger.info(f"Video saved: {video_path}")
320
 
321
- # Open video file
322
  cap = cv2.VideoCapture(video_path)
323
  if not cap.isOpened():
 
324
  raise ValueError("Could not open video file")
325
 
326
- # Get video properties
327
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
328
- fps = cap.get(cv2.CAP_PROP_FPS)
329
- if fps <= 0:
330
- fps = 30
331
  duration = total_frames / fps
332
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
333
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
334
-
335
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
336
 
337
- workers = []
338
- violations = []
339
- helmet_violations = {}
 
 
 
 
 
 
340
  snapshots = []
341
  start_time = time.time()
342
  frame_skip = CONFIG["FRAME_SKIP"]
 
343
 
344
- # Process frames in batches
345
- while True:
346
  batch_frames = []
347
  batch_indices = []
348
 
349
- # Collect frames for this batch
350
  for _ in range(CONFIG["BATCH_SIZE"]):
351
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
352
  if frame_idx >= total_frames:
@@ -356,6 +554,8 @@ def process_video(video_data):
356
  if not ret:
357
  break
358
 
 
 
359
  # Skip frames if needed
360
  for _ in range(frame_skip - 1):
361
  if not cap.grab():
@@ -363,127 +563,172 @@ def process_video(video_data):
363
 
364
  batch_frames.append(frame)
365
  batch_indices.append(frame_idx)
 
366
 
367
- # Break if no more frames
368
  if not batch_frames:
369
  break
370
 
371
- # Run batch detection
372
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
373
 
374
- # Process results for each frame in batch
375
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
376
  current_time = frame_idx / fps
377
 
378
- # Update progress periodically
379
- if time.time() - start_time > 1.0: # Update every second
380
- progress = (frame_idx / total_frames) * 100
381
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
382
  start_time = time.time()
383
 
384
- # Process detections in this frame
385
  boxes = result.boxes
 
 
386
  for box in boxes:
387
  cls = int(box.cls)
388
  conf = float(box.conf)
389
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
390
 
391
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
 
 
 
392
  continue
393
 
394
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
395
- detection = {
396
- "frame": frame_idx,
397
- "violation": label,
398
- "confidence": round(conf, 2),
399
- "bounding_box": bbox,
400
- "timestamp": current_time
401
- }
402
 
403
- # Worker tracking
404
- worker_id = None
405
- max_iou = 0
406
- for idx, worker in enumerate(workers):
407
- iou = calculate_iou(bbox, worker["bbox"])
408
- if iou > max_iou and iou > 0.4: # IOU threshold
409
- max_iou = iou
410
- worker_id = worker["id"]
411
- workers[idx]["bbox"] = bbox
412
- workers[idx]["last_seen"] = current_time
413
-
414
- if worker_id is None:
415
- worker_id = len(workers) + 1
416
- workers.append({
417
- "id": worker_id,
418
- "bbox": bbox,
419
- "first_seen": current_time,
420
- "last_seen": current_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  })
422
-
423
- detection["worker_id"] = worker_id
424
-
425
- # Track helmet violations with stricter criteria
426
- if detection["violation"] == "no_helmet":
427
- # Only include high-confidence no_helmet detections
428
- if conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
429
- if worker_id not in helmet_violations:
430
- helmet_violations[worker_id] = []
431
- helmet_violations[worker_id].append(detection)
432
- else:
433
- violations.append(detection)
434
-
435
- # Remove inactive workers
436
- workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
437
 
438
  cap.release()
439
- os.remove(video_path)
 
 
440
  processing_time = time.time() - start_time
441
- logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
442
-
443
- # Confirm helmet violations (require multiple detections)
444
- for worker_id, detections in helmet_violations.items():
445
- if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
446
- # Select the detection with the highest confidence
447
- best_detection = max(detections, key=lambda x: x["confidence"])
448
- violations.append(best_detection)
449
-
450
- # Capture snapshot for confirmed no_helmet violation
451
- cap = cv2.VideoCapture(video_path)
452
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
453
- ret, snapshot_frame = cap.read()
454
- if ret:
455
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
456
- snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
457
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
458
- cv2.imwrite(snapshot_path, snapshot_frame)
459
- snapshots.append({
460
- "violation": "no_helmet",
461
- "frame": best_detection["frame"],
462
- "snapshot_path": snapshot_path,
463
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
464
- })
465
- cap.release()
466
 
467
- # Generate results
468
  if not violations:
 
469
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
470
  return
471
 
 
472
  score = calculate_safety_score(violations)
 
 
473
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
 
 
474
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
475
 
476
- violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
477
- violation_table += "|------------------------|---------------|------------|-----------|\n"
478
- for v in sorted(violations, key=lambda x: x["timestamp"]):
 
 
479
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
480
- row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
481
- violation_table += row
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
- snapshots_text = "\n".join(
484
- f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
485
- for s in snapshots
486
- ) if snapshots else "No snapshots captured."
487
 
488
  yield (
489
  violation_table,
@@ -495,24 +740,27 @@ def process_video(video_data):
495
 
496
  except Exception as e:
497
  logger.error(f"Error processing video: {e}", exc_info=True)
 
 
498
  yield f"Error processing video: {e}", "", "", "", ""
499
 
500
- # ==========================
501
- # Gradio Interface
502
- # ==========================
503
  def gradio_interface(video_file):
 
504
  if not video_file:
505
  return "No file uploaded.", "", "No file uploaded.", "", ""
 
506
  try:
507
  with open(video_file, "rb") as f:
508
  video_data = f.read()
509
 
510
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
511
  yield status, score, snapshots_text, record_id, details_url
 
512
  except Exception as e:
513
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
514
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
515
 
 
516
  interface = gr.Interface(
517
  fn=gradio_interface,
518
  inputs=gr.Video(label="Upload Site Video"),
@@ -524,7 +772,7 @@ interface = gr.Interface(
524
  gr.Textbox(label="Violation Details URL")
525
  ],
526
  title="Worksite Safety Violation Analyzer",
527
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
528
  allow_flagging="never"
529
  )
530
 
 
1
+ opencv-python>=4.10.0
2
+ gradio>=4.44.0
3
+ torch>=2.4.1
4
+ numpy>=1.26.4
5
+ ultralytics>=8.3.0
6
+ simple-salesforce>=1.12.6
7
+ reportlab>=4.2.2
8
+ retrying>=1.3.4import os
9
+ import sys
10
+ import subprocess
11
+ import logging
12
+ import warnings
13
  import cv2
14
  import gradio as gr
15
  import torch
 
22
  from reportlab.lib.units import inch
23
  from io import BytesIO
24
  import base64
 
25
  from retrying import retry
26
  import uuid
27
  from multiprocessing import Pool, cpu_count
28
  from functools import partial
29
 
30
+ # ========================== # Configuration and Setup # ==========================
31
+ os.environ['YOLO_CONFIG_DIR'] = '/tmp/Ultralytics'
32
+ os.makedirs('/tmp/Ultralytics', exist_ok=True)
33
+
34
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ========================== # ByteTrack Implementation # ==========================
38
+ class BYTETracker:
39
+ def __init__(self, track_thresh=0.3, track_buffer=30, match_thresh=0.7, frame_rate=30):
40
+ self.track_thresh = track_thresh
41
+ self.track_buffer = track_buffer
42
+ self.match_thresh = match_thresh
43
+ self.frame_rate = frame_rate
44
+ self.next_id = 1
45
+ self.tracks = {} # Store active tracks
46
+ self.worker_history = {} # Track worker positions over time
47
+ self.last_positions = {} # Last known positions of workers
48
+
49
+ def update(self, dets, scores, cls):
50
+ tracks = []
51
+ current_time = time.time()
52
+
53
+ # Update existing tracks with new detections
54
+ for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
55
+ if score < self.track_thresh:
56
+ continue
57
+
58
+ x, y, w, h = det
59
+ matched = False
60
+ best_iou = 0
61
+ best_track_id = None
62
+
63
+ # Try to match with existing tracks
64
+ for track_id, track_info in self.tracks.items():
65
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
66
+ continue
67
+
68
+ tx, ty, tw, th = track_info['bbox']
69
+ iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
70
+
71
+ if iou > self.match_thresh and iou > best_iou:
72
+ best_iou = iou
73
+ best_track_id = track_id
74
+ matched = True
75
+
76
+ if matched:
77
+ # Update existing track
78
+ self.tracks[best_track_id].update({
79
+ 'bbox': [x, y, w, h],
80
+ 'score': score,
81
+ 'cls': cl,
82
+ 'last_seen': current_time
83
+ })
84
+
85
+ # Update position history
86
+ if best_track_id not in self.worker_history:
87
+ self.worker_history[best_track_id] = []
88
+ self.worker_history[best_track_id].append([x, y])
89
+ self.last_positions[best_track_id] = [x, y]
90
+
91
+ tracks.append({
92
+ 'id': best_track_id,
93
+ 'bbox': [x, y, w, h],
94
+ 'score': score,
95
+ 'cls': cl
96
+ })
97
+ else:
98
+ # Create new track
99
+ # Check if this detection might be the same worker from a different angle
100
+ same_worker = False
101
+ for worker_id, last_pos in self.last_positions.items():
102
+ if self._is_same_worker([x, y], last_pos):
103
+ self.tracks[worker_id] = {
104
+ 'bbox': [x, y, w, h],
105
+ 'score': score,
106
+ 'cls': cl,
107
+ 'last_seen': current_time
108
+ }
109
+ tracks.append({
110
+ 'id': worker_id,
111
+ 'bbox': [x, y, w, h],
112
+ 'score': score,
113
+ 'cls': cl
114
+ })
115
+ same_worker = True
116
+ break
117
+
118
+ if not same_worker:
119
+ self.tracks[self.next_id] = {
120
+ 'bbox': [x, y, w, h],
121
+ 'score': score,
122
+ 'cls': cl,
123
+ 'last_seen': current_time
124
+ }
125
+ self.worker_history[self.next_id] = [[x, y]]
126
+ self.last_positions[self.next_id] = [x, y]
127
+ tracks.append({
128
+ 'id': self.next_id,
129
+ 'bbox': [x, y, w, h],
130
+ 'score': score,
131
+ 'cls': cl
132
+ })
133
+ self.next_id += 1
134
+
135
+ # Clean up old tracks
136
+ current_time = time.time()
137
+ stale_ids = []
138
+ for track_id, track_info in self.tracks.items():
139
+ if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate:
140
+ stale_ids.append(track_id)
141
+
142
+ for track_id in stale_ids:
143
+ del self.tracks[track_id]
144
+ if track_id in self.worker_history:
145
+ del self.worker_history[track_id]
146
+ if track_id in self.last_positions:
147
+ del self.last_positions[track_id]
148
+
149
+ return tracks
150
+
151
+ def _calculate_iou(self, box1, box2):
152
+ """Calculate IOU between two boxes"""
153
+ x1, y1, w1, h1 = box1
154
+ x2, y2, w2, h2 = box2
155
+
156
+ # Calculate intersection coordinates
157
+ x_left = max(x1 - w1/2, x2 - w2/2)
158
+ y_top = max(y1 - h1/2, y2 - h2/2)
159
+ x_right = min(x1 + w1/2, x2 + w2/2)
160
+ y_bottom = min(y1 + h1/2, y2 + h2/2)
161
+
162
+ if x_right < x_left or y_bottom < y_top:
163
+ return 0.0
164
+
165
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
166
+
167
+ box1_area = w1 * h1
168
+ box2_area = w2 * h2
169
+
170
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
171
+ return iou
172
+
173
+ def _is_same_worker(self, pos1, pos2, threshold=100):
174
+ """Check if two positions likely belong to the same worker"""
175
+ x1, y1 = pos1
176
+ x2, y2 = pos2
177
+ distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
178
+ return distance < threshold
179
+
180
+ # ========================== # Optimized Configuration # ==========================
181
  CONFIG = {
182
  "MODEL_PATH": "yolov8_safety.pt",
183
  "FALLBACK_MODEL": "yolov8n.pt",
 
190
  4: "improper_tool_use"
191
  },
192
  "CLASS_COLORS": {
193
+ "no_helmet": (0, 0, 255), # Red
194
+ "no_harness": (0, 165, 255), # Orange
195
+ "unsafe_posture": (0, 255, 0), # Green
196
+ "unsafe_zone": (255, 0, 0), # Blue
197
+ "improper_tool_use": (255, 255, 0) # Cyan
198
  },
199
  "DISPLAY_NAMES": {
200
  "no_helmet": "No Helmet Violation",
 
211
  },
212
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
213
  "CONFIDENCE_THRESHOLDS": {
214
+ "no_helmet": 0.5,
215
+ "no_harness": 0.3,
216
+ "unsafe_posture": 0.3,
217
+ "unsafe_zone": 0.3,
218
+ "improper_tool_use": 0.3
219
  },
220
+ "MIN_VIOLATION_FRAMES": 1,
221
+ "VIOLATION_COOLDOWN": 30.0, # Increased cooldown period
222
+ "WORKER_TRACKING_DURATION": 5.0,
223
+ "MAX_PROCESSING_TIME": 60,
224
+ "FRAME_SKIP": 2, # Skip more frames for faster processing
225
+ "BATCH_SIZE": 16,
226
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1),
227
+ "TRACK_BUFFER": 30,
228
+ "TRACK_THRESH": 0.3,
229
+ "MATCH_THRESH": 0.7,
230
+ "SNAPSHOT_QUALITY": 95, # Higher quality for better visibility
231
+ "MAX_WORKER_DISTANCE": 100 # Maximum pixel distance to consider same worker
232
  }
233
 
 
 
 
 
 
 
234
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
235
  logger.info(f"Using device: {device}")
236
 
 
245
  if not os.path.isfile(model_path):
246
  logger.info(f"Downloading fallback model: {model_path}")
247
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
248
+
249
  model = YOLO(model_path).to(device)
250
+ logger.info(f"Model classes: {model.names}")
251
  return model
252
  except Exception as e:
253
  logger.error(f"Failed to load model: {e}")
 
255
 
256
  model = load_model()
257
 
258
+ # ========================== # Helper Functions # ==========================
259
+ def preprocess_frame(frame):
260
+ """Apply basic preprocessing to enhance detection"""
261
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=20)
262
+ return frame
263
+
264
  def draw_detections(frame, detections):
265
+ """Draw bounding boxes and labels on detection frame with improved visibility"""
266
+ result_frame = frame.copy()
267
+
268
  for det in detections:
269
  label = det.get("violation", "Unknown")
270
  confidence = det.get("confidence", 0.0)
271
  x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
272
+ worker_id = det.get("worker_id", "Unknown")
273
+
274
  x1 = int(x - w/2)
275
  y1 = int(y - h/2)
276
  x2 = int(x + w/2)
277
  y2 = int(y + h/2)
278
 
279
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
 
280
 
281
+ # Draw thicker rectangle with border
282
+ cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # Add black background behind text
285
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
286
+ text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
287
+ cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
288
+ cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
289
+
290
+ # Add confidence score
291
+ conf_text = f"Conf: {confidence:.2f}"
292
+ cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
293
+
294
+ return result_frame
295
 
296
+ def calculate_safety_score(violations):
297
+ """Calculate safety score based on detected violations"""
298
+ penalties = {
299
+ "no_helmet": 25,
300
+ "no_harness": 30,
301
+ "unsafe_posture": 20,
302
+ "unsafe_zone": 35,
303
+ "improper_tool_use": 25
304
+ }
305
+
306
+ # Count unique violation types per worker
307
+ worker_violations = {}
308
+ for v in violations:
309
+ worker_id = v.get("worker_id", "Unknown")
310
+ violation_type = v.get("violation", "Unknown")
311
 
312
+ if worker_id not in worker_violations:
313
+ worker_violations[worker_id] = set()
314
+ worker_violations[worker_id].add(violation_type)
315
+
316
+ # Calculate total penalty
317
+ total_penalty = 0
318
+ for worker_violations_set in worker_violations.values():
319
+ worker_penalty = sum(penalties.get(v, 0) for v in worker_violations_set)
320
+ total_penalty += worker_penalty
321
 
322
+ score = max(0, 100 - total_penalty)
323
+ return score
324
 
325
  def generate_violation_pdf(violations, score):
326
+ """Generate a PDF report for the detected violations"""
327
  try:
328
  pdf_filename = f"violations_{int(time.time())}.pdf"
329
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
330
  pdf_file = BytesIO()
331
  c = canvas.Canvas(pdf_file, pagesize=letter)
332
+
333
+ # Title
334
+ c.setFont("Helvetica-Bold", 16)
335
  c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
336
+
337
+ # Basic Information
338
+ c.setFont("Helvetica", 12)
339
+ c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
340
+ c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
341
+
342
+ # Safety Score
343
+ c.setFont("Helvetica-Bold", 14)
344
+ c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
345
+
346
+ # Violation Summary
347
+ y_position = 8.2 * inch
348
+ c.setFont("Helvetica-Bold", 12)
349
+ c.drawString(1 * inch, y_position, "Summary:")
350
+ y_position -= 0.3 * inch
351
+
352
+ # Group violations by worker
353
+ worker_violations = {}
354
+ for v in violations:
355
+ worker_id = v.get("worker_id", "Unknown")
356
+ if worker_id not in worker_violations:
357
+ worker_violations[worker_id] = []
358
+ worker_violations[worker_id].append(v)
359
+
360
  c.setFont("Helvetica", 10)
361
+ summary_data = {
362
+ "Total Workers with Violations": len(worker_violations),
363
+ "Total Violations Found": len(violations),
364
+ "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
 
 
365
  }
366
+
367
+ for key, value in summary_data.items():
368
  c.drawString(1 * inch, y_position, f"{key}: {value}")
369
+ y_position -= 0.25 * inch
370
 
371
+ # Detailed Violations by Worker
372
+ y_position -= 0.5 * inch
373
+ c.setFont("Helvetica-Bold", 12)
374
+ c.drawString(1 * inch, y_position, "Violations by Worker:")
375
  y_position -= 0.3 * inch
376
+
377
+ c.setFont("Helvetica", 10)
378
+ for worker_id, worker_vios in worker_violations.items():
379
+ c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
380
+ y_position -= 0.2 * inch
381
+
382
+ for v in worker_vios:
383
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
384
+ time_str = f"{v.get('timestamp', 0.0):.2f}s"
385
+ conf_str = f"{v.get('confidence', 0.0):.2f}"
386
+
387
+ violation_text = f" - {display_name} at {time_str} (Confidence: {conf_str})"
388
+ c.drawString(1.2 * inch, y_position, violation_text)
389
+ y_position -= 0.2 * inch
390
+
391
  if y_position < 1 * inch:
392
  c.showPage()
393
  c.setFont("Helvetica", 10)
394
  y_position = 10 * inch
395
 
 
396
  c.save()
397
  pdf_file.seek(0)
398
 
399
+ # Save PDF file
400
  with open(pdf_path, "wb") as f:
401
  f.write(pdf_file.getvalue())
402
+
403
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
404
  logger.info(f"PDF generated: {public_url}")
405
  return pdf_path, public_url, pdf_file
 
407
  logger.error(f"Error generating PDF: {e}")
408
  return "", "", None
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
411
  def connect_to_salesforce():
412
+ """Connect to Salesforce with retry logic"""
413
  try:
414
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
415
  logger.info("Connected to Salesforce")
 
420
  raise
421
 
422
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
423
+ """Upload PDF report to Salesforce"""
424
  try:
425
  if not pdf_file:
426
  logger.error("No PDF file provided for upload")
427
  return ""
428
+
429
  encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
430
  content_version_data = {
431
  "Title": f"Safety_Violation_Report_{int(time.time())}",
 
435
  }
436
  content_version = sf.ContentVersion.create(content_version_data)
437
  result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
438
+
439
  if not result['records']:
440
  logger.error("Failed to retrieve ContentVersion")
441
  return ""
442
+
443
  file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
444
  logger.info(f"PDF uploaded to Salesforce: {file_url}")
445
  return file_url
 
448
  return ""
449
 
450
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
451
+ """Push violation report to Salesforce"""
452
  try:
453
  sf = connect_to_salesforce()
454
+
455
+ # Format violations for Salesforce
456
+ violations_text = ""
457
+ for v in violations:
458
+ display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
459
+ worker_id = v.get('worker_id', 'Unknown')
460
+ timestamp = v.get('timestamp', 0.0)
461
+ confidence = v.get('confidence', 0.0)
462
+
463
+ violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
464
+
465
+ if not violations_text:
466
+ violations_text = "No violations detected."
467
+
468
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
469
 
470
  record_data = {
 
474
  "Status__c": "Pending",
475
  "PDF_Report_URL__c": pdf_url
476
  }
477
+
478
  logger.info(f"Creating Salesforce record with data: {record_data}")
479
+
480
  try:
481
  record = sf.Safety_Video_Report__c.create(record_data)
482
  logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
 
484
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
485
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
486
  logger.warning(f"Fell back to Account record: {record['id']}")
487
+
488
  record_id = record["id"]
489
 
490
  if pdf_file:
 
504
  logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
505
  return None, ""
506
 
 
 
 
507
  def process_video(video_data):
508
+ """Process video to detect safety violations"""
509
  try:
510
+ os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
511
+ logger.info(f"Output directory ensured: {CONFIG['OUTPUT_DIR']}")
512
+
513
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
514
  with open(video_path, "wb") as f:
515
  f.write(video_data)
516
  logger.info(f"Video saved: {video_path}")
517
 
 
518
  cap = cv2.VideoCapture(video_path)
519
  if not cap.isOpened():
520
+ os.remove(video_path)
521
  raise ValueError("Could not open video file")
522
 
 
523
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
524
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
 
525
  duration = total_frames / fps
526
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
527
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
528
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
529
 
530
+ tracker = BYTETracker(
531
+ track_thresh=CONFIG["TRACK_THRESH"],
532
+ track_buffer=CONFIG["TRACK_BUFFER"],
533
+ match_thresh=CONFIG["MATCH_THRESH"],
534
+ frame_rate=fps
535
+ )
536
+
537
+ # Track unique violations by worker ID
538
+ unique_violations = {} # {worker_id: {violation_type: first_detection_time}}
539
  snapshots = []
540
  start_time = time.time()
541
  frame_skip = CONFIG["FRAME_SKIP"]
542
+ processed_frames = 0
543
 
544
+ while processed_frames < total_frames:
 
545
  batch_frames = []
546
  batch_indices = []
547
 
 
548
  for _ in range(CONFIG["BATCH_SIZE"]):
549
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
550
  if frame_idx >= total_frames:
 
554
  if not ret:
555
  break
556
 
557
+ frame = preprocess_frame(frame)
558
+
559
  # Skip frames if needed
560
  for _ in range(frame_skip - 1):
561
  if not cap.grab():
 
563
 
564
  batch_frames.append(frame)
565
  batch_indices.append(frame_idx)
566
+ processed_frames += 1
567
 
 
568
  if not batch_frames:
569
  break
570
 
571
+ # Process batch with YOLO model
572
  results = model(batch_frames, device=device, conf=0.1, verbose=False)
573
 
 
574
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
575
  current_time = frame_idx / fps
576
 
577
+ # Update progress every second
578
+ if time.time() - start_time > 1.0:
579
+ progress = (processed_frames / total_frames) * 100
580
+ yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames})", "", "", "", ""
581
  start_time = time.time()
582
 
 
583
  boxes = result.boxes
584
+ track_inputs = []
585
+
586
  for box in boxes:
587
  cls = int(box.cls)
588
  conf = float(box.conf)
589
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
590
 
591
+ if label is None:
592
+ continue
593
+
594
+ if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
595
  continue
596
 
597
+ bbox = box.xywh.cpu().numpy()[0]
598
+ track_inputs.append({
599
+ "bbox": bbox,
600
+ "conf": conf,
601
+ "cls": cls
602
+ })
 
 
603
 
604
+ if not track_inputs:
605
+ continue
606
+
607
+ tracked_objects = tracker.update(
608
+ np.array([t["bbox"] for t in track_inputs]),
609
+ np.array([t["conf"] for t in track_inputs]),
610
+ np.array([t["cls"] for t in track_inputs])
611
+ )
612
+
613
+ # Process tracked objects for violations
614
+ for obj in tracked_objects:
615
+ worker_id = obj['id']
616
+ label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
617
+ conf = obj['score']
618
+ bbox = obj['bbox']
619
+
620
+ if label is None:
621
+ continue
622
+
623
+ # Initialize worker if not seen before
624
+ if worker_id not in unique_violations:
625
+ unique_violations[worker_id] = {}
626
+
627
+ # Check if this violation type has been recorded for this worker
628
+ if label not in unique_violations[worker_id]:
629
+ # This is a new violation type for this worker
630
+ unique_violations[worker_id][label] = current_time
631
+
632
+ # Create detection object
633
+ detection = {
634
+ "worker_id": worker_id,
635
+ "violation": label,
636
+ "confidence": round(conf, 2),
637
+ "bounding_box": bbox,
638
+ "timestamp": current_time
639
+ }
640
+
641
+ # Take snapshot for the new violation
642
+ snapshot_frame = batch_frames[i].copy()
643
+ snapshot_frame = draw_detections(snapshot_frame, [detection])
644
+
645
+ # Add timestamp to snapshot
646
+ cv2.putText(
647
+ snapshot_frame,
648
+ f"Time: {current_time:.2f}s",
649
+ (10, 30),
650
+ cv2.FONT_HERSHEY_SIMPLEX,
651
+ 0.7,
652
+ (255, 255, 255),
653
+ 2
654
+ )
655
+
656
+ # Save snapshot with high quality
657
+ snapshot_filename = f"violation_{label}_worker{worker_id}_{int(current_time*100)}.jpg"
658
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
659
+
660
+ cv2.imwrite(
661
+ snapshot_path,
662
+ snapshot_frame,
663
+ [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
664
+ )
665
+
666
+ snapshots.append({
667
+ "violation": label,
668
+ "worker_id": worker_id,
669
+ "timestamp": current_time,
670
+ "snapshot_path": snapshot_path,
671
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
672
  })
673
+
674
+ logger.info(f"Captured snapshot for {label} violation by worker {worker_id} at {current_time:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
  cap.release()
677
+ if os.path.exists(video_path):
678
+ os.remove(video_path)
679
+
680
  processing_time = time.time() - start_time
681
+ logger.info(f"Processing complete in {processing_time:.2f}s")
682
+
683
+ # Convert tracked violations to final violation list
684
+ violations = []
685
+ for worker_id, worker_violations in unique_violations.items():
686
+ for label, detection_time in worker_violations.items():
687
+ violation = {
688
+ "worker_id": worker_id,
689
+ "violation": label,
690
+ "timestamp": detection_time
691
+ }
692
+ violations.append(violation)
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
 
694
  if not violations:
695
+ logger.info("No violations detected after processing")
696
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
697
  return
698
 
699
+ # Calculate safety score
700
  score = calculate_safety_score(violations)
701
+
702
+ # Generate PDF report
703
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
704
+
705
+ # Push report to Salesforce
706
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
707
 
708
+ # Format violations table for display
709
+ violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
710
+ violation_table += "|-----------|-----------|----------|------------|\n"
711
+
712
+ for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
713
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
714
+ worker_id = v.get("worker_id", "Unknown")
715
+ timestamp = v.get("timestamp", 0.0)
716
+ confidence = v.get("confidence", 0.0)
717
+
718
+ violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
719
+
720
+ # Format snapshots for display
721
+ snapshots_text = ""
722
+ for s in snapshots:
723
+ display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
724
+ worker_id = s.get("worker_id", "Unknown")
725
+ timestamp = s.get("timestamp", 0.0)
726
+
727
+ snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
728
+ snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"
729
 
730
+ if not snapshots_text:
731
+ snapshots_text = "No snapshots captured."
 
 
732
 
733
  yield (
734
  violation_table,
 
740
 
741
  except Exception as e:
742
  logger.error(f"Error processing video: {e}", exc_info=True)
743
+ if 'video_path' in locals() and os.path.exists(video_path):
744
+ os.remove(video_path)
745
  yield f"Error processing video: {e}", "", "", "", ""
746
 
 
 
 
747
  def gradio_interface(video_file):
748
+ """Gradio interface for the video processing"""
749
  if not video_file:
750
  return "No file uploaded.", "", "No file uploaded.", "", ""
751
+
752
  try:
753
  with open(video_file, "rb") as f:
754
  video_data = f.read()
755
 
756
  for status, score, snapshots_text, record_id, details_url in process_video(video_data):
757
  yield status, score, snapshots_text, record_id, details_url
758
+
759
  except Exception as e:
760
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
761
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
762
 
763
+ # ========================== # Gradio Interface # ==========================
764
  interface = gr.Interface(
765
  fn=gradio_interface,
766
  inputs=gr.Video(label="Upload Site Video"),
 
772
  gr.Textbox(label="Violation Details URL")
773
  ],
774
  title="Worksite Safety Violation Analyzer",
775
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
776
  allow_flagging="never"
777
  )
778