PrashanthB461 commited on
Commit
6283a53
·
verified ·
1 Parent(s): 4ea55e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -43
app.py CHANGED
@@ -40,8 +40,8 @@ CONFIG = {
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
- "CONFIDENCE_THRESHOLD": 0.3, # Lowered threshold for detecting all violations
44
- "IOU_THRESHOLD": 0.5 # Added for worker tracking
45
  }
46
 
47
  # Setup logging
@@ -80,19 +80,12 @@ def calculate_iou(box1, box2):
80
  x1, y1, w1, h1 = box1
81
  x2, y2, w2, h2 = box2
82
 
83
- # Convert to top-left and bottom-right coordinates
84
  x1_min, y1_min = x1 - w1/2, y1 - h1/2
85
  x1_max, y1_max = x1 + w1/2, y1 + h1/2
86
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
87
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
88
 
89
- # Calculate intersection
90
- x_min = max(x1_min, x2_min)
91
- y_min = max(y1_min, y2_min)
92
- x_max = min(x1_max, x2_max)
93
- y_max = min(y1_max, y2_max)
94
-
95
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
96
  area1 = w1 * h1
97
  area2 = w2 * h2
98
  union = area1 + area2 - intersection
@@ -141,7 +134,7 @@ def generate_violation_pdf(violations, score):
141
  else:
142
  for v in violations:
143
  display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
144
- text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
145
  c.drawString(1 * inch, y_position, text)
146
  y_position -= 0.3 * inch
147
  if y_position < 1 * inch:
@@ -190,7 +183,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
190
  try:
191
  sf = connect_to_salesforce()
192
  violations_text = "\n".join(
193
- f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
194
  for v in violations
195
  ) or "No violations detected."
196
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
@@ -260,7 +253,6 @@ def process_video(video_data):
260
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
261
  workers = [] # List to track workers
262
 
263
- # Adding debug logging for violation labels
264
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
265
  logger.info(f"Using confidence threshold: {CONFIG['CONFIDENCE_THRESHOLD']}")
266
 
@@ -277,11 +269,9 @@ def process_video(video_data):
277
  logger.info("Processing time limit reached")
278
  break
279
 
280
- # Run detection on this frame
281
  results = model(frame, device=device)
282
  current_detections = []
283
 
284
- # Process detections from the model
285
  for result in results:
286
  boxes = result.boxes
287
  logger.info(f"Frame {frame_count}: Found {len(boxes)} potential detections")
@@ -290,10 +280,8 @@ def process_video(video_data):
290
  cls, conf = int(box.cls), float(box.conf)
291
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
292
 
293
- # Enhanced logging
294
  logger.info(f"Detection: class={cls}, conf={conf:.2f}, label={label}")
295
 
296
- # Skip if not a known violation or below confidence threshold
297
  if label not in CONFIG["VIOLATION_LABELS"].values():
298
  logger.info(f"Skipping unknown class: {cls}")
299
  continue
@@ -302,7 +290,6 @@ def process_video(video_data):
302
  logger.info(f"Skipping low confidence: {conf:.2f} < {CONFIG['CONFIDENCE_THRESHOLD']}")
303
  continue
304
 
305
- # Process valid detection
306
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
307
  logger.info(f"Valid detection: {label} with confidence: {conf:.2f}")
308
 
@@ -314,13 +301,10 @@ def process_video(video_data):
314
  "frame": frame_count
315
  })
316
 
317
- # Process detections and associate with workers
318
- # FIXED: Improved worker tracking logic
319
  for detection in current_detections:
320
  matched_worker = None
321
  max_iou = 0
322
 
323
- # Try to match with existing workers
324
  for worker in workers:
325
  iou = calculate_iou(detection["bounding_box"], worker["bbox"])
326
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
@@ -328,9 +312,7 @@ def process_video(video_data):
328
  matched_worker = worker
329
 
330
  if matched_worker:
331
- # Update existing worker
332
  if detection["violation"] not in matched_worker["violations"]:
333
- # New violation for this worker
334
  logger.info(f"New violation for worker {matched_worker['id']}: {detection['violation']}")
335
  matched_worker["violations"].add(detection["violation"])
336
  violations.append({
@@ -342,7 +324,6 @@ def process_video(video_data):
342
  "worker_id": matched_worker["id"]
343
  })
344
 
345
- # Save snapshot for this violation type if not already taken
346
  if not snapshot_taken[detection["violation"]]:
347
  snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
348
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
@@ -352,14 +333,14 @@ def process_video(video_data):
352
  "violation": detection["violation"],
353
  "frame": frame_count,
354
  "snapshot_path": snapshot_path,
355
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
 
 
356
  })
357
 
358
- # Update worker position
359
  matched_worker["bbox"] = detection["bounding_box"]
360
  matched_worker["last_frame"] = frame_count
361
  else:
362
- # New worker detected
363
  worker_id = len(workers) + 1
364
  logger.info(f"New worker {worker_id} with violation: {detection['violation']}")
365
  workers.append({
@@ -378,7 +359,6 @@ def process_video(video_data):
378
  "worker_id": worker_id
379
  })
380
 
381
- # Save snapshot for this violation type if not already taken
382
  if not snapshot_taken[detection["violation"]]:
383
  snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
384
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
@@ -388,10 +368,11 @@ def process_video(video_data):
388
  "violation": detection["violation"],
389
  "frame": frame_count,
390
  "snapshot_path": snapshot_path,
391
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
 
 
392
  })
393
 
394
- # Clean up workers that haven't been seen for a while
395
  active_workers = [w for w in workers if frame_count - w["last_frame"] < CONFIG["FRAME_SKIP"] * 5]
396
  if len(active_workers) != len(workers):
397
  logger.info(f"Cleaned up {len(workers) - len(active_workers)} inactive workers")
@@ -402,7 +383,6 @@ def process_video(video_data):
402
  video.release()
403
  os.remove(video_path)
404
 
405
- # Final log of violations detected
406
  violation_types = {}
407
  for v in violations:
408
  violation_types[v["violation"]] = violation_types.get(v["violation"], 0) + 1
@@ -445,9 +425,9 @@ def process_video(video_data):
445
 
446
  def gradio_interface(video_file):
447
  if not video_file:
448
- return "No file uploaded.", "", "No file uploaded.", "", ""
449
  try:
450
- yield "Processing video... please wait.", "", "", "", ""
451
 
452
  with open(video_file, "rb") as f:
453
  video_data = f.read()
@@ -455,52 +435,56 @@ def gradio_interface(video_file):
455
  result = process_video(video_data)
456
 
457
  if result.get("message"):
458
- yield result["message"], "", "", "", ""
459
  return
460
 
461
  violation_table = "No violations detected."
462
  if result["violations"]:
463
- header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
464
- separator = "|------------------------|---------------|------------|-----------|\n"
465
  rows = []
466
  violation_name_map = CONFIG["DISPLAY_NAMES"]
467
  for v in result["violations"]:
468
  display_name = violation_name_map.get(v["violation"], v["violation"])
469
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
470
  rows.append(row)
471
  violation_table = header + separator + "\n".join(rows)
472
 
473
  snapshots_text = "No snapshots captured."
 
474
  if result["snapshots"]:
475
  violation_name_map = CONFIG["DISPLAY_NAMES"]
476
  snapshots_text = "\n".join(
477
- f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
478
  for s in result["snapshots"]
479
  )
 
480
 
481
  yield (
482
  violation_table,
483
  f"Safety Score: {result['score']}%",
484
  snapshots_text,
485
  f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
486
- result["violation_details_url"] or "N/A"
 
487
  )
488
  except Exception as e:
489
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
490
- yield f"Error: {str(e)}", "", "Error in processing.", "", ""
491
 
