PrashanthB461 commited on
Commit
6104e09
·
verified ·
1 Parent(s): 1aebe73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -87
app.py CHANGED
@@ -28,6 +28,13 @@ CONFIG = {
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
 
 
 
 
 
 
 
31
  "DISPLAY_NAMES": {
32
  "no_helmet": "No Helmet Violation",
33
  "no_harness": "No Harness Violation",
@@ -42,17 +49,11 @@ CONFIG = {
42
  "domain": "login"
43
  },
44
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
45
- "FRAME_SKIP": 10, # Reduced for better detection
46
- "MAX_PROCESSING_TIME": 45,
47
- "CONFIDENCE_THRESHOLD": {
48
- "no_helmet": 0.4,
49
- "no_harness": 0.35,
50
- "unsafe_posture": 0.3,
51
- "unsafe_zone": 0.3,
52
- "improper_tool_use": 0.35
53
- },
54
  "IOU_THRESHOLD": 0.4,
55
- "MIN_VIOLATION_DURATION": 2 # seconds
56
  }
57
 
58
  # Setup logging
@@ -86,6 +87,27 @@ model = load_model()
86
  # ==========================
87
  # Enhanced Helper Functions
88
  # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def calculate_iou(box1, box2):
90
  """Calculate Intersection over Union (IoU) for two bounding boxes."""
91
  x1, y1, w1, h1 = box1
@@ -110,15 +132,6 @@ def calculate_iou(box1, box2):
110
 
111
  return intersection / union if union > 0 else 0
112
 
113
- def is_violation_persistent(violation_type, worker_id, violations, fps):
114
- """Check if a violation persists for the required duration."""
115
- violation_times = [v['timestamp'] for v in violations
116
- if v['violation'] == violation_type and v['worker_id'] == worker_id]
117
- if len(violation_times) < 2:
118
- return False
119
- duration = max(violation_times) - min(violation_times)
120
- return duration >= CONFIG["MIN_VIOLATION_DURATION"]
121
-
122
  # ==========================
123
  # Salesforce Integration (unchanged)
124
  # ==========================
@@ -281,16 +294,18 @@ def process_video(video_data):
281
  if not video.isOpened():
282
  raise ValueError("Could not open video file")
283
 
284
- violations, snapshots = [], []
 
285
  frame_count = 0
286
  start_time = time.time()
287
  fps = video.get(cv2.CAP_PROP_FPS)
288
  if fps <= 0:
289
  fps = 30 # Default assumption if FPS cannot be determined
290
 
 
 
 
291
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
292
- workers = [] # List to track workers
293
- violation_history = [] # Track all potential violations before filtering
294
 
295
  logger.info(f"Processing video with FPS: {fps}")
296
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
@@ -308,31 +323,28 @@ def process_video(video_data):
308
  logger.info("Processing time limit reached")
309
  break
310
 
 
 
311
  # Run detection on this frame
312
  results = model(frame, device=device)
313
- current_time = frame_count / fps
314
 
 
315
  for result in results:
316
  boxes = result.boxes
317
- logger.debug(f"Frame {frame_count}: Found {len(boxes)} potential detections")
318
-
319
  for box in boxes:
320
  cls = int(box.cls)
321
  conf = float(box.conf)
322
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
323
 
324
  if label is None:
325
- continue # Skip unknown classes
326
 
327
- conf_threshold = CONFIG["CONFIDENCE_THRESHOLD"].get(label, 0.3)
328
- if conf < conf_threshold:
329
- logger.debug(f"Skipping {label} with low confidence: {conf:.2f} < {conf_threshold}")
330
  continue
331
 
332
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
333
 
334
- # Store potential violation (will filter later)
335
- violation_history.append({
336
  "frame": frame_count,
337
  "violation": label,
338
  "confidence": round(conf, 2),
@@ -340,72 +352,73 @@ def process_video(video_data):
340
  "timestamp": current_time
341
  })
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  frame_count += 1
344
 
345
  video.release()
346
  os.remove(video_path)
347
 
348
- # Process violation history to track workers and persistent violations
349
- workers = []
350
- for v in violation_history:
351
- # Find matching worker
352
- matched_worker = None
353
- max_iou = 0
354
-
355
- for worker in workers:
356
- iou = calculate_iou(v["bounding_box"], worker["bbox"])
357
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
358
- max_iou = iou
359
- matched_worker = worker
360
-
361
- if matched_worker:
362
- # Update worker's violation history
363
- matched_worker["violations"].append(v)
364
- matched_worker["bbox"] = v["bounding_box"]
365
- matched_worker["last_seen"] = v["timestamp"]
366
- v["worker_id"] = matched_worker["id"]
367
- else:
368
- # New worker
369
- worker_id = len(workers) + 1
370
- workers.append({
371
- "id": worker_id,
372
- "bbox": v["bounding_box"],
373
- "violations": [v],
374
- "first_seen": v["timestamp"],
375
- "last_seen": v["timestamp"]
376
- })
377
- v["worker_id"] = worker_id
378
-
379
- # Filter violations to only include those that persist for minimum duration
380
- final_violations = []
381
- for worker in workers:
382
- # Group violations by type
383
- violations_by_type = {}
384
- for v in worker["violations"]:
385
- if v["violation"] not in violations_by_type:
386
- violations_by_type[v["violation"]] = []
387
- violations_by_type[v["violation"]].append(v)
388
-
389
- # Check each violation type for persistence
390
- for violation_type, v_list in violations_by_type.items():
391
- if len(v_list) < 2:
392
- continue # Need multiple detections to check duration
393
 
394
- duration = max(v["timestamp"] for v in v_list) - min(v["timestamp"] for v in v_list)
395
- if duration >= CONFIG["MIN_VIOLATION_DURATION"]:
 
 
 
 
 
 
 
 
396
  # Take the highest confidence detection
397
- best_detection = max(v_list, key=lambda x: x["confidence"])
398
- final_violations.append(best_detection)
399
 
400
  # Capture snapshot if not already taken
401
  if not snapshot_taken[violation_type]:
402
- # We need to get the frame for this violation
403
  cap = cv2.VideoCapture(video_path)
404
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
405
  ret, snapshot_frame = cap.read()
406
  cap.release()
407
 
408
  if ret:
 
 
 
409
  snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
410
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
411
  cv2.imwrite(snapshot_path, snapshot_frame)
@@ -418,7 +431,7 @@ def process_video(video_data):
418
  snapshot_taken[violation_type] = True
419
 
420
  # Final processing
421
- if not final_violations:
422
  logger.info("No persistent violations detected")
423
  return {
424
  "violations": [],
@@ -429,12 +442,12 @@ def process_video(video_data):
429
  "message": "No violations detected in the video."
430
  }
431
 
432
- score = calculate_safety_score(final_violations)
433
- pdf_path, pdf_url, pdf_file = generate_violation_pdf(final_violations, score)
434
- report_id, final_pdf_url = push_report_to_salesforce(final_violations, score, pdf_path, pdf_file)
435
 
436
  return {
437
- "violations": final_violations,
438
  "snapshots": snapshots,
439
  "score": score,
440
  "salesforce_record_id": report_id,
@@ -453,7 +466,7 @@ def process_video(video_data):
453
  }
454
 
455
  # ==========================
456
- # Gradio Interface (unchanged)
457
  # ==========================
458
  def gradio_interface(video_file):
459
  if not video_file:
 
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
+ "CLASS_COLORS": {
32
+ "no_helmet": (0, 0, 255), # Red
33
+ "no_harness": (0, 165, 255), # Orange
34
+ "unsafe_posture": (0, 255, 0), # Green
35
+ "unsafe_zone": (255, 0, 0), # Blue
36
+ "improper_tool_use": (255, 255, 0) # Yellow
37
+ },
38
  "DISPLAY_NAMES": {
39
  "no_helmet": "No Helmet Violation",
40
  "no_harness": "No Harness Violation",
 
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 5, # Reduced for better detection
53
+ "MAX_PROCESSING_TIME": 60,
54
+ "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
 
 
 
 
 
 
55
  "IOU_THRESHOLD": 0.4,
56
+ "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
57
  }
58
 
59
  # Setup logging
 
87
  # ==========================
88
  # Enhanced Helper Functions
89
  # ==========================
90
+ def draw_detections(frame, detections):
91
+ """Draw bounding boxes and labels on frame"""
92
+ for det in detections:
93
+ label = det["violation"]
94
+ confidence = det["confidence"]
95
+ x, y, w, h = det["bounding_box"]
96
+
97
+ # Convert from center coordinates to corner coordinates
98
+ x1 = int(x - w/2)
99
+ y1 = int(y - h/2)
100
+ x2 = int(x + w/2)
101
+ y2 = int(y + h/2)
102
+
103
+ color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
+
106
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
+ cv2.putText(frame, display_text, (x1, y1-10),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
109
+ return frame
110
+
111
  def calculate_iou(box1, box2):
112
  """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
  x1, y1, w1, h1 = box1
 
132
 
133
  return intersection / union if union > 0 else 0
134
 
 
 
 
 
 
 
 
 
 
135
  # ==========================
136
  # Salesforce Integration (unchanged)
137
  # ==========================
 
294
  if not video.isOpened():
295
  raise ValueError("Could not open video file")
296
 
297
+ violations = []
298
+ snapshots = []
299
  frame_count = 0
300
  start_time = time.time()
301
  fps = video.get(cv2.CAP_PROP_FPS)
302
  if fps <= 0:
303
  fps = 30 # Default assumption if FPS cannot be determined
304
 
305
+ # Structure to track workers and their violations
306
+ workers = []
307
+ violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
 
 
309
 
310
  logger.info(f"Processing video with FPS: {fps}")
311
  logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
 
323
  logger.info("Processing time limit reached")
324
  break
325
 
326
+ current_time = frame_count / fps
327
+
328
  # Run detection on this frame
329
  results = model(frame, device=device)
 
330
 
331
+ current_detections = []
332
  for result in results:
333
  boxes = result.boxes
 
 
334
  for box in boxes:
335
  cls = int(box.cls)
336
  conf = float(box.conf)
337
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
 
339
  if label is None:
340
+ continue
341
 
342
+ if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
 
343
  continue
344
 
345
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
 
347
+ current_detections.append({
 
348
  "frame": frame_count,
349
  "violation": label,
350
  "confidence": round(conf, 2),
 
352
  "timestamp": current_time
353
  })
354
 
355
+ # Process detections and associate with workers
356
+ for detection in current_detections:
357
+ # Find matching worker
358
+ matched_worker = None
359
+ max_iou = 0
360
+
361
+ for worker in workers:
362
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
+ max_iou = iou
365
+ matched_worker = worker
366
+
367
+ if matched_worker:
368
+ # Update worker's position
369
+ matched_worker["bbox"] = detection["bounding_box"]
370
+ matched_worker["last_seen"] = current_time
371
+ worker_id = matched_worker["id"]
372
+ else:
373
+ # New worker
374
+ worker_id = len(workers) + 1
375
+ workers.append({
376
+ "id": worker_id,
377
+ "bbox": detection["bounding_box"],
378
+ "first_seen": current_time,
379
+ "last_seen": current_time
380
+ })
381
+
382
+ # Add to violation history
383
+ detection["worker_id"] = worker_id
384
+ violation_history[detection["violation"]].append(detection)
385
+
386
  frame_count += 1
387
 
388
  video.release()
389
  os.remove(video_path)
390
 
391
+ # Process violation history to confirm persistent violations
392
+ for violation_type, detections in violation_history.items():
393
+ if not detections:
394
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
+ # Group by worker
397
+ worker_violations = {}
398
+ for det in detections:
399
+ if det["worker_id"] not in worker_violations:
400
+ worker_violations[det["worker_id"]] = []
401
+ worker_violations[det["worker_id"]].append(det)
402
+
403
+ # Check each worker's violations for persistence
404
+ for worker_id, worker_dets in worker_violations.items():
405
+ if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
  # Take the highest confidence detection
407
+ best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
+ violations.append(best_detection)
409
 
410
  # Capture snapshot if not already taken
411
  if not snapshot_taken[violation_type]:
412
+ # Get the frame for this violation
413
  cap = cv2.VideoCapture(video_path)
414
  cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
  ret, snapshot_frame = cap.read()
416
  cap.release()
417
 
418
  if ret:
419
+ # Draw detections on snapshot
420
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
+
422
  snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
  cv2.imwrite(snapshot_path, snapshot_frame)
 
431
  snapshot_taken[violation_type] = True
432
 
433
  # Final processing
434
+ if not violations:
435
  logger.info("No persistent violations detected")
436
  return {
437
  "violations": [],
 
442
  "message": "No violations detected in the video."
443
  }
444
 
445
+ score = calculate_safety_score(violations)
446
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
447
+ report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
448
 
449
  return {
450
+ "violations": violations,
451
  "snapshots": snapshots,
452
  "score": score,
453
  "salesforce_record_id": report_id,
 
466
  }
467
 
468
  # ==========================
469
+ # Gradio Interface
470
  # ==========================
471
  def gradio_interface(video_file):
472
  if not video_file: