PrashanthB461 commited on
Commit
1aebe73
·
verified ·
1 Parent(s): 921f6bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -158
app.py CHANGED
@@ -15,7 +15,7 @@ import logging
15
  from retrying import retry
16
 
17
  # ==========================
18
- # Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
@@ -24,12 +24,16 @@ CONFIG = {
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
- 2: "unsafe_posture"
 
 
28
  },
29
  "DISPLAY_NAMES": {
30
  "no_helmet": "No Helmet Violation",
31
  "no_harness": "No Harness Violation",
32
- "unsafe_posture": "Unsafe Posture Violation"
 
 
33
  },
34
  "SF_CREDENTIALS": {
35
  "username": "prashanth1ai@safety.com",
@@ -38,10 +42,17 @@ CONFIG = {
38
  "domain": "login"
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
- "FRAME_SKIP": 15,
42
- "MAX_PROCESSING_TIME": 30,
43
- "CONFIDENCE_THRESHOLD": 0.1, # Lowered for debugging
44
- "IOU_THRESHOLD": 0.5 # For worker tracking
 
 
 
 
 
 
 
45
  }
46
 
47
  # Setup logging
@@ -73,7 +84,7 @@ def load_model():
73
  model = load_model()
74
 
75
  # ==========================
76
- # Helper Functions
77
  # ==========================
78
  def calculate_iou(box1, box2):
79
  """Calculate Intersection over Union (IoU) for two bounding boxes."""
@@ -99,8 +110,17 @@ def calculate_iou(box1, box2):
99
 
100
  return intersection / union if union > 0 else 0
101
 
 
 
 
 
 
 
 
 
 
102
  # ==========================
103
- # Salesforce Integration
104
  # ==========================
105
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
106
  def connect_to_salesforce():
@@ -141,7 +161,7 @@ def generate_violation_pdf(violations, score):
141
  else:
142
  for v in violations:
143
  display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
144
- text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, Worker ID: {v['worker_id']})"
145
  c.drawString(1 * inch, y_position, text)
146
  y_position -= 0.3 * inch
147
  if y_position < 1 * inch:
@@ -190,7 +210,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
190
  try:
191
  sf = connect_to_salesforce()
192
  violations_text = "\n".join(
193
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, Worker ID: {v['worker_id']})"
194
  for v in violations
195
  ) or "No violations detected."
196
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
@@ -233,14 +253,23 @@ def calculate_safety_score(violations):
233
  penalties = {
234
  "no_helmet": 25,
235
  "no_harness": 30,
236
- "unsafe_posture": 20
 
 
237
  }
238
- total_penalty = sum(penalties.get(v["violation"], 0) for v in violations)
239
- logger.info(f"Total Penalty: {total_penalty}")
 
 
 
 
 
240
  score = 100 - total_penalty
241
- logger.info(f"Calculated Score: {score}")
242
  return max(score, 0)
243
 
 
 
 
244
  def process_video(video_data):
245
  try:
246
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
@@ -252,16 +281,19 @@ def process_video(video_data):
252
  if not video.isOpened():
253
  raise ValueError("Could not open video file")
254
 
255
- violations, snapshots, raw_detections = [], [], []
256
  frame_count = 0
257
  start_time = time.time()
258
  fps = video.get(cv2.CAP_PROP_FPS)
 
 
259
 
260
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
261
  workers = [] # List to track workers
 
262
 
 
263
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
264
- logger.info(f"Using confidence threshold: {CONFIG['CONFIDENCE_THRESHOLD']}")
265
 
266
  while True:
267
  ret, frame = video.read()
@@ -276,154 +308,134 @@ def process_video(video_data):
276
  logger.info("Processing time limit reached")
277
  break
278
 
 
279
  results = model(frame, device=device)
280
- current_detections = []
281
 
282
  for result in results:
283
  boxes = result.boxes
284
- logger.info(f"Frame {frame_count}: Found {len(boxes)} potential detections")
285
 
286
  for box in boxes:
287
- cls, conf = int(box.cls), float(box.conf)
 
288
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
289
 
290
- # Log all raw detections
291
- logger.info(f"Raw Detection: class={cls}, conf={conf:.2f}, label={label}")
292
- raw_detections.append({
293
- "frame": frame_count,
294
- "class": cls,
295
- "confidence": round(conf, 2),
296
- "label": label if label in CONFIG["VIOLATION_LABELS"].values() else "unknown",
297
- "timestamp": frame_count / fps
298
- })
299
-
300
- if label not in CONFIG["VIOLATION_LABELS"].values():
301
- logger.info(f"Skipping unknown class: {cls}")
302
- continue
303
 
304
- if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
305
- logger.info(f"Skipping low confidence: {conf:.2f} < {CONFIG['CONFIDENCE_THRESHOLD']}")
 
306
  continue
307
 
308
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
309
- logger.info(f"Valid detection: {label} with confidence: {conf:.2f}")
310
 
311
- current_detections.append({
 
 
312
  "violation": label,
313
  "confidence": round(conf, 2),
314
  "bounding_box": bbox,
315
- "timestamp": frame_count / fps,
316
- "frame": frame_count
317
  })
318
 
319
- for detection in current_detections:
320
- matched_worker = None
321
- max_iou = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- for worker in workers:
324
- iou = calculate_iou(detection["bounding_box"], worker["bbox"])
325
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
326
- max_iou = iou
327
- matched_worker = worker
328
-
329
- if matched_worker:
330
- if detection["violation"] not in matched_worker["violations"]:
331
- logger.info(f"New violation for worker {matched_worker['id']}: {detection['violation']}")
332
- matched_worker["violations"].add(detection["violation"])
333
- violations.append({
334
- "frame": frame_count,
335
- "violation": detection["violation"],
336
- "confidence": detection["confidence"],
337
- "bounding_box": detection["bounding_box"],
338
- "timestamp": detection["timestamp"],
339
- "worker_id": matched_worker["id"]
340
- })
341
 
342
- if not snapshot_taken[detection["violation"]]:
343
- snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
344
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
345
- cv2.imwrite(snapshot_path, frame)
346
- snapshot_taken[detection["violation"]] = True
347
  snapshots.append({
348
- "violation": detection["violation"],
349
- "frame": frame_count,
350
  "snapshot_path": snapshot_path,
351
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
352
- "timestamp": detection["timestamp"],
353
- "confidence": detection["confidence"]
354
  })
355
-
356
- matched_worker["bbox"] = detection["bounding_box"]
357
- matched_worker["last_frame"] = frame_count
358
- else:
359
- worker_id = len(workers) + 1
360
- logger.info(f"New worker {worker_id} with violation: {detection['violation']}")
361
- workers.append({
362
- "id": worker_id,
363
- "violations": {detection["violation"]},
364
- "bbox": detection["bounding_box"],
365
- "last_frame": frame_count
366
- })
367
-
368
- violations.append({
369
- "frame": frame_count,
370
- "violation": detection["violation"],
371
- "confidence": detection["confidence"],
372
- "bounding_box": detection["bounding_box"],
373
- "timestamp": detection["timestamp"],
374
- "worker_id": worker_id
375
- })
376
-
377
- if not snapshot_taken[detection["violation"]]:
378
- snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
379
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
380
- cv2.imwrite(snapshot_path, frame)
381
- snapshot_taken[detection["violation"]] = True
382
- snapshots.append({
383
- "violation": detection["violation"],
384
- "frame": frame_count,
385
- "snapshot_path": snapshot_path,
386
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
387
- "timestamp": detection["timestamp"],
388
- "confidence": detection["confidence"]
389
- })
390
-
391
- active_workers = [w for w in workers if frame_count - w["last_frame"] < CONFIG["FRAME_SKIP"] * 5]
392
- if len(active_workers) != len(workers):
393
- logger.info(f"Cleaned up {len(workers) - len(active_workers)} inactive workers")
394
- workers = active_workers
395
-
396
- frame_count += 1
397
 
398
- video.release()
399
- os.remove(video_path)
400
-
401
- violation_types = {}
402
- for v in violations:
403
- violation_types[v["violation"]] = violation_types.get(v["violation"], 0) + 1
404
-
405
- logger.info(f"Detection complete. Found violations: {violation_types}")
406
-
407
- if not violations:
408
- logger.info("No violations detected")
409
  return {
410
  "violations": [],
411
  "snapshots": [],
412
- "raw_detections": raw_detections,
413
  "score": 100,
414
  "salesforce_record_id": None,
415
  "violation_details_url": "",
416
  "message": "No violations detected in the video."
417
  }
418
 
419
- score = calculate_safety_score(violations)
420
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
421
- report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
422
 
423
  return {
424
- "violations": violations,
425
  "snapshots": snapshots,
426
- "raw_detections": raw_detections,
427
  "score": score,
428
  "salesforce_record_id": report_id,
429
  "violation_details_url": final_pdf_url,
@@ -434,18 +446,20 @@ def process_video(video_data):
434
  return {
435
  "violations": [],
436
  "snapshots": [],
437
- "raw_detections": [],
438
  "score": 100,
439
  "salesforce_record_id": None,
440
  "violation_details_url": "",
441
  "message": f"Error processing video: {e}"
442
  }
443
 
 
 
 
444
  def gradio_interface(video_file):
445
  if not video_file:
446
- return "No file uploaded.", "", "No file uploaded.", "", "", [], "No raw detections."
447
  try:
448
- yield "Processing video... please wait.", "", "", "", "", [], "Processing..."
449
 
450
  with open(video_file, "rb") as f:
451
  video_data = f.read()
@@ -453,71 +467,55 @@ def gradio_interface(video_file):
453
  result = process_video(video_data)
454
 
455
  if result.get("message"):
456
- yield result["message"], "", "", "", "", [], "Error in processing."
457
  return
458
 
459
  violation_table = "No violations detected."
460
  if result["violations"]:
461
- header = "| Violation | Timestamp (s) | Confidence | Worker ID | Frame |\n"
462
- separator = "|------------------------|---------------|------------|-----------|-------|\n"
463
  rows = []
464
  violation_name_map = CONFIG["DISPLAY_NAMES"]
465
  for v in result["violations"]:
466
  display_name = violation_name_map.get(v["violation"], v["violation"])
467
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} | {v['frame']} |"
468
  rows.append(row)
469
  violation_table = header + separator + "\n".join(rows)
470
 
471
  snapshots_text = "No snapshots captured."
472
- snapshot_images = []
473
  if result["snapshots"]:
474
  violation_name_map = CONFIG["DISPLAY_NAMES"]
475
  snapshots_text = "\n".join(
476
- f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at {s['timestamp']:.2f}s (Frame {s['frame']}, Confidence: {s['confidence']:.2f}): ![]({s['snapshot_base64']})"
477
  for s in result["snapshots"]
478
  )
479
- snapshot_images = [s["snapshot_base64"] for s in result["snapshots"]]
480
-
481
- raw_detections_text = "No raw detections logged."
482
- if result["raw_detections"]:
483
- header = "| Frame | Timestamp (s) | Class | Label | Confidence |\n"
484
- separator = "|-------|---------------|-------|----------------|------------|\n"
485
- rows = []
486
- for d in result["raw_detections"]:
487
- row = f"| {d['frame']:<5} | {d['timestamp']:.2f} | {d['class']:<5} | {d['label']:<14} | {d['confidence']:.2f} |"
488
- rows.append(row)
489
- raw_detections_text = header + separator + "\n".join(rows)
490
 
491
  yield (
492
  violation_table,
493
  f"Safety Score: {result['score']}%",
494
  snapshots_text,
495
  f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
496
- result["violation_details_url"] or "N/A",
497
- snapshot_images,
498
- raw_detections_text
499
  )
500
  except Exception as e:
501
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
502
- yield f"Error: {str(e)}", "", "Error in processing.", "", "", [], "Error in processing."
503
 
