Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,8 +18,8 @@ from retrying import retry
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
-
"MODEL_PATH": "yolov8_safety.pt",
|
| 22 |
-
"FALLBACK_MODEL": "yolov8n.pt", #
|
| 23 |
"OUTPUT_DIR": "static/output",
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
|
@@ -54,11 +54,18 @@ logger.info(f"Using device: {device}")
|
|
| 54 |
|
| 55 |
def load_model():
|
| 56 |
try:
|
| 57 |
-
|
| 58 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
model = YOLO(model_path).to(device)
|
| 61 |
-
logger.info(f"Model loaded: {model_path}")
|
| 62 |
return model
|
| 63 |
except Exception as e:
|
| 64 |
logger.error(f"Failed to load model: {e}")
|
|
@@ -309,3 +316,68 @@ def process_video(video_data):
|
|
| 309 |
"violation_details_url": "",
|
| 310 |
"message": f"Error processing video: {e}"
|
| 311 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
+
"MODEL_PATH": "yolov8_safety.pt", # Make sure this file exists in your directory
|
| 22 |
+
"FALLBACK_MODEL": "yolov8n.pt", # Fallback model for testing
|
| 23 |
"OUTPUT_DIR": "static/output",
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
|
|
|
| 54 |
|
| 55 |
def load_model():
|
| 56 |
try:
|
| 57 |
+
# Check if the model file exists
|
| 58 |
+
if os.path.isfile(CONFIG["MODEL_PATH"]):
|
| 59 |
+
model_path = CONFIG["MODEL_PATH"]
|
| 60 |
+
logger.info(f"Model loaded: {model_path}")
|
| 61 |
+
else:
|
| 62 |
+
model_path = CONFIG["FALLBACK_MODEL"]
|
| 63 |
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
|
| 64 |
+
# Download fallback model if necessary
|
| 65 |
+
if not os.path.isfile(model_path):
|
| 66 |
+
logger.info(f"Downloading fallback model: {model_path}")
|
| 67 |
+
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
|
| 68 |
model = YOLO(model_path).to(device)
|
|
|
|
| 69 |
return model
|
| 70 |
except Exception as e:
|
| 71 |
logger.error(f"Failed to load model: {e}")
|
|
|
|
| 316 |
"violation_details_url": "",
|
| 317 |
"message": f"Error processing video: {e}"
|
| 318 |
}
|
| 319 |
+
|
| 320 |
+
def gradio_interface(video_file):
|
| 321 |
+
if not video_file:
|
| 322 |
+
return "No file uploaded.", "", "No file uploaded.", "", ""
|
| 323 |
+
try:
|
| 324 |
+
yield "Processing video... please wait.", "", "", "", ""
|
| 325 |
+
|
| 326 |
+
with open(video_file, "rb") as f:
|
| 327 |
+
video_data = f.read()
|
| 328 |
+
|
| 329 |
+
result = process_video(video_data)
|
| 330 |
+
|
| 331 |
+
if result.get("message"):
|
| 332 |
+
yield result["message"], "", "", "", ""
|
| 333 |
+
return
|
| 334 |
+
|
| 335 |
+
violation_table = "No violations detected."
|
| 336 |
+
if result["violations"]:
|
| 337 |
+
header = "| Violation | Timestamp (s) | Confidence | Bounding Box |\n"
|
| 338 |
+
separator = "|------------------------|---------------|------------|--------------------------|\n"
|
| 339 |
+
rows = []
|
| 340 |
+
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 341 |
+
for v in result["violations"]:
|
| 342 |
+
display_name = violation_name_map.get(v["violation"], v["violation"])
|
| 343 |
+
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['bounding_box']} |"
|
| 344 |
+
rows.append(row)
|
| 345 |
+
violation_table = header + separator + "\n".join(rows)
|
| 346 |
+
|
| 347 |
+
snapshots_text = "No snapshots captured."
|
| 348 |
+
if result["snapshots"]:
|
| 349 |
+
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 350 |
+
snapshots_text = "\n".join(
|
| 351 |
+
f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: "
|
| 352 |
+
for s in result["snapshots"]
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
yield (
|
| 356 |
+
violation_table,
|
| 357 |
+
f"Safety Score: {result['score']}%",
|
| 358 |
+
snapshots_text,
|
| 359 |
+
f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
|
| 360 |
+
result["violation_details_url"] or "N/A"
|
| 361 |
+
)
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 364 |
+
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
|
| 365 |
+
|
| 366 |
+
interface = gr.Interface(
|
| 367 |
+
fn=gradio_interface,
|
| 368 |
+
inputs=gr.Video(label="Upload Site Video"),
|
| 369 |
+
outputs=[
|
| 370 |
+
gr.Markdown(label="Detected Safety Violations"),
|
| 371 |
+
gr.Textbox(label="Compliance Score"),
|
| 372 |
+
gr.Markdown(label="Snapshots"),
|
| 373 |
+
gr.Textbox(label="Salesforce Record ID"),
|
| 374 |
+
gr.Textbox(label="Violation Details URL")
|
| 375 |
+
],
|
| 376 |
+
title="Worksite Safety Violation Analyzer",
|
| 377 |
+
description="Upload site videos to detect safety violations (No Helmet Violation, No Harness Violation, Unsafe Posture Violation). Non-violations are ignored.",
|
| 378 |
+
allow_flagging="never"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
logger.info("Launching Safety Analyzer App...")
|
| 383 |
+
interface.launch()
|