PrashanthB461 commited on
Commit
1df6374
·
verified ·
1 Parent(s): 188385c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -173
app.py CHANGED
@@ -18,7 +18,7 @@ from multiprocessing import Pool, cpu_count
18
  from functools import partial
19
 
20
  # ==========================
21
- # Ultra-Fast Configuration
22
  # ==========================
23
  CONFIG = {
24
  "MODEL_PATH": "yolov8_safety.pt",
@@ -32,16 +32,16 @@ CONFIG = {
32
  4: "improper_tool_use"
33
  },
34
  "CLASS_COLORS": {
35
- "no_helmet": (0, 0, 255),
36
- "no_harness": (0, 165, 255),
37
- "unsafe_posture": (0, 255, 0),
38
- "unsafe_zone": (255, 0, 0),
39
- "improper_tool_use": (255, 255, 0)
40
  },
41
  "DISPLAY_NAMES": {
42
  "no_helmet": "No Helmet Violation",
43
  "no_harness": "No Harness Violation",
44
- "unsafe_posture": "Unsafe Posture Violation",
45
  "unsafe_zone": "Unsafe Zone Entry",
46
  "improper_tool_use": "Improper Tool Use"
47
  },
@@ -53,21 +53,17 @@ CONFIG = {
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
- "no_helmet": 0.55,
57
- "no_harness": 0.1,
58
- "unsafe_posture": 0.1,
59
- "unsafe_zone": 0.1,
60
- "improper_tool_use": 0.1
61
  },
62
- "IOU_THRESHOLD": 0.45,
63
- "MIN_VIOLATION_FRAMES": 2,
64
- "HELMET_CONFIDENCE_THRESHOLD": 0.6,
65
- "WORKER_TRACKING_DURATION": 2.5,
66
- "MAX_PROCESSING_TIME": 30,
67
- "PARALLEL_WORKERS": 2, # Reduced for Hugging Face Spaces
68
- "CHUNK_SIZE": 20, # Increased for faster batch processing
69
- "FRAME_SAMPLE_RATE": 2, # Process every 2nd frame
70
- "MAX_FRAME_WIDTH": 640 # Resize frames to this width
71
  }
72
 
73
  # Setup logging
@@ -101,15 +97,6 @@ model = load_model()
101
  # ==========================
102
  # Optimized Helper Functions
103
  # ==========================
104
- def resize_frame(frame, max_width):
105
- height, width = frame.shape[:2]
106
- if width > max_width:
107
- scale = max_width / width
108
- new_width = int(width * scale)
109
- new_height = int(height * scale)
110
- frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
111
- return frame
112
-
113
  def draw_detections(frame, detections):
114
  for det in detections:
115
  label = det.get("violation", "Unknown")
@@ -125,14 +112,14 @@ def draw_detections(frame, detections):
125
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
126
 
