Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,7 +24,12 @@ CONFIG = {
|
|
| 24 |
0: "no_helmet",
|
| 25 |
1: "no_harness",
|
| 26 |
2: "unsafe_posture",
|
| 27 |
-
3: "unsafe_zone" # Ignored in scoring
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
},
|
| 29 |
"SF_CREDENTIALS": {
|
| 30 |
"username": "prashanth1ai@safety.com",
|
|
@@ -58,7 +63,7 @@ def load_model():
|
|
| 58 |
try:
|
| 59 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 60 |
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
|
| 61 |
-
logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if
|
| 62 |
return model
|
| 63 |
except Exception as e:
|
| 64 |
logger.error(f"Failed to load model: {e}")
|
|
@@ -104,7 +109,8 @@ def generate_violation_pdf(violations, score):
|
|
| 104 |
c.drawString(1 * inch, y_position, "Violation Details:")
|
| 105 |
y_position -= 0.3 * inch
|
| 106 |
for v in violations:
|
| 107 |
-
|
|
|
|
| 108 |
c.drawString(1 * inch, y_position, text)
|
| 109 |
y_position -= 0.3 * inch
|
| 110 |
if y_position < 1 * inch:
|
|
@@ -155,7 +161,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
|
| 155 |
try:
|
| 156 |
sf = connect_to_salesforce()
|
| 157 |
violations_text = "\n".join(
|
| 158 |
-
f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
|
| 159 |
for v in violations
|
| 160 |
) or "No violations detected."
|
| 161 |
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
|
|
@@ -251,13 +257,10 @@ def process_video(video_data):
|
|
| 251 |
for result in results:
|
| 252 |
for box in result.boxes:
|
| 253 |
cls, conf = int(box.cls), float(box.conf)
|
| 254 |
-
label = CONFIG["VIOLATION_LABELS"].get(cls, f"
|
| 255 |
-
# Log and skip any unexpected classes
|
| 256 |
-
if label not in ["no_helmet", "no_harness", "unsafe_posture", "unsafe_zone"]:
|
| 257 |
-
logger.warning(f"Unexpected class detected: {label} (cls: {cls}, conf: {conf})")
|
| 258 |
-
continue
|
| 259 |
# Only process specified violations
|
| 260 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
|
|
|
| 261 |
continue
|
| 262 |
# Apply confidence threshold
|
| 263 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
|
@@ -339,19 +342,19 @@ def gradio_interface(video_file):
|
|
| 339 |
|
| 340 |
violation_table = "No violations detected."
|
| 341 |
if result["violations"]:
|
| 342 |
-
header = "| Violation
|
| 343 |
-
separator = "
|
| 344 |
rows = []
|
| 345 |
for v in result["violations"]:
|
| 346 |
-
|
| 347 |
-
row = f"| {
|
| 348 |
rows.append(row)
|
| 349 |
violation_table = header + separator + "\n".join(rows)
|
| 350 |
|
| 351 |
snapshots_text = "No snapshots captured."
|
| 352 |
if result["snapshots"]:
|
| 353 |
snapshots_text = "\n".join(
|
| 354 |
-
f"- Snapshot for {
|
| 355 |
for s in result["snapshots"]
|
| 356 |
)
|
| 357 |
|
|
@@ -366,7 +369,7 @@ def gradio_interface(video_file):
|
|
| 366 |
logger.error(f"Error in Gradio interface: {e}")
|
| 367 |
return f"Error: {str(e)}", "", "Error in processing.", "", ""
|
| 368 |
|
| 369 |
-
interface = gr.
|
| 370 |
fn=gradio_interface,
|
| 371 |
inputs=gr.Video(label="Upload Site Video"),
|
| 372 |
outputs=[
|
|
@@ -377,7 +380,7 @@ interface = gr.Interface(
|
|
| 377 |
gr.Textbox(label="Violation Details URL")
|
| 378 |
],
|
| 379 |
title="Worksite Safety Violation Analyzer",
|
| 380 |
-
description="Upload site videos to detect safety violations (
|
| 381 |
)
|
| 382 |
|
| 383 |
if __name__ == "__main__":
|
|
|
|
| 24 |
0: "no_helmet",
|
| 25 |
1: "no_harness",
|
| 26 |
2: "unsafe_posture",
|
| 27 |
+
3: "unsafe_zone" # Ignored in scoring and table
|
| 28 |
+
},
|
| 29 |
+
"DISPLAY_NAMES": { # Mapping for user-friendly violation names
|
| 30 |
+
"no_helmet": "Missing Helmet",
|
| 31 |
+
"no_harness": "Missing Harness",
|
| 32 |
+
"unsafe_posture": "Unsafe Posture"
|
| 33 |
},
|
| 34 |
"SF_CREDENTIALS": {
|
| 35 |
"username": "prashanth1ai@safety.com",
|
|
|
|
| 63 |
try:
|
| 64 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 65 |
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
|
| 66 |
+
logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.")
|
| 67 |
return model
|
| 68 |
except Exception as e:
|
| 69 |
logger.error(f"Failed to load model: {e}")
|
|
|
|
| 109 |
c.drawString(1 * inch, y_position, "Violation Details:")
|
| 110 |
y_position -= 0.3 * inch
|
| 111 |
for v in violations:
|
| 112 |
+
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
|
| 113 |
+
text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
|
| 114 |
c.drawString(1 * inch, y_position, text)
|
| 115 |
y_position -= 0.3 * inch
|
| 116 |
if y_position < 1 * inch:
|
|
|
|
| 161 |
try:
|
| 162 |
sf = connect_to_salesforce()
|
| 163 |
violations_text = "\n".join(
|
| 164 |
+
f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
|
| 165 |
for v in violations
|
| 166 |
) or "No violations detected."
|
| 167 |
pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
|
|
|
|
| 257 |
for result in results:
|
| 258 |
for box in result.boxes:
|
| 259 |
cls, conf = int(box.cls), float(box.conf)
|
| 260 |
+
label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
# Only process specified violations
|
| 262 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 263 |
+
logger.warning(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation")
|
| 264 |
continue
|
| 265 |
# Apply confidence threshold
|
| 266 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
|
|
|
| 342 |
|
| 343 |
violation_table = "No violations detected."
|
| 344 |
if result["violations"]:
|
| 345 |
+
header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
|
| 346 |
+
separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
|
| 347 |
rows = []
|
| 348 |
for v in result["violations"]:
|
| 349 |
+
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
|
| 350 |
+
row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
|
| 351 |
rows.append(row)
|
| 352 |
violation_table = header + separator + "\n".join(rows)
|
| 353 |
|
| 354 |
snapshots_text = "No snapshots captured."
|
| 355 |
if result["snapshots"]:
|
| 356 |
snapshots_text = "\n".join(
|
| 357 |
+
f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: "
|
| 358 |
for s in result["snapshots"]
|
| 359 |
)
|
| 360 |
|
|
|
|
| 369 |
logger.error(f"Error in Gradio interface: {e}")
|
| 370 |
return f"Error: {str(e)}", "", "Error in processing.", "", ""
|
| 371 |
|
| 372 |
+
interface = gr.example(
|
| 373 |
fn=gradio_interface,
|
| 374 |
inputs=gr.Video(label="Upload Site Video"),
|
| 375 |
outputs=[
|
|
|
|
| 380 |
gr.Textbox(label="Violation Details URL")
|
| 381 |
],
|
| 382 |
title="Worksite Safety Violation Analyzer",
|
| 383 |
+
description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
|
| 384 |
)
|
| 385 |
|
| 386 |
if __name__ == "__main__":
|