Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,11 +30,11 @@ CONFIG = {
|
|
| 30 |
4: "improper_tool_use"
|
| 31 |
},
|
| 32 |
"CLASS_COLORS": {
|
| 33 |
-
"no_helmet": (0, 0, 255),
|
| 34 |
-
"no_harness": (0, 165, 255),
|
| 35 |
-
"unsafe_posture": (0, 255, 0),
|
| 36 |
-
"unsafe_zone": (255, 0, 0),
|
| 37 |
-
"improper_tool_use": (255, 255, 0)
|
| 38 |
},
|
| 39 |
"DISPLAY_NAMES": {
|
| 40 |
"no_helmet": "No Helmet Violation",
|
|
@@ -50,13 +50,7 @@ CONFIG = {
|
|
| 50 |
"domain": "login"
|
| 51 |
},
|
| 52 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 53 |
-
"FRAME_SKIP":
|
| 54 |
-
"no_helmet": 2,
|
| 55 |
-
"no_harness": 1,
|
| 56 |
-
"unsafe_posture": 1,
|
| 57 |
-
"unsafe_zone": 1,
|
| 58 |
-
"improper_tool_use": 1
|
| 59 |
-
},
|
| 60 |
"CONFIDENCE_THRESHOLDS": {
|
| 61 |
"no_helmet": 0.5,
|
| 62 |
"no_harness": 0.15,
|
|
@@ -65,8 +59,9 @@ CONFIG = {
|
|
| 65 |
"improper_tool_use": 0.15
|
| 66 |
},
|
| 67 |
"IOU_THRESHOLD": 0.4,
|
| 68 |
-
"MIN_VIOLATION_FRAMES":
|
| 69 |
-
"HELMET_CONFIDENCE_THRESHOLD": 0.7
|
|
|
|
| 70 |
}
|
| 71 |
|
| 72 |
# Setup logging
|
|
@@ -101,7 +96,6 @@ model = load_model()
|
|
| 101 |
# Enhanced Helper Functions
|
| 102 |
# ==========================
|
| 103 |
def draw_detections(frame, detections):
|
| 104 |
-
"""Draw bounding boxes and labels on frame"""
|
| 105 |
for det in detections:
|
| 106 |
label = det.get("violation", "Unknown")
|
| 107 |
confidence = det.get("confidence", 0.0)
|
|
@@ -121,7 +115,6 @@ def draw_detections(frame, detections):
|
|
| 121 |
return frame
|
| 122 |
|
| 123 |
def calculate_iou(box1, box2):
|
| 124 |
-
"""Calculate Intersection over Union (IoU) for two bounding boxes."""
|
| 125 |
x1, y1, w1, h1 = box1
|
| 126 |
x2, y2, w2, h2 = box2
|
| 127 |
|
|
@@ -299,34 +292,41 @@ def process_video(video_data):
|
|
| 299 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 300 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 301 |
if fps <= 0:
|
| 302 |
-
fps = 30
|
| 303 |
video_duration = total_frames / fps
|
| 304 |
logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")
|
| 305 |
|
| 306 |
workers = []
|
| 307 |
violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 308 |
-
confirmed_violations = {}
|
| 309 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 310 |
-
helmet_compliance = {}
|
| 311 |
-
detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
|
| 314 |
ret, frame = video.read()
|
| 315 |
if not ret:
|
| 316 |
-
break
|
| 317 |
-
|
| 318 |
-
current_time = frame_count / fps
|
| 319 |
-
min_frame_skip = min(CONFIG["FRAME_SKIP"].values())
|
| 320 |
-
if frame_count % min_frame_skip != 0:
|
| 321 |
-
frame_count += 1
|
| 322 |
continue
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
| 327 |
|
| 328 |
# Run detection on this frame
|
| 329 |
-
results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"]
|
| 330 |
|
| 331 |
current_detections = []
|
| 332 |
for result in results:
|
|
@@ -347,7 +347,7 @@ def process_video(video_data):
|
|
| 347 |
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
| 348 |
|
| 349 |
current_detections.append({
|
| 350 |
-
"frame":
|
| 351 |
"violation": label,
|
| 352 |
"confidence": round(conf, 2),
|
| 353 |
"bounding_box": bbox,
|
|
@@ -355,16 +355,14 @@ def process_video(video_data):
|
|
| 355 |
})
|
| 356 |
detection_counts[label] += 1
|
| 357 |
|
| 358 |
-
logger.debug(f"Frame {
|
| 359 |
|
| 360 |
-
# Process detections and associate with workers
|
| 361 |
for detection in current_detections:
|
| 362 |
violation_type = detection.get("violation", None)
|
| 363 |
if violation_type is None:
|
| 364 |
logger.error(f"Invalid detection, missing 'violation' key: {detection}")
|
| 365 |
continue
|
| 366 |
|
| 367 |
-
# Helmet compliance check
|
| 368 |
if violation_type == "no_helmet":
|
| 369 |
matched_worker = None
|
| 370 |
max_iou = 0
|
|
@@ -376,7 +374,6 @@ def process_video(video_data):
|
|
| 376 |
|
| 377 |
if matched_worker:
|
| 378 |
worker_id = matched_worker["id"]
|
| 379 |
-
# Require high confidence and persistence for No Helmet violation
|
| 380 |
if worker_id not in helmet_compliance:
|
| 381 |
helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
|
| 382 |
helmet_compliance[worker_id]["no_helmet_frames"] += 1
|
|
@@ -387,7 +384,6 @@ def process_video(video_data):
|
|
| 387 |
logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
|
| 388 |
continue
|
| 389 |
|
| 390 |
-
# Find or create worker
|
| 391 |
matched_worker = None
|
| 392 |
max_iou = 0
|
| 393 |
|
|
@@ -409,11 +405,9 @@ def process_video(video_data):
|
|
| 409 |
"first_seen": current_time,
|
| 410 |
"last_seen": current_time
|
| 411 |
})
|
| 412 |
-
# Initialize helmet compliance for new worker
|
| 413 |
if worker_id not in helmet_compliance:
|
| 414 |
helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
|
| 415 |
|
| 416 |
-
# Skip if this violation type is already confirmed for this worker
|
| 417 |
if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
|
| 418 |
logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
|
| 419 |
continue
|
|
@@ -421,14 +415,10 @@ def process_video(video_data):
|
|
| 421 |
detection["worker_id"] = worker_id
|
| 422 |
violation_history[violation_type].append(detection)
|
| 423 |
|
| 424 |
-
# Clean up old workers
|
| 425 |
workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
|
| 426 |
|
| 427 |
-
frame_count += 1
|
| 428 |
-
|
| 429 |
logger.info(f"Detection counts: {detection_counts}")
|
| 430 |
|
| 431 |
-
# Process violation history to confirm persistent violations
|
| 432 |
for violation_type, detections in violation_history.items():
|
| 433 |
if not detections:
|
| 434 |
logger.info(f"No detections for {violation_type}")
|
|
@@ -442,16 +432,13 @@ def process_video(video_data):
|
|
| 442 |
|
| 443 |
for worker_id, worker_dets in worker_violations.items():
|
| 444 |
if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
|
| 445 |
-
# Skip if already confirmed
|
| 446 |
if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
|
| 447 |
continue
|
| 448 |
|
| 449 |
-
# Skip No Helmet if worker is compliant
|
| 450 |
if violation_type == "no_helmet":
|
| 451 |
if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
|
| 452 |
logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
|
| 453 |
continue
|
| 454 |
-
# Require persistent No Helmet detections
|
| 455 |
if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2:
|
| 456 |
logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections")
|
| 457 |
continue
|
|
@@ -486,7 +473,6 @@ def process_video(video_data):
|
|
| 486 |
logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
|
| 487 |
cap.release()
|
| 488 |
|
| 489 |
-
# Clean up video file after snapshots are captured
|
| 490 |
video.release()
|
| 491 |
os.remove(video_path)
|
| 492 |
logger.info(f"Video file {video_path} removed")
|
|
@@ -537,7 +523,6 @@ def gradio_interface(video_file):
|
|
| 537 |
with open(video_file, "rb") as f:
|
| 538 |
video_data = f.read()
|
| 539 |
|
| 540 |
-
# Use generator to yield updates
|
| 541 |
for status, score, snapshots_text, record_id, details_url in process_video(video_data):
|
| 542 |
yield status, score, snapshots_text, record_id, details_url
|
| 543 |
except Exception as e:
|
|
|
|
| 30 |
4: "improper_tool_use"
|
| 31 |
},
|
| 32 |
"CLASS_COLORS": {
|
| 33 |
+
"no_helmet": (0, 0, 255),
|
| 34 |
+
"no_harness": (0, 165, 255),
|
| 35 |
+
"unsafe_posture": (0, 255, 0),
|
| 36 |
+
"unsafe_zone": (255, 0, 0),
|
| 37 |
+
"improper_tool_use": (255, 255, 0)
|
| 38 |
},
|
| 39 |
"DISPLAY_NAMES": {
|
| 40 |
"no_helmet": "No Helmet Violation",
|
|
|
|
| 50 |
"domain": "login"
|
| 51 |
},
|
| 52 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 53 |
+
"FRAME_SKIP": 5, # Increased to process fewer frames (1 frame every 5 frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"CONFIDENCE_THRESHOLDS": {
|
| 55 |
"no_helmet": 0.5,
|
| 56 |
"no_harness": 0.15,
|
|
|
|
| 59 |
"improper_tool_use": 0.15
|
| 60 |
},
|
| 61 |
"IOU_THRESHOLD": 0.4,
|
| 62 |
+
"MIN_VIOLATION_FRAMES": 2, # Reduced to ensure violations are detected with fewer frames
|
| 63 |
+
"HELMET_CONFIDENCE_THRESHOLD": 0.7,
|
| 64 |
+
"MAX_PROCESSING_TIME": 30 # Enforce 30-second processing limit
|
| 65 |
}
|
| 66 |
|
| 67 |
# Setup logging
|
|
|
|
| 96 |
# Enhanced Helper Functions
|
| 97 |
# ==========================
|
| 98 |
def draw_detections(frame, detections):
|
|
|
|
| 99 |
for det in detections:
|
| 100 |
label = det.get("violation", "Unknown")
|
| 101 |
confidence = det.get("confidence", 0.0)
|
|
|
|
| 115 |
return frame
|
| 116 |
|
| 117 |
def calculate_iou(box1, box2):
|
|
|
|
| 118 |
x1, y1, w1, h1 = box1
|
| 119 |
x2, y2, w2, h2 = box2
|
| 120 |
|
|
|
|
| 292 |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 293 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 294 |
if fps <= 0:
|
| 295 |
+
fps = 30
|
| 296 |
video_duration = total_frames / fps
|
| 297 |
logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")
|
| 298 |
|
| 299 |
workers = []
|
| 300 |
violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 301 |
+
confirmed_violations = {}
|
| 302 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 303 |
+
helmet_compliance = {}
|
| 304 |
+
detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 305 |
+
start_time = time.time()
|
| 306 |
+
|
| 307 |
+
# Calculate frames to process within 30 seconds
|
| 308 |
+
target_frames = int(total_frames / CONFIG["FRAME_SKIP"])
|
| 309 |
+
frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
|
| 310 |
+
|
| 311 |
+
processed_frames = 0
|
| 312 |
+
for idx in frame_indices:
|
| 313 |
+
elapsed_time = time.time() - start_time
|
| 314 |
+
if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 315 |
+
logger.info(f"Processing time limit of {CONFIG['MAX_PROCESSING_TIME']} seconds reached. Processed {processed_frames}/{target_frames} frames.")
|
| 316 |
+
break
|
| 317 |
|
| 318 |
+
video.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 319 |
ret, frame = video.read()
|
| 320 |
if not ret:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
continue
|
| 322 |
|
| 323 |
+
processed_frames += 1
|
| 324 |
+
current_time = idx / fps
|
| 325 |
+
progress = (processed_frames / target_frames) * 100
|
| 326 |
+
yield f"Processing video... {progress:.1f}% complete (Frame {idx}/{total_frames})", "", "", "", ""
|
| 327 |
|
| 328 |
# Run detection on this frame
|
| 329 |
+
results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"])
|
| 330 |
|
| 331 |
current_detections = []
|
| 332 |
for result in results:
|
|
|
|
| 347 |
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
| 348 |
|
| 349 |
current_detections.append({
|
| 350 |
+
"frame": idx,
|
| 351 |
"violation": label,
|
| 352 |
"confidence": round(conf, 2),
|
| 353 |
"bounding_box": bbox,
|
|
|
|
| 355 |
})
|
| 356 |
detection_counts[label] += 1
|
| 357 |
|
| 358 |
+
logger.debug(f"Frame {idx}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
|
| 359 |
|
|
|
|
| 360 |
for detection in current_detections:
|
| 361 |
violation_type = detection.get("violation", None)
|
| 362 |
if violation_type is None:
|
| 363 |
logger.error(f"Invalid detection, missing 'violation' key: {detection}")
|
| 364 |
continue
|
| 365 |
|
|
|
|
| 366 |
if violation_type == "no_helmet":
|
| 367 |
matched_worker = None
|
| 368 |
max_iou = 0
|
|
|
|
| 374 |
|
| 375 |
if matched_worker:
|
| 376 |
worker_id = matched_worker["id"]
|
|
|
|
| 377 |
if worker_id not in helmet_compliance:
|
| 378 |
helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
|
| 379 |
helmet_compliance[worker_id]["no_helmet_frames"] += 1
|
|
|
|
| 384 |
logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
|
| 385 |
continue
|
| 386 |
|
|
|
|
| 387 |
matched_worker = None
|
| 388 |
max_iou = 0
|
| 389 |
|
|
|
|
| 405 |
"first_seen": current_time,
|
| 406 |
"last_seen": current_time
|
| 407 |
})
|
|
|
|
| 408 |
if worker_id not in helmet_compliance:
|
| 409 |
helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
|
| 410 |
|
|
|
|
| 411 |
if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
|
| 412 |
logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
|
| 413 |
continue
|
|
|
|
| 415 |
detection["worker_id"] = worker_id
|
| 416 |
violation_history[violation_type].append(detection)
|
| 417 |
|
|
|
|
| 418 |
workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
|
| 419 |
|
|
|
|
|
|
|
| 420 |
logger.info(f"Detection counts: {detection_counts}")
|
| 421 |
|
|
|
|
| 422 |
for violation_type, detections in violation_history.items():
|
| 423 |
if not detections:
|
| 424 |
logger.info(f"No detections for {violation_type}")
|
|
|
|
| 432 |
|
| 433 |
for worker_id, worker_dets in worker_violations.items():
|
| 434 |
if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
|
|
|
|
| 435 |
if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
|
| 436 |
continue
|
| 437 |
|
|
|
|
| 438 |
if violation_type == "no_helmet":
|
| 439 |
if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
|
| 440 |
logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
|
| 441 |
continue
|
|
|
|
| 442 |
if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2:
|
| 443 |
logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections")
|
| 444 |
continue
|
|
|
|
| 473 |
logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
|
| 474 |
cap.release()
|
| 475 |
|
|
|
|
| 476 |
video.release()
|
| 477 |
os.remove(video_path)
|
| 478 |
logger.info(f"Video file {video_path} removed")
|
|
|
|
| 523 |
with open(video_file, "rb") as f:
|
| 524 |
video_data = f.read()
|
| 525 |
|
|
|
|
| 526 |
for status, score, snapshots_text, record_id, details_url in process_video(video_data):
|
| 527 |
yield status, score, snapshots_text, record_id, details_url
|
| 528 |
except Exception as e:
|