492
  interface = gr.Interface(
493
  fn=gradio_interface,
494
  inputs=gr.Video(label="Upload Site Video"),
495
- outputs=[
496
  gr.Markdown(label="Detected Safety Violations"),
497
  gr.Textbox(label="Compliance Score"),
498
  gr.Markdown(label="Snapshots"),
499
  gr.Textbox(label="Salesforce Record ID"),
500
- gr.Textbox(label="Violation Details URL")
 
501
  ],
502
  title="Worksite Safety Violation Analyzer",
503
- description="Upload site videos to detect safety violations (No Helmet Violation, No Harness Violation, Unsafe Posture Violation). Non-violations are ignored.",
504
  allow_flagging="never"
505
  )
506
 
 
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
+ "CONFIDENCE_THRESHOLD": 0.25, # Lowered for better detection
44
+ "IOU_THRESHOLD": 0.5 # For worker tracking
45
  }
46
 
47
  # Setup logging
 
80
  x1, y1, w1, h1 = box1
81
  x2, y2, w2, h2 = box2
82
 
 
83
  x1_min, y1_min = x1 - w1/2, y1 - h1/2
84
  x1_max, y1_max = x1 + w1/2, y1 + h1/2
85
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
86
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
87
 
88
+ intersection = max(0, x2_min - x1_max) * max(0, y2_min - y1_max)
 
 
 
 
 
 
89
  area1 = w1 * h1
90
  area2 = w2 * h2
91
  union = area1 + area2 - intersection
 
134
  else:
135
  for v in violations:
136
  display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
137
+ text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, Worker ID: {v['worker_id']})"
138
  c.drawString(1 * inch, y_position, text)
139
  y_position -= 0.3 * inch
140
  if y_position < 1 * inch:
 
183
  try:
184
  sf = connect_to_salesforce()
185
  violations_text = "\n".join(
186
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']:.2f}, Worker ID: {v['worker_id']})"
187
  for v in violations
188
  ) or "No violations detected."
189
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
 
253
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
254
  workers = [] # List to track workers
255
 
 
256
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
257
  logger.info(f"Using confidence threshold: {CONFIG['CONFIDENCE_THRESHOLD']}")
258
 
 
269
  logger.info("Processing time limit reached")
270
  break
271
 
 
272
  results = model(frame, device=device)
273
  current_detections = []
274
 
 
275
  for result in results:
276
  boxes = result.boxes
277
  logger.info(f"Frame {frame_count}: Found {len(boxes)} potential detections")
 
280
  cls, conf = int(box.cls), float(box.conf)
281
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
282
 
 
283
  logger.info(f"Detection: class={cls}, conf={conf:.2f}, label={label}")
284
 
 
285
  if label not in CONFIG["VIOLATION_LABELS"].values():
286
  logger.info(f"Skipping unknown class: {cls}")
287
  continue
 
290
  logger.info(f"Skipping low confidence: {conf:.2f} < {CONFIG['CONFIDENCE_THRESHOLD']}")
291
  continue
292
 
 
293
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
294
  logger.info(f"Valid detection: {label} with confidence: {conf:.2f}")
295
 
 
301
  "frame": frame_count
302
  })
303
 
 
 
304
  for detection in current_detections:
305
  matched_worker = None
306
  max_iou = 0
307
 
 
308
  for worker in workers:
309
  iou = calculate_iou(detection["bounding_box"], worker["bbox"])
310
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
 
312
  matched_worker = worker
313
 
314
  if matched_worker:
 
315
  if detection["violation"] not in matched_worker["violations"]:
 
316
  logger.info(f"New violation for worker {matched_worker['id']}: {detection['violation']}")
317
  matched_worker["violations"].add(detection["violation"])
318
  violations.append({
 
324
  "worker_id": matched_worker["id"]
325
  })
326
 
 
327
  if not snapshot_taken[detection["violation"]]:
328
  snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
329
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
 
333
  "violation": detection["violation"],
334
  "frame": frame_count,
335
  "snapshot_path": snapshot_path,
