Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -51,7 +51,7 @@ CONFIG = {
|
|
| 51 |
},
|
| 52 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 53 |
"FRAME_SKIP": {
|
| 54 |
-
"no_helmet": 3,
|
| 55 |
"no_harness": 2,
|
| 56 |
"unsafe_posture": 2,
|
| 57 |
"unsafe_zone": 2,
|
|
@@ -59,14 +59,14 @@ CONFIG = {
|
|
| 59 |
},
|
| 60 |
"MAX_PROCESSING_TIME": 60,
|
| 61 |
"CONFIDENCE_THRESHOLDS": {
|
| 62 |
-
"no_helmet": 0.3,
|
| 63 |
"no_harness": 0.2,
|
| 64 |
"unsafe_posture": 0.2,
|
| 65 |
"unsafe_zone": 0.2,
|
| 66 |
"improper_tool_use": 0.2
|
| 67 |
},
|
| 68 |
"IOU_THRESHOLD": 0.4,
|
| 69 |
-
"MIN_VIOLATION_FRAMES": 3
|
| 70 |
}
|
| 71 |
|
| 72 |
# Setup logging
|
|
@@ -103,9 +103,9 @@ model = load_model()
|
|
| 103 |
def draw_detections(frame, detections):
|
| 104 |
"""Draw bounding boxes and labels on frame"""
|
| 105 |
for det in detections:
|
| 106 |
-
label = det
|
| 107 |
-
confidence = det
|
| 108 |
-
x, y, w, h = det
|
| 109 |
|
| 110 |
x1 = int(x - w/2)
|
| 111 |
y1 = int(y - h/2)
|
|
@@ -178,8 +178,8 @@ def generate_violation_pdf(violations, score):
|
|
| 178 |
c.drawString(1 * inch, y_position, "No violations detected.")
|
| 179 |
else:
|
| 180 |
for v in violations:
|
| 181 |
-
display_name = CONFIG["DISPLAY_NAMES"].get(v
|
| 182 |
-
text = f"{display_name} at {v
|
| 183 |
c.drawString(1 * inch, y_position, text)
|
| 184 |
y_position -= 0.3 * inch
|
| 185 |
if y_position < 1 * inch:
|
|
@@ -228,7 +228,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
|
| 228 |
try:
|
| 229 |
sf = connect_to_salesforce()
|
| 230 |
violations_text = "\n".join(
|
| 231 |
-
f"{CONFIG['DISPLAY_NAMES'].get(v
|
| 232 |
for v in violations
|
| 233 |
) or "No violations detected."
|
| 234 |
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
|
|
@@ -275,7 +275,7 @@ def calculate_safety_score(violations):
|
|
| 275 |
"unsafe_zone": 35,
|
| 276 |
"improper_tool_use": 25
|
| 277 |
}
|
| 278 |
-
total_penalty = sum(penalties.get(v
|
| 279 |
score = 100 - total_penalty
|
| 280 |
return max(score, 0)
|
| 281 |
|
|
@@ -337,24 +337,32 @@ def process_video(video_data):
|
|
| 337 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 338 |
|
| 339 |
if label is None:
|
|
|
|
| 340 |
continue
|
| 341 |
|
| 342 |
if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
|
|
|
|
| 343 |
continue
|
| 344 |
|
| 345 |
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
| 346 |
|
| 347 |
current_detections.append({
|
| 348 |
"frame": frame_count,
|
| 349 |
-
"
|
| 350 |
"confidence": round(conf, 2),
|
| 351 |
"bounding_box": bbox,
|
| 352 |
"timestamp": current_time
|
| 353 |
})
|
| 354 |
|
|
|
|
|
|
|
| 355 |
# Process detections and associate with workers
|
| 356 |
for detection in current_detections:
|
| 357 |
-
violation_type = detection
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
# Skip No Helmet detection if worker is compliant
|
| 359 |
if violation_type == "no_helmet":
|
| 360 |
matched_worker = None
|
|
@@ -366,7 +374,8 @@ def process_video(video_data):
|
|
| 366 |
matched_worker = worker
|
| 367 |
|
| 368 |
if matched_worker and matched_worker["id"] in helmet_compliance:
|
| 369 |
-
|
|
|
|
| 370 |
|
| 371 |
# Find or create worker
|
| 372 |
matched_worker = None
|
|
@@ -393,14 +402,16 @@ def process_video(video_data):
|
|
| 393 |
|
| 394 |
# Skip if this violation type is already confirmed for this worker
|
| 395 |
if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
|
|
|
|
| 396 |
continue
|
| 397 |
|
| 398 |
detection["worker_id"] = worker_id
|
| 399 |
violation_history[violation_type].append(detection)
|
| 400 |
|
| 401 |
# Update helmet compliance (simulate by checking if No Helmet confidence is low)
|
| 402 |
-
if violation_type == "no_helmet" and
|
| 403 |
helmet_compliance[worker_id] = True
|
|
|
|
| 404 |
|
| 405 |
# Clean up old workers
|
| 406 |
workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
|
|
@@ -413,6 +424,7 @@ def process_video(video_data):
|
|
| 413 |
# Process violation history to confirm persistent violations
|
| 414 |
for violation_type, detections in violation_history.items():
|
| 415 |
if not detections:
|
|
|
|
| 416 |
continue
|
| 417 |
|
| 418 |
worker_violations = {}
|
|
@@ -429,6 +441,7 @@ def process_video(video_data):
|
|
| 429 |
|
| 430 |
# Skip No Helmet if worker is compliant
|
| 431 |
if violation_type == "no_helmet" and worker_id in helmet_compliance:
|
|
|
|
| 432 |
continue
|
| 433 |
|
| 434 |
best_detection = max(worker_dets, key=lambda x: x["confidence"])
|
|
@@ -457,6 +470,7 @@ def process_video(video_data):
|
|
| 457 |
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
|
| 458 |
})
|
| 459 |
snapshot_taken[violation_type] = True
|
|
|
|
| 460 |
|
| 461 |
if not violations:
|
| 462 |
logger.info("No persistent violations detected")
|
|
@@ -473,6 +487,7 @@ def process_video(video_data):
|
|
| 473 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 474 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
| 475 |
|
|
|
|
| 476 |
return {
|
| 477 |
"violations": violations,
|
| 478 |
"snapshots": snapshots,
|
|
@@ -517,8 +532,8 @@ def gradio_interface(video_file):
|
|
| 517 |
rows = []
|
| 518 |
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 519 |
for v in result["violations"]:
|
| 520 |
-
display_name = violation_name_map.get(v
|
| 521 |
-
row = f"| {display_name:<22} | {v
|
| 522 |
rows.append(row)
|
| 523 |
violation_table = header + separator + "\n".join(rows)
|
| 524 |
|
|
@@ -526,7 +541,7 @@ def gradio_interface(video_file):
|
|
| 526 |
if result["snapshots"]:
|
| 527 |
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 528 |
snapshots_text = "\n".join(
|
| 529 |
-
f"- Snapshot for {violation_name_map.get(s
|
| 530 |
for s in result["snapshots"]
|
| 531 |
)
|
| 532 |
|
|
|
|
| 51 |
},
|
| 52 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 53 |
"FRAME_SKIP": {
|
| 54 |
+
"no_helmet": 3,
|
| 55 |
"no_harness": 2,
|
| 56 |
"unsafe_posture": 2,
|
| 57 |
"unsafe_zone": 2,
|
|
|
|
| 59 |
},
|
| 60 |
"MAX_PROCESSING_TIME": 60,
|
| 61 |
"CONFIDENCE_THRESHOLDS": {
|
| 62 |
+
"no_helmet": 0.3,
|
| 63 |
"no_harness": 0.2,
|
| 64 |
"unsafe_posture": 0.2,
|
| 65 |
"unsafe_zone": 0.2,
|
| 66 |
"improper_tool_use": 0.2
|
| 67 |
},
|
| 68 |
"IOU_THRESHOLD": 0.4,
|
| 69 |
+
"MIN_VIOLATION_FRAMES": 3
|
| 70 |
}
|
| 71 |
|
| 72 |
# Setup logging
|
|
|
|
| 103 |
def draw_detections(frame, detections):
|
| 104 |
"""Draw bounding boxes and labels on frame"""
|
| 105 |
for det in detections:
|
| 106 |
+
label = det.get("violation", "Unknown")
|
| 107 |
+
confidence = det.get("confidence", 0.0)
|
| 108 |
+
x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
|
| 109 |
|
| 110 |
x1 = int(x - w/2)
|
| 111 |
y1 = int(y - h/2)
|
|
|
|
| 178 |
c.drawString(1 * inch, y_position, "No violations detected.")
|
| 179 |
else:
|
| 180 |
for v in violations:
|
| 181 |
+
display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
|
| 182 |
+
text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
|
| 183 |
c.drawString(1 * inch, y_position, text)
|
| 184 |
y_position -= 0.3 * inch
|
| 185 |
if y_position < 1 * inch:
|
|
|
|
| 228 |
try:
|
| 229 |
sf = connect_to_salesforce()
|
| 230 |
violations_text = "\n".join(
|
| 231 |
+
f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
|
| 232 |
for v in violations
|
| 233 |
) or "No violations detected."
|
| 234 |
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
|
|
|
|
| 275 |
"unsafe_zone": 35,
|
| 276 |
"improper_tool_use": 25
|
| 277 |
}
|
| 278 |
+
total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
|
| 279 |
score = 100 - total_penalty
|
| 280 |
return max(score, 0)
|
| 281 |
|
|
|
|
| 337 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 338 |
|
| 339 |
if label is None:
|
| 340 |
+
logger.warning(f"Unknown class ID {cls} detected, skipping")
|
| 341 |
continue
|
| 342 |
|
| 343 |
if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
|
| 344 |
+
logger.debug(f"Detection {label} with confidence {conf:.2f} below threshold, skipping")
|
| 345 |
continue
|
| 346 |
|
| 347 |
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
| 348 |
|
| 349 |
current_detections.append({
|
| 350 |
"frame": frame_count,
|
| 351 |
+
"violation": label, # Corrected key
|
| 352 |
"confidence": round(conf, 2),
|
| 353 |
"bounding_box": bbox,
|
| 354 |
"timestamp": current_time
|
| 355 |
})
|
| 356 |
|
| 357 |
+
logger.debug(f"Frame {frame_count}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")
|
| 358 |
+
|
| 359 |
# Process detections and associate with workers
|
| 360 |
for detection in current_detections:
|
| 361 |
+
violation_type = detection.get("violation", None)
|
| 362 |
+
if violation_type is None:
|
| 363 |
+
logger.error(f"Invalid detection, missing 'violation' key: {detection}")
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
# Skip No Helmet detection if worker is compliant
|
| 367 |
if violation_type == "no_helmet":
|
| 368 |
matched_worker = None
|
|
|
|
| 374 |
matched_worker = worker
|
| 375 |
|
| 376 |
if matched_worker and matched_worker["id"] in helmet_compliance:
|
| 377 |
+
logger.debug(f"Worker {matched_worker['id']} has helmet, skipping no_helmet violation")
|
| 378 |
+
continue
|
| 379 |
|
| 380 |
# Find or create worker
|
| 381 |
matched_worker = None
|
|
|
|
| 402 |
|
| 403 |
# Skip if this violation type is already confirmed for this worker
|
| 404 |
if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
|
| 405 |
+
logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
|
| 406 |
continue
|
| 407 |
|
| 408 |
detection["worker_id"] = worker_id
|
| 409 |
violation_history[violation_type].append(detection)
|
| 410 |
|
| 411 |
# Update helmet compliance (simulate by checking if No Helmet confidence is low)
|
| 412 |
+
if violation_type == "no_helmet" and detection["confidence"] < 0.5:
|
| 413 |
helmet_compliance[worker_id] = True
|
| 414 |
+
logger.debug(f"Worker {worker_id} marked as helmet compliant")
|
| 415 |
|
| 416 |
# Clean up old workers
|
| 417 |
workers = [w for w in workers if current_time - w["last_seen"] < 5.0]
|
|
|
|
| 424 |
# Process violation history to confirm persistent violations
|
| 425 |
for violation_type, detections in violation_history.items():
|
| 426 |
if not detections:
|
| 427 |
+
logger.info(f"No detections for {violation_type}")
|
| 428 |
continue
|
| 429 |
|
| 430 |
worker_violations = {}
|
|
|
|
| 441 |
|
| 442 |
# Skip No Helmet if worker is compliant
|
| 443 |
if violation_type == "no_helmet" and worker_id in helmet_compliance:
|
| 444 |
+
logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
|
| 445 |
continue
|
| 446 |
|
| 447 |
best_detection = max(worker_dets, key=lambda x: x["confidence"])
|
|
|
|
| 470 |
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
|
| 471 |
})
|
| 472 |
snapshot_taken[violation_type] = True
|
| 473 |
+
logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
|
| 474 |
|
| 475 |
if not violations:
|
| 476 |
logger.info("No persistent violations detected")
|
|
|
|
| 487 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 488 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
| 489 |
|
| 490 |
+
logger.info(f"Processing complete: {len(violations)} violations detected, score: {score}%")
|
| 491 |
return {
|
| 492 |
"violations": violations,
|
| 493 |
"snapshots": snapshots,
|
|
|
|
| 532 |
rows = []
|
| 533 |
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 534 |
for v in result["violations"]:
|
| 535 |
+
display_name = violation_name_map.get(v.get("violation", "Unknown"), "Unknown")
|
| 536 |
+
row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |"
|
| 537 |
rows.append(row)
|
| 538 |
violation_table = header + separator + "\n".join(rows)
|
| 539 |
|
|
|
|
| 541 |
if result["snapshots"]:
|
| 542 |
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 543 |
snapshots_text = "\n".join(
|
| 544 |
+
f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: })"
|
| 545 |
for s in result["snapshots"]
|
| 546 |
)
|
| 547 |
|