PrashanthB461 commited on
Commit
aa6c939
·
verified ·
1 Parent(s): 6283a53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -11
app.py CHANGED
@@ -40,7 +40,7 @@ CONFIG = {
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
- "CONFIDENCE_THRESHOLD": 0.25, # Lowered for better detection
44
  "IOU_THRESHOLD": 0.5 # For worker tracking
45
  }
46
 
@@ -80,12 +80,19 @@ def calculate_iou(box1, box2):
80
  x1, y1, w1, h1 = box1
81
  x2, y2, w2, h2 = box2
82
 
 
83
  x1_min, y1_min = x1 - w1/2, y1 - h1/2
84
  x1_max, y1_max = x1 + w1/2, y1 + h1/2
85
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
86
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
87
 
88
- intersection = max(0, x2_min - x1_max) * max(0, y2_min - y1_max)
 
 
 
 
 
 
89
  area1 = w1 * h1
90
  area2 = w2 * h2
91
  union = area1 + area2 - intersection
@@ -245,7 +252,7 @@ def process_video(video_data):
245
  if not video.isOpened():
246
  raise ValueError("Could not open video file")
247
 
248
- violations, snapshots = [], []
249
  frame_count = 0
250
  start_time = time.time()
251
  fps = video.get(cv2.CAP_PROP_FPS)
@@ -280,7 +287,15 @@ def process_video(video_data):
280
  cls, conf = int(box.cls), float(box.conf)
281
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
282
 
283
- logger.info(f"Detection: class={cls}, conf={conf:.2f}, label={label}")
 
 
 
 
 
 
 
 
284
 
285
  if label not in CONFIG["VIOLATION_LABELS"].values():
286
  logger.info(f"Skipping unknown class: {cls}")