504
  interface = gr.Interface(
505
  fn=gradio_interface,
506
  inputs=gr.Video(label="Upload Site Video"),
507
- outputs=[
508
  gr.Markdown(label="Detected Safety Violations"),
509
  gr.Textbox(label="Compliance Score"),
510
  gr.Markdown(label="Snapshots"),
511
  gr.Textbox(label="Salesforce Record ID"),
512
- gr.Textbox(label="Violation Details URL"),
513
- gr.Gallery(label="Violation Snapshots"),
514
- gr.Markdown(label="Raw Detections (Debug)")
515
  ],
516
  title="Worksite Safety Violation Analyzer",
517
- description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture). Non-violations are ignored.",
518
  allow_flagging="never"
519
  )
520
 
521
  if __name__ == "__main__":
522
- logger.info("Launching Safety Analyzer App...")
523
  interface.launch()
 
15
  from retrying import retry
16
 
17
  # ==========================
18
+ # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
 
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
+ 2: "unsafe_posture",
28
+ 3: "unsafe_zone",
29
+ 4: "improper_tool_use"
30
  },
31
  "DISPLAY_NAMES": {
32
  "no_helmet": "No Helmet Violation",
33
  "no_harness": "No Harness Violation",
34
+ "unsafe_posture": "Unsafe Posture Violation",
35
+ "unsafe_zone": "Unsafe Zone Entry",
36
+ "improper_tool_use": "Improper Tool Use"
37
  },
38
  "SF_CREDENTIALS": {
39
  "username": "prashanth1ai@safety.com",
 
42
  "domain": "login"
43
  },
44
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
45
+ "FRAME_SKIP": 10, # Reduced for better detection
46
+ "MAX_PROCESSING_TIME": 45,
47
+ "CONFIDENCE_THRESHOLD": {
48
+ "no_helmet": 0.4,
49
+ "no_harness": 0.35,
50
+ "unsafe_posture": 0.3,
51
+ "unsafe_zone": 0.3,
52
+ "improper_tool_use": 0.35
53
+ },
54
+ "IOU_THRESHOLD": 0.4,
55
+ "MIN_VIOLATION_DURATION": 2 # seconds
56
  }
57
 
58
  # Setup logging
 
84
  model = load_model()
85
 
86
  # ==========================
87
+ # Enhanced Helper Functions
88
  # ==========================
89
  def calculate_iou(box1, box2):
90
  """Calculate Intersection over Union (IoU) for two bounding boxes."""
 
110
 
111
  return intersection / union if union > 0 else 0
112
 
113
+ def is_violation_persistent(violation_type, worker_id, violations, fps):
114
+ """Check if a violation persists for the required duration."""
115
+ violation_times = [v['timestamp'] for v in violations
116
+ if v['violation'] == violation_type and v['worker_id'] == worker_id]
117
+ if len(violation_times) < 2:
118
+ return False
119
+ duration = max(violation_times) - min(violation_times)
120
+ return duration >= CONFIG["MIN_VIOLATION_DURATION"]
121
+
122
  # ==========================
123
+ # Salesforce Integration (unchanged)
124
  # ==========================
125
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
126
  def connect_to_salesforce():
 
161
  else:
162
  for v in violations:
163
  display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
164
+ text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
165
  c.drawString(1 * inch, y_position, text)
166
  y_position -= 0.3 * inch
167
  if y_position < 1 * inch:
 
210
  try:
211
  sf = connect_to_salesforce()
212
  violations_text = "\n".join(
213
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
214
  for v in violations
215
  ) or "No violations detected."
216
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
 
253
  penalties = {
254
  "no_helmet": 25,
255
  "no_harness": 30,
256
+ "unsafe_posture": 20,
257
+ "unsafe_zone": 35,
258
+ "improper_tool_use": 25
259
  }
260
+ # Count unique violations per worker
261
+ unique_violations = set()
262
+ for v in violations:
263
+ key = (v["worker_id"], v["violation"])
264
+ unique_violations.add(key)
265
+
266
+ total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
267
  score = 100 - total_penalty
 
268
  return max(score, 0)
269
 
270
+ # ==========================
271
+ # Enhanced Video Processing
272
+ # ==========================
273
  def process_video(video_data):
