PrashanthB461 commited on
Commit
d7cc76c
·
verified ·
1 Parent(s): 40b5e84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -51,7 +51,7 @@ CONFIG = {
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
  "FRAME_SKIP": {
54
- "no_helmet": 3, # Lower skip for frequent violations
55
  "no_harness": 2,
56
  "unsafe_posture": 2,
57
  "unsafe_zone": 2,
@@ -59,14 +59,14 @@ CONFIG = {
59
  },
60
  "MAX_PROCESSING_TIME": 60,
61
  "CONFIDENCE_THRESHOLDS": {
62
- "no_helmet": 0.3, # Slightly higher to reduce false positives
63
  "no_harness": 0.2,
64
  "unsafe_posture": 0.2,
65
  "unsafe_zone": 0.2,
66
  "improper_tool_use": 0.2
67
  },
68
  "IOU_THRESHOLD": 0.4,
69
- "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
70
  }
71
 
72
  # Setup logging
@@ -103,9 +103,9 @@ model = load_model()
103
  def draw_detections(frame, detections):
104
  """Draw bounding boxes and labels on frame"""
105
  for det in detections:
106
- label = det["violation"]
107
- confidence = det["confidence"]
108
- x, y, w, h = det["bounding_box"]
109
 
110
  x1 = int(x - w/2)
111
  y1 = int(y - h/2)
@@ -178,8 +178,8 @@ def generate_violation_pdf(violations, score):
178
  c.drawString(1 * inch, y_position, "No violations detected.")
179
  else:
180
  for v in violations:
181
- display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
182
- text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
183
  c.drawString(1 * inch, y_position, text)
184
  y_position -= 0.3 * inch
185
  if y_position < 1 * inch:
@@ -228,7 +228,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
228
  try:
229
  sf = connect_to_salesforce()
230
  violations_text = "\n".join(
231
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
232
  for v in violations
233
  ) or "No violations detected."
234
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
@@ -275,7 +275,7 @@ def calculate_safety_score(violations):
275
  "unsafe_zone": 35,
276
  "improper_tool_use": 25
277
  }
278
- total_penalty = sum(penalties.get(v["violation"], 0) for v in violations)
279
  score = 100 - total_penalty
280
  return max(score, 0)
281
 
@@ -337,24 +337,32 @@ def process_video(video_data):
337
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
 
339
  if label is None:
 
340
  continue
341
 
342
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
 
343
  continue
344
 
345
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
 
347
  current_detections.append({
348
  "frame": frame_count,
349
- "vi臉olation": label,
350
  "confidence": round(conf, 2),
351
  "bounding_box": bbox,
352
  "timestamp": current_time
353
  })
354
 
 
 
355
  # Process detections and associate with workers
356
  for detection in current_detections:
357
- violation_type = detection["violation"]
 
 
 
 
358
  # Skip No Helmet detection if worker is compliant
359
  if violation_type == "no_helmet":
360
  matched_worker = None
@@ -366,7 +374,8 @@ def process_video(video_data):
366
  matched_worker = worker
367
 
368
  if matched_worker and matched_worker["id"] in helmet_compliance:
369
- continue # Skip if worker is known to wear a helmet
 
370
 
371
  # Find or create worker
372
  matched_worker = None
@@ -393,14 +402,16 @@ def process_video(video_data):
393
 
394
  # Skip if this violation type is already confirmed for this worker
395
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
 
396
  continue
397
 
398
  detection["worker_id"] = worker_id
399
  violation_history[violation_type].append(detection)
400
 
401
  # Update helmet compliance (simulate by checking if No Helmet confidence is low)
402
- if violation_type == "no_helmet" and conf < 0.5:
403
  helmet_compliance[worker_id] = True
 
404
 
405
  # Clean up old workers
406
  workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
@@ -413,6 +424,7 @@ def process_video(video_data):
413
  # Process violation history to confirm persistent violations
414
  for violation_type, detections in violation_history.items():
415
  if not detections:
 
416
  continue
417
 
418
  worker_violations = {}
@@ -429,6 +441,7 @@ def process_video(video_data):
429
 
430
  # Skip No Helmet if worker is compliant
431
  if violation_type == "no_helmet" and worker_id in helmet_compliance:
 
432
  continue
433
 
434
  best_detection = max(worker_dets, key=lambda x: x["confidence"])
@@ -457,6 +470,7 @@ def process_video(video_data):
457
  "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
458
  })
459
  snapshot_taken[violation_type] = True
 
460
 
461
  if not violations:
462
  logger.info("No persistent violations detected")
@@ -473,6 +487,7 @@ def process_video(video_data):
473
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
474
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
475
 
 
476
  return {
477
  "violations": violations,
478
  "snapshots": snapshots,
@@ -517,8 +532,8 @@ def gradio_interface(video_file):
517
  rows = []
518
  violation_name_map = CONFIG["DISPLAY_NAMES"]
519
  for v in result["violations"]:
520
- display_name = violation_name_map.get(v["violation"], v["violation"])
521
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
522
  rows.append(row)
523
  violation_table = header + separator + "\n".join(rows)
524
 
@@ -526,7 +541,7 @@ def gradio_interface(video_file):
526
  if result["snapshots"]:
527
  violation_name_map = CONFIG["DISPLAY_NAMES"]
528
  snapshots_text = "\n".join(
529
- f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
530
  for s in result["snapshots"]
531
  )
532
 
 
51
  },