336
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
337
+ "timestamp": detection["timestamp"],
338
+ "confidence": detection["confidence"]
339
  })
340
 
 
341
  matched_worker["bbox"] = detection["bounding_box"]
342
  matched_worker["last_frame"] = frame_count
343
  else:
 
344
  worker_id = len(workers) + 1
345
  logger.info(f"New worker {worker_id} with violation: {detection['violation']}")
346
  workers.append({
 
359
  "worker_id": worker_id
360
  })
361
 
 
362
  if not snapshot_taken[detection["violation"]]:
363
  snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
364
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
 
368
  "violation": detection["violation"],
369
  "frame": frame_count,
370
  "snapshot_path": snapshot_path,
371
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
372
+ "timestamp": detection["timestamp"],
373
+ "confidence": detection["confidence"]
374
  })
375
 
 
376
  active_workers = [w for w in workers if frame_count - w["last_frame"] < CONFIG["FRAME_SKIP"] * 5]
377
  if len(active_workers) != len(workers):
378
  logger.info(f"Cleaned up {len(workers) - len(active_workers)} inactive workers")
 
383
  video.release()
384
  os.remove(video_path)
385
 
 
386
  violation_types = {}
387
  for v in violations:
388
  violation_types[v["violation"]] = violation_types.get(v["violation"], 0) + 1
 
425
 
426
  def gradio_interface(video_file):
427
  if not video_file:
428
+ return "No file uploaded.", "", "No file uploaded.", "", "", []
429
  try:
430
+ yield "Processing video... please wait.", "", "", "", "", []
431
 
432
  with open(video_file, "rb") as f:
433
  video_data = f.read()
 
435
  result = process_video(video_data)
436
 
437
  if result.get("message"):
438
+ yield result["message"], "", "", "", "", []
439
  return
440
 
441
  violation_table = "No violations detected."
442
  if result["violations"]:
443
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID | Frame |\n"
444
+ separator = "|------------------------|---------------|------------|-----------|-------|\n"
445
  rows = []
446
  violation_name_map = CONFIG["DISPLAY_NAMES"]
447
  for v in result["violations"]:
448
  display_name = violation_name_map.get(v["violation"], v["violation"])
449
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} | {v['frame']} |"
450
  rows.append(row)
451
  violation_table = header + separator + "\n".join(rows)
452
 
453
  snapshots_text = "No snapshots captured."
454
+ snapshot_images = []
455
  if result["snapshots"]:
456
  violation_name_map = CONFIG["DISPLAY_NAMES"]
457
  snapshots_text = "\n".join(
458
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at {s['timestamp']:.2f}s (Frame {s['frame']}, Confidence: {s['confidence']:.2f}): ![]({s['snapshot_base64']})"
459
  for s in result["snapshots"]
460
  )
461
+ snapshot_images = [s["snapshot_base64"] for s in result["snapshots"]]
462
 
463
  yield (
464
  violation_table,
465
  f"Safety Score: {result['score']}%",
466
  snapshots_text,
467
  f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
468
+ result["violation_details_url"] or "N/A",
469
+ snapshot_images
470
  )
471
  except Exception as e:
472
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
473
+ yield f"Error: {str(e)}", "", "Error in processing.", "", "", []
474
 
475
  interface = gr.Interface(
476
  fn=gradio_interface,
477
  inputs=gr.Video(label="Upload Site Video"),
478
+ outputs=[
479
  gr.Markdown(label="Detected Safety Violations"),
480
  gr.Textbox(label="Compliance Score"),
481
  gr.Markdown(label="Snapshots"),
482
  gr.Textbox(label="Salesforce Record ID"),
483
+ gr.Textbox(label="Violation Details URL"),
484
+ gr.Gallery(label="Violation Snapshots")
485
  ],
486
  title="Worksite Safety Violation Analyzer",
487
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture). Non-violations are ignored.",
488
  allow_flagging="never"
489
  )
490