PrashanthB461 commited on
Commit
2d7e132
·
verified ·
1 Parent(s): 97fe32f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -24,7 +24,12 @@ CONFIG = {
24
  0: "no_helmet",
25
  1: "no_harness",
26
  2: "unsafe_posture",
27
- 3: "unsafe_zone" # Ignored in scoring
 
 
 
 
 
28
  },
29
  "SF_CREDENTIALS": {
30
  "username": "prashanth1ai@safety.com",
@@ -58,7 +63,7 @@ def load_model():
58
  try:
59
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
60
  logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
61
- logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if positive cases are detected.")
62
  return model
63
  except Exception as e:
64
  logger.error(f"Failed to load model: {e}")
@@ -104,7 +109,8 @@ def generate_violation_pdf(violations, score):
104
  c.drawString(1 * inch, y_position, "Violation Details:")
105
  y_position -= 0.3 * inch
106
  for v in violations:
107
- text = f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
 
108
  c.drawString(1 * inch, y_position, text)
109
  y_position -= 0.3 * inch
110
  if y_position < 1 * inch:
@@ -155,7 +161,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
155
  try:
156
  sf = connect_to_salesforce()
157
  violations_text = "\n".join(
158
- f"{v['violation']} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
159
  for v in violations
160
  ) or "No violations detected."
161
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
@@ -251,13 +257,10 @@ def process_video(video_data):
251
  for result in results:
252
  for box in result.boxes:
253
  cls, conf = int(box.cls), float(box.conf)
254
- label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
255
- # Log and skip any unexpected classes
256
- if label not in ["no_helmet", "no_harness", "unsafe_posture", "unsafe_zone"]:
257
- logger.warning(f"Unexpected class detected: {label} (cls: {cls}, conf: {conf})")
258
- continue
259
  # Only process specified violations
260
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
 
261
  continue
262
  # Apply confidence threshold
263
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
@@ -339,19 +342,19 @@ def gradio_interface(video_file):
339
 
340
  violation_table = "No violations detected."
341
  if result["violations"]:
342
- header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
343
- separator = "|---------------|-----------|------------|--------------------------|-------------------------|\n"
344
  rows = []
345
  for v in result["violations"]:
346
- violation_name = v["violation"].replace("no_", "").replace("unsafe_", "")
347
- row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
348
  rows.append(row)
349
  violation_table = header + separator + "\n".join(rows)
350
 
351
  snapshots_text = "No snapshots captured."
352
  if result["snapshots"]:
353
  snapshots_text = "\n".join(
354
- f"- Snapshot for {s['violation'].replace('no_', '').replace('unsafe_', '')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
355
  for s in result["snapshots"]
356
  )
357
 
@@ -366,7 +369,7 @@ def gradio_interface(video_file):
366
  logger.error(f"Error in Gradio interface: {e}")
367
  return f"Error: {str(e)}", "", "Error in processing.", "", ""
368
 
369
- interface = gr.Interface(
370
  fn=gradio_interface,
371
  inputs=gr.Video(label="Upload Site Video"),
372
  outputs=[
@@ -377,7 +380,7 @@ interface = gr.Interface(
377
  gr.Textbox(label="Violation Details URL")
378
  ],
379
  title="Worksite Safety Violation Analyzer",
380
- description="Upload site videos to detect safety violations (no helmet, no harness, unsafe posture). Positive cases are ignored."
381
  )
382
 
383
  if __name__ == "__main__":
 
24
  0: "no_helmet",
25
  1: "no_harness",
26
  2: "unsafe_posture",
27
+ 3: "unsafe_zone" # Ignored in scoring and table
28
+ },
29
+ "DISPLAY_NAMES": { # Mapping for user-friendly violation names
30
+ "no_helmet": "Missing Helmet",
31
+ "no_harness": "Missing Harness",
32
+ "unsafe_posture": "Unsafe Posture"
33
  },
34
  "SF_CREDENTIALS": {
35
  "username": "prashanth1ai@safety.com",
 
63
  try:
64
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
65
  logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
66
+ logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.")
67
  return model
68
  except Exception as e:
69
  logger.error(f"Failed to load model: {e}")
 
109
  c.drawString(1 * inch, y_position, "Violation Details:")
110
  y_position -= 0.3 * inch
111
  for v in violations:
112
+ display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
113
+ text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
114
  c.drawString(1 * inch, y_position, text)
115
  y_position -= 0.3 * inch
116
  if y_position < 1 * inch:
 
161
  try:
162
  sf = connect_to_salesforce()
163
  violations_text = "\n".join(
164
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
165
  for v in violations
166
  ) or "No violations detected."
167
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
 
257
  for result in results:
258
  for box in result.boxes:
259
  cls, conf = int(box.cls), float(box.conf)
260
+ label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
 
 
 
 
261
  # Only process specified violations
262
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
263
+ logger.warning(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation")
264
  continue
265
  # Apply confidence threshold
266
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
342
 
343
  violation_table = "No violations detected."
344
  if result["violations"]:
345
+ header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
346
+ separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
347
  rows = []
348
  for v in result["violations"]:
349
+ display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
350
+ row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
351
  rows.append(row)
352
  violation_table = header + separator + "\n".join(rows)
353
 
354
  snapshots_text = "No snapshots captured."
355
  if result["snapshots"]:
356
  snapshots_text = "\n".join(
357
+ f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
358
  for s in result["snapshots"]
359
  )
360
 
 
369
  logger.error(f"Error in Gradio interface: {e}")
370
  return f"Error: {str(e)}", "", "Error in processing.", "", ""
371
 
372
+ interface = gr.example(
373
  fn=gradio_interface,
374
  inputs=gr.Video(label="Upload Site Video"),
375
  outputs=[
 
380
  gr.Textbox(label="Violation Details URL")
381
  ],
382
  title="Worksite Safety Violation Analyzer",
383
+ description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
384
  )
385
 
386
  if __name__ == "__main__":