52
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
53
  "FRAME_SKIP": {
54
+ "no_helmet": 3,
55
  "no_harness": 2,
56
  "unsafe_posture": 2,
57
  "unsafe_zone": 2,
 
59
  },
60
  "MAX_PROCESSING_TIME": 60,
61
  "CONFIDENCE_THRESHOLDS": {
62
+ "no_helmet": 0.3,
63
  "no_harness": 0.2,
64
  "unsafe_posture": 0.2,
65
  "unsafe_zone": 0.2,
66
  "improper_tool_use": 0.2
67
  },
68
  "IOU_THRESHOLD": 0.4,
69
+ "MIN_VIOLATION_FRAMES": 3
70
  }
71
 
72
  # Setup logging
 
103
  def draw_detections(frame, detections):
104
  """Draw bounding boxes and labels on frame"""
105
  for det in detections:
106
+ label = det.get("violation", "Unknown")
107
+ confidence = det.get("confidence", 0.0)
108
+ x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
109
 
110
  x1 = int(x - w/2)
111
  y1 = int(y - h/2)
 
178
  c.drawString(1 * inch, y_position, "No violations detected.")
179
  else:
180
  for v in violations:
181
+ display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
182
+ text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
183
  c.drawString(1 * inch, y_position, text)
184
  y_position -= 0.3 * inch
185
  if y_position < 1 * inch:
 
228
  try:
229
  sf = connect_to_salesforce()
230
  violations_text = "\n".join(
231
+ f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
232
  for v in violations
233
  ) or "No violations detected."
234
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
 
275
  "unsafe_zone": 35,
276
  "improper_tool_use": 25
277
  }
278
+ total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
279
  score = 100 - total_penalty
280
  return max(score, 0)
281
 
 
337
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
 
339
  if label is None:
340
+ logger.warning(f"Unknown class ID {cls} detected, skipping")
341
  continue
342
 
343
  if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
344
+ logger.debug(f"Detection {label} with confidence {conf:.2f} below threshold, skipping")
345
  continue
346
 
347
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
348
 
349
  current_detections.append({
350
  "frame": frame_count,
351
+ "violation": label, # Corrected key
352
  "confidence": round(conf, 2),
353
  "bounding_box": bbox,
354
  "timestamp": current_time
355
  })
356
 
357
+ logger.debug(f"Frame {frame_count}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
358
+
359
  # Process detections and associate with workers
360
  for detection in current_detections:
361
+ violation_type = detection.get("violation", None)
362
+ if violation_type is None:
363
+ logger.error(f"Invalid detection, missing 'violation' key: {detection}")
364
+ continue
365
+
366
  # Skip No Helmet detection if worker is compliant
367
  if violation_type == "no_helmet":
368
  matched_worker = None
 
374
  matched_worker = worker
375
 
376
  if matched_worker and matched_worker["id"] in helmet_compliance:
377
+ logger.debug(f"Worker {matched_worker['id']} has helmet, skipping no_helmet violation")
378
+ continue
379
 
380
  # Find or create worker
381
  matched_worker = None
 
402
 
403
  # Skip if this violation type is already confirmed for this worker
404
  if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
405
+ logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
406
  continue
407
 
408
  detection["worker_id"] = worker_id
409
  violation_history[violation_type].append(detection)
410
 
411
  # Update helmet compliance (simulate by checking if No Helmet confidence is low)
412
+ if violation_type == "no_helmet" and detection["confidence"] < 0.5:
413
  helmet_compliance[worker_id] = True
414
+ logger.debug(f"Worker {worker_id} marked as helmet compliant")
415
 
416
  # Clean up old workers
417
  workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
 
424
  # Process violation history to confirm persistent violations
425
  for violation_type, detections in violation_history.items():
426
  if not detections:
427
+ logger.info(f"No detections for {violation_type}")
428
  continue
429
 
430
  worker_violations = {}
 
441
 
442
  # Skip No Helmet if worker is compliant
443
  if violation_type == "no_helmet" and worker_id in helmet_compliance:
444
+ logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
445
  continue
446
 
447
  best_detection = max(worker_dets, key=lambda x: x["confidence"])
 
470
  "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
471
  })
472
  snapshot_taken[violation_type] = True
473
+ logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
474
 
475
  if not violations:
476
  logger.info("No persistent violations detected")
 
487
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
488
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
489
 
490
+ logger.info(f"Processing complete: {len(violations)} violations detected, score: {score}%")
491
  return {
492
  "violations": violations,
493
  "snapshots": snapshots,
 
532
  rows = []
533
  violation_name_map = CONFIG["DISPLAY_NAMES"]
534
  for v in result["violations"]:
535
+ display_name = violation_name_map.get(v.get("violation", "Unknown"), "Unknown")
536
+ row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |"
537
  rows.append(row)
538
  violation_table = header + separator + "\n".join(rows)
539
 
 
541
  if result["snapshots"]:
542
  violation_name_map = CONFIG["DISPLAY_NAMES"]
543
  snapshots_text = "\n".join(
544
+ f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: ![]({s.get('snapshot_base64', '')})"
545
  for s in result["snapshots"]
546
  )
547