274
  try:
275
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
 
281
  if not video.isOpened():
282
  raise ValueError("Could not open video file")
283
 
284
+ violations, snapshots = [], []
285
  frame_count = 0
286
  start_time = time.time()
287
  fps = video.get(cv2.CAP_PROP_FPS)
288
+ if fps <= 0:
289
+ fps = 30 # Default assumption if FPS cannot be determined
290
 
291
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
292
  workers = [] # List to track workers
293
+ violation_history = [] # Track all potential violations before filtering
294
 
295
+ logger.info(f"Processing video with FPS: {fps}")
296
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
 
297
 
298
  while True:
299
  ret, frame = video.read()
 
308
  logger.info("Processing time limit reached")
309
  break
310
 
311
+ # Run detection on this frame
312
  results = model(frame, device=device)
313
+ current_time = frame_count / fps
314
 
315
  for result in results:
316
  boxes = result.boxes
317
+ logger.debug(f"Frame {frame_count}: Found {len(boxes)} potential detections")
318
 
319
  for box in boxes:
320
+ cls = int(box.cls)
321
+ conf = float(box.conf)
322
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
323
 
324
+ if label is None:
325
+ continue # Skip unknown classes
 
 
 
 
 
 
 
 
 
 
 
326
 
327
+ conf_threshold = CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3)
328
+ if conf < conf_threshold:
329
+ logger.debug(f"Skipping {label} with low confidence: {conf:.2f} < {conf_threshold}")
330
  continue
331
 
332
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
 
333
 
334
+ # Store potential violation (will filter later)
335
+ violation_history.append({
336
+ "frame": frame_count,
337
  "violation": label,
338
  "confidence": round(conf, 2),
339
  "bounding_box": bbox,
340
+ "timestamp": current_time
 
341
  })
342
 
343
+ frame_count += 1
344
+
345
+ video.release()
346
+ os.remove(video_path)
347
+
348
+ # Process violation history to track workers and persistent violations
349
+ workers = []
350
+ for v in violation_history:
351
+ # Find matching worker
352
+ matched_worker = None
353
+ max_iou = 0
354
+
355
+ for worker in workers:
356
+ iou = calculate_iou(v["bounding_box"], worker["bbox"])
357
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
358
+ max_iou = iou
359
+ matched_worker = worker
360
+
361
+ if matched_worker:
362
+ # Update worker's violation history
363
+ matched_worker["violations"].append(v)
364
+ matched_worker["bbox"] = v["bounding_box"]
365
+ matched_worker["last_seen"] = v["timestamp"]
366
+ v["worker_id"] = matched_worker["id"]
367
+ else:
368
+ # New worker
369
+ worker_id = len(workers) + 1
370
+ workers.append({
371
+ "id": worker_id,
372
+ "bbox": v["bounding_box"],
373
+ "violations": [v],
374
+ "first_seen": v["timestamp"],
375
+ "last_seen": v["timestamp"]
376
+ })
377
+ v["worker_id"] = worker_id
378
+
379
+ # Filter violations to only include those that persist for minimum duration
380
+ final_violations = []
381
+ for worker in workers:
382
+ # Group violations by type
383
+ violations_by_type = {}
384
+ for v in worker["violations"]:
385
+ if v["violation"] not in violations_by_type:
386
+ violations_by_type[v["violation"]] = []
387
+ violations_by_type[v["violation"]].append(v)
388
+
389
+ # Check each violation type for persistence
390
+ for violation_type, v_list in violations_by_type.items():
391
+ if len(v_list) < 2:
392
+ continue # Need multiple detections to check duration
393
 
394
+ duration = max(v["timestamp"] for v in v_list) - min(v["timestamp"] for v in v_list)
395
+ if duration >= CONFIG["MIN_VIOLATION_DURATION"]:
396
+ # Take the highest confidence detection
397
+ best_detection = max(v_list, key=lambda x: x["confidence"])
398
+ final_violations.append(best_detection)
399
+
400
+ # Capture snapshot if not already taken
401
+ if not snapshot_taken[violation_type]:
402
+ # We need to get the frame for this violation
403
+ cap = cv2.VideoCapture(video_path)
404
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
405
+ ret, snapshot_frame = cap.read()
406
+ cap.release()
 
 
 
 
 
