Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ from retrying import retry
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
-
"MODEL_PATH": "yolov8n.pt", # Lightweight model,
|
| 22 |
"OUTPUT_DIR": "static/output",
|
| 23 |
"VIOLATION_LABELS": {
|
| 24 |
0: "no_helmet",
|
|
@@ -58,7 +58,7 @@ def load_model():
|
|
| 58 |
try:
|
| 59 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 60 |
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
|
| 61 |
-
logger.warning("Ensure yolov8n.pt is trained to detect
|
| 62 |
return model
|
| 63 |
except Exception as e:
|
| 64 |
logger.error(f"Failed to load model: {e}")
|
|
@@ -252,13 +252,14 @@ def process_video(video_data):
|
|
| 252 |
for box in result.boxes:
|
| 253 |
cls, conf = int(box.cls), float(box.conf)
|
| 254 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
|
| 255 |
-
# Log unexpected classes
|
| 256 |
if label not in ["no_helmet", "no_harness", "unsafe_posture", "unsafe_zone"]:
|
| 257 |
logger.warning(f"Unexpected class detected: {label} (cls: {cls}, conf: {conf})")
|
| 258 |
continue
|
| 259 |
-
# Only process specified violations
|
| 260 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 261 |
continue
|
|
|
|
| 262 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 263 |
logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
|
| 264 |
continue
|
|
@@ -376,7 +377,7 @@ interface = gr.Interface(
|
|
| 376 |
gr.Textbox(label="Violation Details URL")
|
| 377 |
],
|
| 378 |
title="Worksite Safety Violation Analyzer",
|
| 379 |
-
description="Upload site videos to detect safety violations (
|
| 380 |
)
|
| 381 |
|
| 382 |
if __name__ == "__main__":
|
|
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
+
"MODEL_PATH": "yolov8n.pt", # Lightweight model, must be trained for violations only
|
| 22 |
"OUTPUT_DIR": "static/output",
|
| 23 |
"VIOLATION_LABELS": {
|
| 24 |
0: "no_helmet",
|
|
|
|
| 58 |
try:
|
| 59 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 60 |
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
|
| 61 |
+
logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if positive cases are detected.")
|
| 62 |
return model
|
| 63 |
except Exception as e:
|
| 64 |
logger.error(f"Failed to load model: {e}")
|
|
|
|
| 252 |
for box in result.boxes:
|
| 253 |
cls, conf = int(box.cls), float(box.conf)
|
| 254 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
|
| 255 |
+
# Log and skip any unexpected classes
|
| 256 |
if label not in ["no_helmet", "no_harness", "unsafe_posture", "unsafe_zone"]:
|
| 257 |
logger.warning(f"Unexpected class detected: {label} (cls: {cls}, conf: {conf})")
|
| 258 |
continue
|
| 259 |
+
# Only process specified violations
|
| 260 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 261 |
continue
|
| 262 |
+
# Apply confidence threshold
|
| 263 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 264 |
logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
|
| 265 |
continue
|
|
|
|
| 377 |
gr.Textbox(label="Violation Details URL")
|
| 378 |
],
|
| 379 |
title="Worksite Safety Violation Analyzer",
|
| 380 |
+
description="Upload site videos to detect safety violations (no helmet, no harness, unsafe posture). Positive cases are ignored."
|
| 381 |
)
|
| 382 |
|
| 383 |
if __name__ == "__main__":
|