127
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
128
- cv2.putText(frame, display_text, (x1, y1-10),
129
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
130
  return frame
131
 
132
  def calculate_iou(box1, box2):
133
  x1, y1, w1, h1 = box1
134
  x2, y2, w2, h2 = box2
135
 
 
136
  x_left = max(x1 - w1/2, x2 - w2/2)
137
  y_top = max(y1 - h1/2, y2 - h2/2)
138
  x_right = min(x1 + w1/2, x2 + w2/2)
@@ -150,9 +137,7 @@ def calculate_iou(box1, box2):
150
 
151
  def process_frame_batch(frame_batch, frame_indices, fps):
152
  batch_results = []
153
- start_inference = time.time()
154
- results = model(frame_batch, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"], verbose=False)
155
- logger.info(f"Inference time for batch of {len(frame_batch)} frames: {time.time() - start_inference:.2f}s")
156
 
157
  for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
158
  current_time = frame_idx / fps
@@ -164,7 +149,7 @@ def process_frame_batch(frame_batch, frame_indices, fps):
164
  conf = float(box.conf)
165
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
166
 
167
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.2):
168
  continue
169
 
170
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
@@ -182,7 +167,6 @@ def process_frame_batch(frame_batch, frame_indices, fps):
182
 
183
  def generate_violation_pdf(violations, score):
184
  try:
185
- start_pdf = time.time()
186
  pdf_filename = f"violations_{int(time.time())}.pdf"
187
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
188
  pdf_file = BytesIO()
@@ -224,7 +208,6 @@ def generate_violation_pdf(violations, score):
224
  with open(pdf_path, "wb") as f:
225
  f.write(pdf_file.getvalue())
226
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
227
- logger.info(f"PDF generation time: {time.time() - start_pdf:.2f}s")
228
  logger.info(f"PDF generated: {public_url}")
229
  return pdf_path, public_url, pdf_file
230
  except Exception as e:
@@ -244,11 +227,10 @@ def calculate_safety_score(violations):
244
  return max(score, 0)
245
 
246
  # ==========================
247
- # Ultra-Fast Video Processing
248
  # ==========================
249
  def process_video(video_data):
250
  try:
251
- start_time = time.time()
252
  # Create temp video file
253
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
254
  with open(video_path, "wb") as f:
@@ -256,7 +238,6 @@ def process_video(video_data):
256
  logger.info(f"Video saved: {video_path}")
257
 
258
  # Open video file
259
- start_read = time.time()
260
  cap = cv2.VideoCapture(video_path)
261
  if not cap.isOpened():
262
  raise ValueError("Could not open video file")
@@ -265,163 +246,107 @@ def process_video(video_data):
265
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
266
  fps = cap.get(cv2.CAP_PROP_FPS)
267
  if fps <= 0:
268
- fps = 30
269
  duration = total_frames / fps
270
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
271
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
272
 
273
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
274
 
275
- # Check if processing will exceed time limit
276
- if duration > CONFIG["MAX_PROCESSING_TIME"]:
277
- logger.warning(f"Video duration {duration:.2f}s exceeds max processing time {CONFIG['MAX_PROCESSING_TIME']}s")
278
- cap.release()
279
- os.remove(video_path)
280
- yield "Video duration too long. Please upload a shorter video.", "", "", "", ""
281
- return
282
-
283
- # Estimate processing feasibility
284
- estimated_frames = total_frames // CONFIG["FRAME_SAMPLE_RATE"]
285
- if estimated_frames * 0.1 > CONFIG["MAX_PROCESSING_TIME"]: # Rough estimate: 0.1s per frame
286
- logger.warning(f"Too many frames ({estimated_frames}) to process within {CONFIG['MAX_PROCESSING_TIME']}s")
287
- cap.release()
288
- os.remove(video_path)
289
- yield "Video has too many frames to process within 30 seconds.", "", "", "", ""
290
- return
291
-
292
- # Read frames with sampling
293
- frame_batches = []
294
- frame_indices_batches = []
295
- current_batch = []
296
- current_indices = []
297
- frame_count = 0
298
- sampled_frame_count = 0
299
-
300
- while True:
301
  ret, frame = cap.read()
302
  if not ret:
303
  break
304
- if frame_count % CONFIG["FRAME_SAMPLE_RATE"] == 0:
305
- frame = resize_frame(frame, CONFIG["MAX_FRAME_WIDTH"])
306
- current_batch.append(frame)
307
- current_indices.append(frame_count)
308
- sampled_frame_count += 1
309
- frame_count += 1
310
-
311
- if len(current_batch) >= CONFIG["CHUNK_SIZE"]:
312
- frame_batches.append(current_batch)
313
- frame_indices_batches.append(current_indices)
314
- current_batch = []
315
- current_indices = []
316
-
317
- if current_batch:
318
- frame_batches.append(current_batch)
319
- frame_indices_batches.append(current_indices)
320
-
321
  cap.release()
322
- logger.info(f"Frame reading time: {time.time() - start_read:.2f}s")
323
- logger.info(f"Total frames: {frame_count}, Sampled frames: {sampled_frame_count}")
324
 
325
- # Process frames in parallel
 
326
  violations = []
327
  helmet_violations = {}
328
  snapshots = []
329
- last_progress_time = start_time
 
 
 
 
330
 
331
- start_parallel = time.time()
332
  with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
333
  process_func = partial(process_frame_batch, fps=fps)
334
  results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
335
-
336
- # Flatten and sort results
337
- all_detections = []
338
- for batch_result in results:
339
- all_detections.extend(batch_result)
340
- all_detections.sort(key=lambda x: x[0])
341
-
342
- logger.info(f"Parallel processing time: {time.time() - start_parallel:.2f}s")
343
-
344
- # Worker tracking
345
- start_tracking = time.time()
346
- workers = []
347
- for frame_idx, detections in all_detections:
348
- current_time = frame_idx / fps
349
-
350
- # Update progress every second
351
- if time.time() - last_progress_time > 1.0:
352
- progress = (frame_idx / frame_count) * 100
353
- yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{frame_count})", "", "", "", ""
354
- last_progress_time = time.time()
355
-
356
- # Early termination if time limit approached
357
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"] - 2:
358
- logger.warning("Approaching max processing time, terminating early")
359
- break
360
-
361
- for detection in detections:
362
- worker_id = None
363
- max_iou = 0
364
- for idx, worker in enumerate(workers):
365
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
366
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
367
- max_iou = iou
368
- worker_id = worker["id"]
369
- workers[idx]["bbox"] = detection["bounding_box"]
370
- workers[idx]["last_seen"] = current_time
371
-
372
- if worker_id is None:
373
- worker_id = len(workers) + 1
374
- workers.append({
375
- "id": worker_id,
376
- "bbox": detection["bounding_box"],
377
- "first_seen": current_time,
378
- "last_seen": current_time
379
- })
380
 