407
 
408
+ if ret:
409
+ snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
410
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
411
+ cv2.imwrite(snapshot_path, snapshot_frame)
 
412
  snapshots.append({
413
+ "violation": violation_type,
414
+ "frame": best_detection["frame"],
415
  "snapshot_path": snapshot_path,
416
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
 
 
417
  })
418
+ snapshot_taken[violation_type] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
+ # Final processing
421
+ if not final_violations:
422
+ logger.info("No persistent violations detected")
 
 
 
 
 
 
 
 
423
  return {
424
  "violations": [],
425
  "snapshots": [],
 
426
  "score": 100,
427
  "salesforce_record_id": None,
428
  "violation_details_url": "",
429
  "message": "No violations detected in the video."
430
  }
431
 
432
+ score = calculate_safety_score(final_violations)
433
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(final_violations, score)
434
+ report_id, final_pdf_url = push_report_to_salesforce(final_violations, score, pdf_path, pdf_file)
435
 
436
  return {
437
+ "violations": final_violations,
438
  "snapshots": snapshots,
 
439
  "score": score,
440
  "salesforce_record_id": report_id,
441
  "violation_details_url": final_pdf_url,
 
446
  return {
447
  "violations": [],
448
  "snapshots": [],
 
449
  "score": 100,
450
  "salesforce_record_id": None,
451
  "violation_details_url": "",
452
  "message": f"Error processing video: {e}"
453
  }
454
 
455
+ # ==========================
456
+ # Gradio Interface (unchanged)
457
+ # ==========================
458
  def gradio_interface(video_file):
459
  if not video_file:
460
+ return "No file uploaded.", "", "No file uploaded.", "", ""
461
  try:
462
+ yield "Processing video... please wait.", "", "", "", ""
463
 
464
  with open(video_file, "rb") as f:
465
  video_data = f.read()
 
467
  result = process_video(video_data)
468
 
469
  if result.get("message"):
470
+ yield result["message"], "", "", "", ""
471
  return
472
 
473
  violation_table = "No violations detected."
474
  if result["violations"]:
475
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
476
+ separator = "|------------------------|---------------|------------|-----------|\n"
477
  rows = []
478
  violation_name_map = CONFIG["DISPLAY_NAMES"]
479
  for v in result["violations"]:
480
  display_name = violation_name_map.get(v["violation"], v["violation"])
481
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
482
  rows.append(row)
483
  violation_table = header + separator + "\n".join(rows)
484
 
485
  snapshots_text = "No snapshots captured."
 
486
  if result["snapshots"]:
487
  violation_name_map = CONFIG["DISPLAY_NAMES"]
488
  snapshots_text = "\n".join(
489
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
490
  for s in result["snapshots"]
491
  )
 
 
 
 
 
 
 
 
 
 
 
492
 
493
  yield (
494
  violation_table,
495
  f"Safety Score: {result['score']}%",
496
  snapshots_text,
497
  f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
498
+ result["violation_details_url"] or "N/A"
 
 
499
  )
500
  except Exception as e:
501
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
502
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
503
 
504
  interface = gr.Interface(
505
  fn=gradio_interface,
506
  inputs=gr.Video(label="Upload Site Video"),
507
+ outputs=[
508
  gr.Markdown(label="Detected Safety Violations"),
509
  gr.Textbox(label="Compliance Score"),
510
  gr.Markdown(label="Snapshots"),
511
  gr.Textbox(label="Salesforce Record ID"),
512
+ gr.Textbox(label="Violation Details URL")
 
 
513
  ],
514
  title="Worksite Safety Violation Analyzer",
515
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
516
  allow_flagging="never"
517
  )
518
 
519
  if __name__ == "__main__":
520
+ logger.info("Launching Enhanced Safety Analyzer App...")
521
  interface.launch()