Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -40,7 +40,7 @@ CONFIG = {
|
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 41 |
"FRAME_SKIP": 15,
|
| 42 |
"MAX_PROCESSING_TIME": 30,
|
| 43 |
-
"CONFIDENCE_THRESHOLD": 0.
|
| 44 |
"IOU_THRESHOLD": 0.5 # For worker tracking
|
| 45 |
}
|
| 46 |
|
|
@@ -80,12 +80,19 @@ def calculate_iou(box1, box2):
|
|
| 80 |
x1, y1, w1, h1 = box1
|
| 81 |
x2, y2, w2, h2 = box2
|
| 82 |
|
|
|
|
| 83 |
x1_min, y1_min = x1 - w1/2, y1 - h1/2
|
| 84 |
x1_max, y1_max = x1 + w1/2, y1 + h1/2
|
| 85 |
x2_min, y2_min = x2 - w2/2, y2 - h2/2
|
| 86 |
x2_max, y2_max = x2 + w2/2, y2 + h2/2
|
| 87 |
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
area1 = w1 * h1
|
| 90 |
area2 = w2 * h2
|
| 91 |
union = area1 + area2 - intersection
|
|
@@ -245,7 +252,7 @@ def process_video(video_data):
|
|
| 245 |
if not video.isOpened():
|
| 246 |
raise ValueError("Could not open video file")
|
| 247 |
|
| 248 |
-
violations, snapshots = [], []
|
| 249 |
frame_count = 0
|
| 250 |
start_time = time.time()
|
| 251 |
fps = video.get(cv2.CAP_PROP_FPS)
|
|
@@ -280,7 +287,15 @@ def process_video(video_data):
|
|
| 280 |
cls, conf = int(box.cls), float(box.conf)
|
| 281 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 282 |
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
if label not in CONFIG["VIOLATION_LABELS"].values():
|
| 286 |
logger.info(f"Skipping unknown class: {cls}")
|
|
@@ -394,6 +409,7 @@ def process_video(video_data):
|
|
| 394 |
return {
|
| 395 |
"violations": [],
|
| 396 |
"snapshots": [],
|
|
|
|
| 397 |
"score": 100,
|
| 398 |
"salesforce_record_id": None,
|
| 399 |
"violation_details_url": "",
|
|
@@ -407,6 +423,7 @@ def process_video(video_data):
|
|
| 407 |
return {
|
| 408 |
"violations": violations,
|
| 409 |
"snapshots": snapshots,
|
|
|
|
| 410 |
"score": score,
|
| 411 |
"salesforce_record_id": report_id,
|
| 412 |
"violation_details_url": final_pdf_url,
|
|
@@ -417,6 +434,7 @@ def process_video(video_data):
|
|
| 417 |
return {
|
| 418 |
"violations": [],
|
| 419 |
"snapshots": [],
|
|
|
|
| 420 |
"score": 100,
|
| 421 |
"salesforce_record_id": None,
|
| 422 |
"violation_details_url": "",
|
|
@@ -425,9 +443,9 @@ def process_video(video_data):
|
|
| 425 |
|
| 426 |
def gradio_interface(video_file):
|
| 427 |
if not video_file:
|
| 428 |
-
return "No file uploaded.", "", "No file uploaded.", "", "", []
|
| 429 |
try:
|
| 430 |
-
yield "Processing video... please wait.", "", "", "", "", []
|
| 431 |
|
| 432 |
with open(video_file, "rb") as f:
|
| 433 |
video_data = f.read()
|
|
@@ -435,7 +453,7 @@ def gradio_interface(video_file):
|
|
| 435 |
result = process_video(video_data)
|
| 436 |
|
| 437 |
if result.get("message"):
|
| 438 |
-
yield result["message"], "", "", "", "", []
|
| 439 |
return
|
| 440 |
|
| 441 |
violation_table = "No violations detected."
|
|
@@ -460,19 +478,30 @@ def gradio_interface(video_file):
|
|
| 460 |
)
|
| 461 |
snapshot_images = [s["snapshot_base64"] for s in result["snapshots"]]
|
| 462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
yield (
|
| 464 |
violation_table,
|
| 465 |
f"Safety Score: {result['score']}%",
|
| 466 |
snapshots_text,
|
| 467 |
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
|
| 468 |
result["violation_details_url"] or "N/A",
|
| 469 |
-
snapshot_images
|
|
|
|
| 470 |
)
|
| 471 |
except Exception as e:
|
| 472 |
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 473 |
-
yield f"Error: {str(e)}", "", "Error in processing.", "", "", []
|
| 474 |
|
| 475 |
-
interface
|
| 476 |
fn=gradio_interface,
|
| 477 |
inputs=gr.Video(label="Upload Site Video"),
|
| 478 |
outputs=[
|
|
@@ -481,7 +510,8 @@ interface = gr.Interface(
|
|
| 481 |
gr.Markdown(label="Snapshots"),
|
| 482 |
gr.Textbox(label="Salesforce Record ID"),
|
| 483 |
gr.Textbox(label="Violation Details URL"),
|
| 484 |
-
gr.Gallery(label="Violation Snapshots")
|
|
|
|
| 485 |
],
|
| 486 |
title="Worksite Safety Violation Analyzer",
|
| 487 |
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture). Non-violations are ignored.",
|
|
|
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 41 |
"FRAME_SKIP": 15,
|
| 42 |
"MAX_PROCESSING_TIME": 30,
|
| 43 |
+
"CONFIDENCE_THRESHOLD": 0.1, # Lowered for debugging
|
| 44 |
"IOU_THRESHOLD": 0.5 # For worker tracking
|
| 45 |
}
|
| 46 |
|
|
|
|
| 80 |
x1, y1, w1, h1 = box1
|
| 81 |
x2, y2, w2, h2 = box2
|
| 82 |
|
| 83 |
+
# Convert to top-left and bottom-right coordinates
|
| 84 |
x1_min, y1_min = x1 - w1/2, y1 - h1/2
|
| 85 |
x1_max, y1_max = x1 + w1/2, y1 + h1/2
|
| 86 |
x2_min, y2_min = x2 - w2/2, y2 - h2/2
|
| 87 |
x2_max, y2_max = x2 + w2/2, y2 + h2/2
|
| 88 |
|
| 89 |
+
# Calculate intersection
|
| 90 |
+
x_min = max(x1_min, x2_min)
|
| 91 |
+
y_min = max(y1_min, y2_min)
|
| 92 |
+
x_max = min(x1_max, x2_max)
|
| 93 |
+
y_max = min(y1_max, y2_max)
|
| 94 |
+
|
| 95 |
+
intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
|
| 96 |
area1 = w1 * h1
|
| 97 |
area2 = w2 * h2
|
| 98 |
union = area1 + area2 - intersection
|
|
|
|
| 252 |
if not video.isOpened():
|
| 253 |
raise ValueError("Could not open video file")
|
| 254 |
|
| 255 |
+
violations, snapshots, raw_detections = [], [], []
|
| 256 |
frame_count = 0
|
| 257 |
start_time = time.time()
|
| 258 |
fps = video.get(cv2.CAP_PROP_FPS)
|
|
|
|
| 287 |
cls, conf = int(box.cls), float(box.conf)
|
| 288 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 289 |
|
| 290 |
+
# Log all raw detections
|
| 291 |
+
logger.info(f"Raw Detection: class={cls}, conf={conf:.2f}, label={label}")
|
| 292 |
+
raw_detections.append({
|
| 293 |
+
"frame": frame_count,
|
| 294 |
+
"class": cls,
|
| 295 |
+
"confidence": round(conf, 2),
|
| 296 |
+
"label": label if label in CONFIG["VIOLATION_LABELS"].values() else "unknown",
|
| 297 |
+
"timestamp": frame_count / fps
|
| 298 |
+
})
|
| 299 |
|
| 300 |
if label not in CONFIG["VIOLATION_LABELS"].values():
|
| 301 |
logger.info(f"Skipping unknown class: {cls}")
|
|
|
|
| 409 |
return {
|
| 410 |
"violations": [],
|
| 411 |
"snapshots": [],
|
| 412 |
+
"raw_detections": raw_detections,
|
| 413 |
"score": 100,
|
| 414 |
"salesforce_record_id": None,
|
| 415 |
"violation_details_url": "",
|
|
|
|
| 423 |
return {
|
| 424 |
"violations": violations,
|
| 425 |
"snapshots": snapshots,
|
| 426 |
+
"raw_detections": raw_detections,
|
| 427 |
"score": score,
|
| 428 |
"salesforce_record_id": report_id,
|
| 429 |
"violation_details_url": final_pdf_url,
|
|
|
|
| 434 |
return {
|
| 435 |
"violations": [],
|
| 436 |
"snapshots": [],
|
| 437 |
+
"raw_detections": [],
|
| 438 |
"score": 100,
|
| 439 |
"salesforce_record_id": None,
|
| 440 |
"violation_details_url": "",
|
|
|
|
| 443 |
|
| 444 |
def gradio_interface(video_file):
|
| 445 |
if not video_file:
|
| 446 |
+
return "No file uploaded.", "", "No file uploaded.", "", "", [], "No raw detections."
|
| 447 |
try:
|
| 448 |
+
yield "Processing video... please wait.", "", "", "", "", [], "Processing..."
|
| 449 |
|
| 450 |
with open(video_file, "rb") as f:
|
| 451 |
video_data = f.read()
|
|
|
|
| 453 |
result = process_video(video_data)
|
| 454 |
|
| 455 |
if result.get("message"):
|
| 456 |
+
yield result["message"], "", "", "", "", [], "Error in processing."
|
| 457 |
return
|
| 458 |
|
| 459 |
violation_table = "No violations detected."
|
|
|
|
| 478 |
)
|
| 479 |
snapshot_images = [s["snapshot_base64"] for s in result["snapshots"]]
|
| 480 |
|
| 481 |
+
raw_detections_text = "No raw detections logged."
|
| 482 |
+
if result["raw_detections"]:
|
| 483 |
+
header = "| Frame | Timestamp (s) | Class | Label | Confidence |\n"
|
| 484 |
+
separator = "|-------|---------------|-------|----------------|------------|\n"
|
| 485 |
+
rows = []
|
| 486 |
+
for d in result["raw_detections"]:
|
| 487 |
+
row = f"| {d['frame']:<5} | {d['timestamp']:.2f} | {d['class']:<5} | {d['label']:<14} | {d['confidence']:.2f} |"
|
| 488 |
+
rows.append(row)
|
| 489 |
+
raw_detections_text = header + separator + "\n".join(rows)
|
| 490 |
+
|
| 491 |
yield (
|
| 492 |
violation_table,
|
| 493 |
f"Safety Score: {result['score']}%",
|
| 494 |
snapshots_text,
|
| 495 |
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
|
| 496 |
result["violation_details_url"] or "N/A",
|
| 497 |
+
snapshot_images,
|
| 498 |
+
raw_detections_text
|
| 499 |
)
|
| 500 |
except Exception as e:
|
| 501 |
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 502 |
+
yield f"Error: {str(e)}", "", "Error in processing.", "", "", [], "Error in processing."
|
| 503 |
|
| 504 |
+
interface beq gr.Interface(
|
| 505 |
fn=gradio_interface,
|
| 506 |
inputs=gr.Video(label="Upload Site Video"),
|
| 507 |
outputs=[
|
|
|
|
| 510 |
gr.Markdown(label="Snapshots"),
|
| 511 |
gr.Textbox(label="Salesforce Record ID"),
|
| 512 |
gr.Textbox(label="Violation Details URL"),
|
| 513 |
+
gr.Gallery(label="Violation Snapshots"),
|
| 514 |
+
gr.Markdown(label="Raw Detections (Debug)")
|
| 515 |
],
|
| 516 |
title="Worksite Safety Violation Analyzer",
|
| 517 |
description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture). Non-violations are ignored.",
|