381
- detection["worker_id"] = worker_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- if detection["violation"] == "no_helmet":
384
- if worker_id not in helmet_violations:
385
- helmet_violations[worker_id] = []
386
- helmet_violations[worker_id].append(detection)
387
- else:
388
- violations.append(detection)
389
 
390
- workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
 
 
 
 
 
 
391
 
392
- logger.info(f"Worker tracking time: {time.time() - start_tracking:.2f}s")
 
393
 
394
- # Process helmet violations
395
- start_snapshot = time.time()
396
  for worker_id, detections in helmet_violations.items():
397
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
398
  best_detection = max(detections, key=lambda x: x["confidence"])
399
- if best_detection["confidence"] >= CONFIG["HELMET_CONFIDENCE_THRESHOLD"]:
400
- violations.append(best_detection)
401
-
402
- # Capture snapshot
403
- cap = cv2.VideoCapture(video_path)
404
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
405
- ret, snapshot_frame = cap.read()
406
- if ret:
407
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
408
- snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
409
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
410
- cv2.imwrite(snapshot_path, snapshot_frame)
411
- snapshots.append({
412
- "violation": "no_helmet",
413
- "frame": best_detection["frame"],
414
- "snapshot_path": snapshot_path,
415
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
416
- })
417
- cap.release()
418
-
419
- logger.info(f"Snapshot generation time: {time.time() - start_snapshot:.2f}s")
420
 
421
  os.remove(video_path)
422
  processing_time = time.time() - start_time
423
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
424
 
 
425
  if not violations:
426
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
427
  return
@@ -429,6 +354,7 @@ def process_video(video_data):
429
  score = calculate_safety_score(violations)
430
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
431
 
 
432
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
433
  violation_table += "|------------------------|---------------|------------|-----------|\n"
434
  for v in sorted(violations, key=lambda x: x["timestamp"]):
@@ -436,12 +362,13 @@ def process_video(video_data):
436
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
437
  violation_table += row
438
 
 
439
  snapshots_text = "\n".join(
440
  f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
441
  for s in snapshots
442
  ) if snapshots else "No snapshots captured."
443
 
444
- start_salesforce = time.time()
445
  try:
446
  sf = connect_to_salesforce()
447
  record_data = {
@@ -458,8 +385,6 @@ def process_video(video_data):
458
  except Exception as e:
459
  logger.error(f"Salesforce integration failed: {e}")
460
  record_id = "N/A (Salesforce error)"
461
-
462
- logger.info(f"Salesforce integration time: {time.time() - start_salesforce:.2f}s")
463
 
464
  yield (
465
  violation_table,
@@ -533,11 +458,11 @@ interface = gr.Interface(
533
  gr.Textbox(label="Salesforce Record ID"),
534
  gr.Textbox(label="Violation Details URL")
535
  ],
536
- title="Ultra-Fast Safety Violation Analyzer",
537
- description="Upload site videos to detect safety violations in under 30 seconds (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use).",
538
  allow_flagging="never"
539
  )
540
 
541
  if __name__ == "__main__":
542
- logger.info("Launching Ultra-Fast Safety Analyzer App...")
543
  interface.launch()
 
18
  from functools import partial
19
 
20
  # ==========================
21
+ # Optimized Configuration
22
  # ==========================
23
  CONFIG = {
24
  "MODEL_PATH": "yolov8_safety.pt",
 
32
  4: "improper_tool_use"
33
  },
34
  "CLASS_COLORS": {
35
+ "no_helmet": (0, 0, 255), # Red
36
+ "no_harness": (0, 165, 255), # Orange
37
+ "unsafe_posture": (0, 255, 0), # Green
38
+ "unsafe_zone": (255, 0, 0), # Blue
39
+ "improper_tool_use": (255, 255, 0) # Yellow
40
  },
41
  "DISPLAY_NAMES": {
42
  "no_helmet": "No Helmet Violation",
43
  "no_harness": "No Harness Violation",
44
+ "unsafe_posture": "Unsafe Posture",
45
  "unsafe_zone": "Unsafe Zone Entry",
46
  "improper_tool_use": "Improper Tool Use"
47
  },
 
53
  },
54
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
55
  "CONFIDENCE_THRESHOLDS": {
56
+ "no_helmet": 0.6, # Higher threshold to reduce false positives
57
+ "no_harness": 0.4,
58
+ "unsafe_posture": 0.4,
59
+ "unsafe_zone": 0.4,
60
+ "improper_tool_use": 0.4
61
  },
62
+ "MIN_VIOLATION_FRAMES": 3, # Require 3+ detections to confirm violation
63
+ "WORKER_TRACKING_DURATION": 3.0, # Track workers for 3 seconds
64
+ "MAX_PROCESSING_TIME": 30, # 30-second limit
65
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all CPU cores
66
+ "CHUNK_SIZE": 8 # Frames per batch
 
 
 
 
67
  }
68
 
69
  # Setup logging
 
97
  # ==========================
98
  # Optimized Helper Functions
99
  # ==========================
 
 
 
 
 
 
 
 
 
100
  def draw_detections(frame, detections):
101
  for det in detections:
102
  label = det.get("violation", "Unknown")
 
112
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
113
 
114
  display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
115
+ cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
116
  return frame
117
 
118
  def calculate_iou(box1, box2):
119
  x1, y1, w1, h1 = box1
120
  x2, y2, w2, h2 = box2
121
 
122
+ # Calculate intersection area
123
  x_left = max(x1 - w1/2, x2 - w2/2)
124
  y_top = max(y1 - h1/2, y2 - h2/2)
125
  x_right = min(x1 + w1/2, x2 + w2/2)
 
137
 
138
  def process_frame_batch(frame_batch, frame_indices, fps):
139
  batch_results = []
140
+ results = model(frame_batch, device=device, conf=0.1, verbose=False)
 
 
141
 
142
  for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
143
  current_time = frame_idx / fps
 
149
  conf = float(box.conf)
150
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
151
 
152
+ if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
153
  continue
154
 
155
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
 
167
 
168
  def generate_violation_pdf(violations, score):
169
  try:
 
170
  pdf_filename = f"violations_{int(time.time())}.pdf"
171
  pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
172
  pdf_file = BytesIO()
 
208
  with open(pdf_path, "wb") as f:
209
  f.write(pdf_file.getvalue())
210
  public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
 
211
  logger.info(f"PDF generated: {public_url}")
212
  return pdf_path, public_url, pdf_file
213
  except Exception as e:
 
227
  return max(score, 0)
228
 
229
  # ==========================
230
+ # Optimized Video Processing
231
  # ==========================
232
  def process_video(video_data):
233
  try:
 
234
  # Create temp video file
235
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
236
  with open(video_path, "wb") as f:
 
238
  logger.info(f"Video saved: {video_path}")
239
 
240
  # Open video file
 
241
  cap = cv2.VideoCapture(video_path)
242
  if not cap.isOpened():
243
  raise ValueError("Could not open video file")
 
246
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
247
  fps = cap.get(cv2.CAP_PROP_FPS)
248
  if fps <= 0:
249
+ fps = 30 # Default assumption if FPS not available
250
  duration = total_frames / fps
251
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
252
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
253
 
254
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
255
 
256
+ # Read all frames upfront
257
+ all_frames = []
258
+ all_indices = []
259
+ for frame_idx in range(total_frames):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  ret, frame = cap.read()
261
  if not ret:
262
  break
263
+ all_frames.append(frame)
264
+ all_indices.append(frame_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  cap.release()
 
 
266
 
267
+ # Process frames in parallel batches
268
+ workers = []
269
  violations = []
270
  helmet_violations = {}
271
  snapshots = []
272
+ start_time = time.time()
273
+
274
+ # Split frames into batches
275
+ frame_batches = [all_frames[i:i + CONFIG["CHUNK_SIZE"]] for i in range(0, len(all_frames), CONFIG["CHUNK_SIZE"])]
276
+ frame_indices_batches = [all_indices[i:i + CONFIG["CHUNK_SIZE"]] for i in range(0, len(all_indices), CONFIG["CHUNK_SIZE"])]
277
 
278
+ # Process batches in parallel
279
  with Pool(processes=CONFIG["PARALLEL_WORKERS"]) as pool:
280
  process_func = partial(process_frame_batch, fps=fps)
281
  results = pool.starmap(process_func, zip(frame_batches, frame_indices_batches))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ # Flatten results and track workers
284
+ for batch_result in results:
285
+ for frame_idx, detections in batch_result:
286
+ current_time = frame_idx / fps
287
+
288
+ for detection in detections:
289
+ # Worker tracking
290
+ worker_id = None
291
+ max_iou = 0
292
+ for idx, worker in enumerate(workers):
293
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
294
+ if iou > max_iou and iou > 0.4: # IOU threshold
295
+ max_iou = iou
296
+ worker_id = worker["id"]
297
+ workers[idx]["bbox"] = detection["bounding_box"]
298
+ workers[idx]["last_seen"] = current_time
299
+
300
+ if worker_id is None:
301
+ worker_id = len(workers) + 1
302
+ workers.append({
303
+ "id": worker_id,
304
+ "bbox": detection["bounding_box"],
305
+ "first_seen": current_time,
306
+ "last_seen": current_time
307
+ })
308
 
309
+ detection["worker_id"] = worker_id
 
 
 
 
 
310
 
311
+ # Track helmet violations separately
312
+ if detection["violation"] == "no_helmet":
313
+ if worker_id not in helmet_violations:
314
+ helmet_violations[worker_id] = []
315
+ helmet_violations[worker_id].append(detection)
316
+ else:
317
+ violations.append(detection)
318
 
319
+ # Remove inactive workers
320
+ workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
321
 
322
+ # Confirm helmet violations (require multiple detections)
 
323
  for worker_id, detections in helmet_violations.items():
324
  if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
325
  best_detection = max(detections, key=lambda x: x["confidence"])
326
+ violations.append(best_detection)
327
+
328
+ # Capture snapshot
329
+ cap = cv2.VideoCapture(video_path)
330
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
331
+ ret, snapshot_frame = cap.read()
332
+ if ret:
333
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
334
+ snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
335
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
336
+ cv2.imwrite(snapshot_path, snapshot_frame)
337
+ snapshots.append({
338
+ "violation": "no_helmet",
339
+ "frame": best_detection["frame"],
340
+ "snapshot_path": snapshot_path,
341
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
342
+ })
343
+ cap.release()
 
 
 
344
 
345
  os.remove(video_path)
346
  processing_time = time.time() - start_time
347
  logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
348
 
349
+ # Generate results
350
  if not violations:
351
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
352
  return
 
354
  score = calculate_safety_score(violations)
355
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
356
 
357
+ # Generate violation table
358
  violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
359
  violation_table += "|------------------------|---------------|------------|-----------|\n"
360
  for v in sorted(violations, key=lambda x: x["timestamp"]):
 
362
  row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
363
  violation_table += row
364
 
365
+ # Generate snapshots text
366
  snapshots_text = "\n".join(
367
  f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], 'Unknown')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
368
  for s in snapshots
369
  ) if snapshots else "No snapshots captured."
370
 
371
+ # Push to Salesforce
372
  try:
373
  sf = connect_to_salesforce()
374
  record_data = {
 
385
  except Exception as e:
386
  logger.error(f"Salesforce integration failed: {e}")
387
  record_id = "N/A (Salesforce error)"
 
 
388
 
389
  yield (
390
  violation_table,
 
458
  gr.Textbox(label="Salesforce Record ID"),
459
  gr.Textbox(label="Violation Details URL")
460
  ],
461
+ title="Worksite Safety Violation Analyzer",
462
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use).",
463
  allow_flagging="never"
464
  )
465
 
466
  if __name__ == "__main__":
467
+ logger.info("Launching Safety Analyzer App...")
468
  interface.launch()