PrashanthB461 commited on
Commit
ba9ee16
·
verified ·
1 Parent(s): 160244b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -5
app.py CHANGED
@@ -18,8 +18,8 @@ from retrying import retry
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt",
22
- "FALLBACK_MODEL": "yolov8n.pt", # updated key to match first code
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
@@ -54,11 +54,18 @@ logger.info(f"Using device: {device}")
54
 
55
  def load_model():
56
  try:
57
- model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
58
- if model_path == CONFIG["FALLBACK_MODEL"]:
 
 
 
 
59
  logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
 
 
 
 
60
  model = YOLO(model_path).to(device)
61
- logger.info(f"Model loaded: {model_path}")
62
  return model
63
  except Exception as e:
64
  logger.error(f"Failed to load model: {e}")
@@ -309,3 +316,68 @@ def process_video(video_data):
309
  "violation_details_url": "",
310
  "message": f"Error processing video: {e}"
311
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt", # Make sure this file exists in your directory
22
+ "FALLBACK_MODEL": "yolov8n.pt", # Fallback model for testing
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
 
54
 
55
  def load_model():
56
  try:
57
+ # Check if the model file exists
58
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
59
+ model_path = CONFIG["MODEL_PATH"]
60
+ logger.info(f"Model loaded: {model_path}")
61
+ else:
62
+ model_path = CONFIG["FALLBACK_MODEL"]
63
  logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
64
+ # Download fallback model if necessary
65
+ if not os.path.isfile(model_path):
66
+ logger.info(f"Downloading fallback model: {model_path}")
67
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
68
  model = YOLO(model_path).to(device)
 
69
  return model
70
  except Exception as e:
71
  logger.error(f"Failed to load model: {e}")
 
316
  "violation_details_url": "",
317
  "message": f"Error processing video: {e}"
318
  }
319
+
320
+ def gradio_interface(video_file):
321
+ if not video_file:
322
+ return "No file uploaded.", "", "No file uploaded.", "", ""
323
+ try:
324
+ yield "Processing video... please wait.", "", "", "", ""
325
+
326
+ with open(video_file, "rb") as f:
327
+ video_data = f.read()
328
+
329
+ result = process_video(video_data)
330
+
331
+ if result.get("message"):
332
+ yield result["message"], "", "", "", ""
333
+ return
334
+
335
+ violation_table = "No violations detected."
336
+ if result["violations"]:
337
+ header = "| Violation | Timestamp (s) | Confidence | Bounding Box |\n"
338
+ separator = "|------------------------|---------------|------------|--------------------------|\n"
339
+ rows = []
340
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
341
+ for v in result["violations"]:
342
+ display_name = violation_name_map.get(v["violation"], v["violation"])
343
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['bounding_box']} |"
344
+ rows.append(row)
345
+ violation_table = header + separator + "\n".join(rows)
346
+
347
+ snapshots_text = "No snapshots captured."
348
+ if result["snapshots"]:
349
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
350
+ snapshots_text = "\n".join(
351
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
352
+ for s in result["snapshots"]
353
+ )
354
+
355
+ yield (
356
+ violation_table,
357
+ f"Safety Score: {result['score']}%",
358
+ snapshots_text,
359
+ f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
360
+ result["violation_details_url"] or "N/A"
361
+ )
362
+ except Exception as e:
363
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
364
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
365
+
366
+ interface = gr.Interface(
367
+ fn=gradio_interface,
368
+ inputs=gr.Video(label="Upload Site Video"),
369
+ outputs=[
370
+ gr.Markdown(label="Detected Safety Violations"),
371
+ gr.Textbox(label="Compliance Score"),
372
+ gr.Markdown(label="Snapshots"),
373
+ gr.Textbox(label="Salesforce Record ID"),
374
+ gr.Textbox(label="Violation Details URL")
375
+ ],
376
+ title="Worksite Safety Violation Analyzer",
377
+ description="Upload site videos to detect safety violations (No Helmet Violation, No Harness Violation, Unsafe Posture Violation). Non-violations are ignored.",
378
+ allow_flagging="never"
379
+ )
380
+
381
+ if __name__ == "__main__":
382
+ logger.info("Launching Safety Analyzer App...")
383
+ interface.launch()