@@ -394,6 +409,7 @@ def process_video(video_data):
394
  return {
395
  "violations": [],
396
  "snapshots": [],
 
397
  "score": 100,
398
  "salesforce_record_id": None,
399
  "violation_details_url": "",
@@ -407,6 +423,7 @@ def process_video(video_data):
407
  return {
408
  "violations": violations,
409
  "snapshots": snapshots,
 
410
  "score": score,
411
  "salesforce_record_id": report_id,
412
  "violation_details_url": final_pdf_url,
@@ -417,6 +434,7 @@ def process_video(video_data):
417
  return {
418
  "violations": [],
419
  "snapshots": [],
 
420
  "score": 100,
421
  "salesforce_record_id": None,
422
  "violation_details_url": "",
@@ -425,9 +443,9 @@ def process_video(video_data):
425
 
426
  def gradio_interface(video_file):
427
  if not video_file:
428
- return "No file uploaded.", "", "No file uploaded.", "", "", []
429
  try:
430
- yield "Processing video... please wait.", "", "", "", "", []
431
 
432
  with open(video_file, "rb") as f:
433
  video_data = f.read()
@@ -435,7 +453,7 @@ def gradio_interface(video_file):
435
  result = process_video(video_data)
436
 
437
  if result.get("message"):
438
- yield result["message"], "", "", "", "", []
439
  return
440
 
441
  violation_table = "No violations detected."
@@ -460,19 +478,30 @@ def gradio_interface(video_file):
460
  )
461
  snapshot_images = [s["snapshot_base64"] for s in result["snapshots"]]
462
 
 
 
 
 
 
 
 
 
 
 
463
  yield (
464
  violation_table,
465
  f"Safety Score: {result['score']}%",
466
  snapshots_text,
467
  f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
468
  result["violation_details_url"] or "N/A",
469
- snapshot_images
 
470
  )
471
  except Exception as e:
472
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
473
- yield f"Error: {str(e)}", "", "Error in processing.", "", "", []
474
 
475
- interface = gr.Interface(
476
  fn=gradio_interface,
477
  inputs=gr.Video(label="Upload Site Video"),
478
  outputs=[
@@ -481,7 +510,8 @@ interface = gr.Interface(
481
  gr.Markdown(label="Snapshots"),
482
  gr.Textbox(label="Salesforce Record ID"),
483
  gr.Textbox(label="Violation Details URL"),
484
- gr.Gallery(label="Violation Snapshots")
 
485
  ],
486
  title="Worksite Safety Violation Analyzer",
487
  description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture). Non-violations are ignored.",
 
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
+ "CONFIDENCE_THRESHOLD": 0.1, # Lowered for debugging
44
  "IOU_THRESHOLD": 0.5 # For worker tracking
45
  }
46
 
 
80
  x1, y1, w1, h1 = box1
81
  x2, y2, w2, h2 = box2
82
 
83
+ # Convert to top-left and bottom-right coordinates
84
  x1_min, y1_min = x1 - w1/2, y1 - h1/2
85
  x1_max, y1_max = x1 + w1/2, y1 + h1/2
86
  x2_min, y2_min = x2 - w2/2, y2 - h2/2
87
  x2_max, y2_max = x2 + w2/2, y2 + h2/2
88
 
89
+ # Calculate intersection
90
+ x_min = max(x1_min, x2_min)
91
+ y_min = max(y1_min, y2_min)
92
+ x_max = min(x1_max, x2_max)
93
+ y_max = min(y1_max, y2_max)
94
+
95
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
96
  area1 = w1 * h1
97
  area2 = w2 * h2
98
  union = area1 + area2 - intersection
 
252
  if not video.isOpened():
253
  raise ValueError("Could not open video file")
254
 
255
+ violations, snapshots, raw_detections = [], [], []
256
  frame_count = 0
257
  start_time = time.time()
258
  fps = video.get(cv2.CAP_PROP_FPS)
 
287
  cls, conf = int(box.cls), float(box.conf)
288
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
289
 
290
+ # Log all raw detections
291
+ logger.info(f"Raw Detection: class={cls}, conf={conf:.2f}, label={label}")
292
+ raw_detections.append({
293
+ "frame": frame_count,
294
+ "class": cls,
295
+ "confidence": round(conf, 2),
296
+ "label": label if label in CONFIG["VIOLATION_LABELS"].values() else "unknown",
297
+ "timestamp": frame_count / fps
298
+ })
299
 
300
  if label not in CONFIG["VIOLATION_LABELS"].values():
301
  logger.info(f"Skipping unknown class: {cls}")
 
409
  return {
410
  "violations": [],
411
  "snapshots": [],
412
+ "raw_detections": raw_detections,
413
  "score": 100,
414
  "salesforce_record_id": None,
415
  "violation_details_url": "",
 
423
  return {
424
  "violations": violations,
425
  "snapshots": snapshots,
426
+ "raw_detections": raw_detections,
427
  "score": score,
428
  "salesforce_record_id": report_id,
429
  "violation_details_url": final_pdf_url,
 
434
  return {
435
  "violations": [],
436
  "snapshots": [],
437
+ "raw_detections": [],
438
  "score": 100,
439
  "salesforce_record_id": None,
440
  "violation_details_url": "",
 
443
 
444
  def gradio_interface(video_file):
445
  if not video_file:
446
+ return "No file uploaded.", "", "No file uploaded.", "", "", [], "No raw detections."
447
  try:
448
+ yield "Processing video... please wait.", "", "", "", "", [], "Processing..."
449
 
450
  with open(video_file, "rb") as f:
451
  video_data = f.read()
 
453
  result = process_video(video_data)
454
 
455
  if result.get("message"):
456
+ yield result["message"], "", "", "", "", [], "Error in processing."
457
  return
458
 
459
  violation_table = "No violations detected."
 
478
  )
479
  snapshot_images = [s["snapshot_base64"] for s in result["snapshots"]]
480
 
481
+ raw_detections_text = "No raw detections logged."
482
+ if result["raw_detections"]:
483
+ header = "| Frame | Timestamp (s) | Class | Label | Confidence |\n"
484
+ separator = "|-------|---------------|-------|----------------|------------|\n"
485
+ rows = []
486
+ for d in result["raw_detections"]:
487
+ row = f"| {d['frame']:<5} | {d['timestamp']:.2f} | {d['class']:<5} | {d['label']:<14} | {d['confidence']:.2f} |"
488
+ rows.append(row)
489
+ raw_detections_text = header + separator + "\n".join(rows)
490
+
491
  yield (
492
  violation_table,
493
  f"Safety Score: {result['score']}%",
494
  snapshots_text,
495
  f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
496
  result["violation_details_url"] or "N/A",
497
+ snapshot_images,
498
+ raw_detections_text
499
  )
500
  except Exception as e:
501
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
502
+ yield f"Error: {str(e)}", "", "Error in processing.", "", "", [], "Error in processing."
503
 
504
+ interface beq gr.Interface(
505
  fn=gradio_interface,
506
  inputs=gr.Video(label="Upload Site Video"),
507
  outputs=[
 
510
  gr.Markdown(label="Snapshots"),
511
  gr.Textbox(label="Salesforce Record ID"),
512
  gr.Textbox(label="Violation Details URL"),
513
+ gr.Gallery(label="Violation Snapshots"),
514
+ gr.Markdown(label="Raw Detections (Debug)")
515
  ],
516
  title="Worksite Safety Violation Analyzer",
517
  description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture). Non-violations are ignored.",