PrashanthB461 commited on
Commit
3bcf6dd
·
verified ·
1 Parent(s): e877767

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -54
app.py CHANGED
@@ -18,46 +18,40 @@ from retrying import retry
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt", # Custom-trained model for specific violations
22
- "FALLBACK_MODEL_PATH": "yolov8n.pt", # Fallback if custom model is missing
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
  2: "unsafe_posture"
28
  },
29
- "DISPLAY_NAMES": { # Mapping for user-friendly violation names
30
- "no_helmet": "Missing Helmet",
31
- "no_harness": "Missing Harness",
32
- "unsafe_posture": "Unsafe Posture"
33
  },
34
  "SF_CREDENTIALS": {
35
  "username": "your_username@safety.com",
36
  "password": "your_password",
37
  "security_token": "your_security_token",
38
- "domain": "login" # Use "test" for sandbox
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
41
- "FRAME_SKIP": 15, # Process every 15th frame
42
- "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
 
43
  }
44
 
45
  # Setup logging
46
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
47
  logger = logging.getLogger(__name__)
48
 
49
- # Ensure output directory exists
50
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
51
 
52
- # ==========================
53
- # Device Setup
54
- # ==========================
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  logger.info(f"Using device: {device}")
57
 
58
- # ==========================
59
- # Model Loading
60
- # ==========================
61
  def load_model():
62
  try:
63
  model_path = CONFIG["MODEL_PATH"]
@@ -75,9 +69,6 @@ def load_model():
75
 
76
  model = load_model()
77
 
78
- # ==========================
79
- # Salesforce Integration
80
- # ==========================
81
  @retry(stop_max_attempt_number=2, wait_fixed=1000)
82
  def connect_to_salesforce():
83
  try:
@@ -207,9 +198,6 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
207
  logger.error(f"Salesforce record creation failed: {e}")
208
  return None, ""
209
 
210
- # ==========================
211
- # Safety Score Calculation
212
- # ==========================
213
  def calculate_safety_score(violations):
214
  penalties = {
215
  "no_helmet": 25,
@@ -222,9 +210,6 @@ def calculate_safety_score(violations):
222
  score -= penalties[v["violation"]]
223
  return max(score, 0)
224
 
225
- # ==========================
226
- # Video Processing
227
- # ==========================
228
  def process_video(video_data):
229
  try:
230
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
@@ -238,30 +223,32 @@ def process_video(video_data):
238
 
239
  violations, snapshots = [], []
240
  frame_count = 0
 
241
  fps = video.get(cv2.CAP_PROP_FPS)
242
 
243
- # Track one snapshot per violation type
244
  snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
245
 
246
  while True:
247
  ret, frame = video.read()
248
  if not ret:
249
- break
250
 
251
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
252
  frame_count += 1
253
  continue
254
 
 
 
 
 
255
  results = model(frame, device=device)
256
  seen_violations = set()
257
  for result in results:
258
  for box in result.boxes:
259
  cls, conf = int(box.cls), float(box.conf)
260
- label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
261
- # Only process specified violations
262
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
263
  continue
264
- # Apply confidence threshold
265
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
266
  continue
267
  if label in seen_violations:
@@ -277,7 +264,6 @@ def process_video(video_data):
277
  }
278
  violations.append(violation)
279
 
280
- # Save only one snapshot per violation type
281
  if not snapshot_taken[label]:
282
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
283
  cv2.imwrite(snapshot_path, frame)
@@ -304,7 +290,7 @@ def process_video(video_data):
304
  "score": 100,
305
  "salesforce_record_id": None,
306
  "violation_details_url": "",
307
- "message": "No violations detected here."
308
  }
309
 
310
  score = calculate_safety_score(violations)
@@ -327,12 +313,9 @@ def process_video(video_data):
327
  "score": 100,
328
  "salesforce_record_id": None,
329
  "violation_details_url": "",
330
- "message": "Error processing video."
331
  }
332
 
333
- # ==========================
334
- # Gradio Interface
335
- # ==========================
336
  def gradio_interface(video_file):
337
  if not video_file:
338
  return "No file uploaded.", "", "No file uploaded.", "", ""
@@ -342,34 +325,26 @@ def gradio_interface(video_file):
342
  result = process_video(video_data)
343
 
344
  if result.get("message"):
345
- # Show message (like "No violations detected here.")
346
- return result["message"], f"Safety Score: {result['score']}%", "", "N/A", "N/A"
347
 
348
  violation_table = "No violations detected."
349
  if result["violations"]:
350
- header = "| Violation | Timestamp | Confidence | Violation Details |\n"
351
- separator = "|-------------------|-----------|------------|---------------------------------|\n"
352
  rows = []
 
353
  for v in result["violations"]:
354
- display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
355
- # Provide clearer human-readable violation explanation
356
- if v["violation"] == "no_helmet":
357
- details = "Employee not wearing helmet"
358
- elif v["violation"] == "no_harness":
359
- details = "Employee not wearing proper harness"
360
- elif v["violation"] == "unsafe_posture":
361
- details = "Employee in unsafe posture/zone"
362
- else:
363
- details = "Violation detected"
364
-
365
- row = f"| {display_name:<17} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {details:<31} |"
366
  rows.append(row)
367
  violation_table = header + separator + "\n".join(rows)
368
 
369
  snapshots_text = "No snapshots captured."
370
  if result["snapshots"]:
 
371
  snapshots_text = "\n".join(
372
- f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
373
  for s in result["snapshots"]
374
  )
375
 
@@ -395,7 +370,7 @@ interface = gr.Interface(
395
  gr.Textbox(label="Violation Details URL")
396
  ],
397
  title="Worksite Safety Violation Analyzer",
398
- description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
399
  )
400
 
401
  if __name__ == "__main__":
 
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt",
22
+ "FALLBACK_MODEL_PATH": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
  2: "unsafe_posture"
28
  },
29
+ "DISPLAY_NAMES": {
30
+ "no_helmet": "No Helmet Violation",
31
+ "no_harness": "No Harness Violation",
32
+ "unsafe_posture": "Unsafe Posture Violation"
33
  },
34
  "SF_CREDENTIALS": {
35
  "username": "your_username@safety.com",
36
  "password": "your_password",
37
  "security_token": "your_security_token",
38
+ "domain": "login"
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
41
+ "FRAME_SKIP": 15,
42
+ "MAX_PROCESSING_TIME": 30, # Updated to 30 seconds
43
+ "CONFIDENCE_THRESHOLD": 0.5
44
  }
45
 
46
  # Setup logging
47
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
48
  logger = logging.getLogger(__name__)
49
 
 
50
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
51
 
 
 
 
52
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
  logger.info(f"Using device: {device}")
54
 
 
 
 
55
  def load_model():
56
  try:
57
  model_path = CONFIG["MODEL_PATH"]
 
69
 
70
  model = load_model()
71
 
 
 
 
72
  @retry(stop_max_attempt_number=2, wait_fixed=1000)
73
  def connect_to_salesforce():
74
  try:
 
198
  logger.error(f"Salesforce record creation failed: {e}")
199
  return None, ""
200
 
 
 
 
201
  def calculate_safety_score(violations):
202
  penalties = {
203
  "no_helmet": 25,
 
210
  score -= penalties[v["violation"]]
211
  return max(score, 0)
212
 
 
 
 
213
  def process_video(video_data):
214
  try:
215
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
 
223
 
224
  violations, snapshots = [], []
225
  frame_count = 0
226
+ start_time = time.time()
227
  fps = video.get(cv2.CAP_PROP_FPS)
228
 
 
229
  snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
230
 
231
  while True:
232
  ret, frame = video.read()
233
  if not ret:
234
+ break # End of video
235
 
236
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
237
  frame_count += 1
238
  continue
239
 
240
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
241
+ logger.info("Processing time limit of 30 seconds reached")
242
+ break
243
+
244
  results = model(frame, device=device)
245
  seen_violations = set()
246
  for result in results:
247
  for box in result.boxes:
248
  cls, conf = int(box.cls), float(box.conf)
249
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
 
250
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
251
  continue
 
252
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
253
  continue
254
  if label in seen_violations:
 
264
  }
265
  violations.append(violation)
266
 
 
267
  if not snapshot_taken[label]:
268
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
269
  cv2.imwrite(snapshot_path, frame)
 
290
  "score": 100,
291
  "salesforce_record_id": None,
292
  "violation_details_url": "",
293
+ "message": "No violations detected in the video."
294
  }
295
 
296
  score = calculate_safety_score(violations)
 
313
  "score": 100,
314
  "salesforce_record_id": None,
315
  "violation_details_url": "",
316
+ "message": f"Error processing video: {e}"
317
  }
318
 
 
 
 
319
  def gradio_interface(video_file):
320
  if not video_file:
321
  return "No file uploaded.", "", "No file uploaded.", "", ""
 
325
  result = process_video(video_data)
326
 
327
  if result.get("message"):
328
+ # If message present (either no violations or error), show it plainly
329
+ return result["message"], "", "", "", ""
330
 
331
  violation_table = "No violations detected."
332
  if result["violations"]:
333
+ header = "| Violation | Timestamp (s) | Confidence | Bounding Box |\n"
334
+ separator = "|------------------------|---------------|------------|--------------------------|\n"
335
  rows = []
336
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
337
  for v in result["violations"]:
338
+ display_name = violation_name_map.get(v["violation"], v["violation"])
339
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['bounding_box']} |"
 
 
 
 
 
 
 
 
 
 
340
  rows.append(row)
341
  violation_table = header + separator + "\n".join(rows)
342
 
343
  snapshots_text = "No snapshots captured."
344
  if result["snapshots"]:
345
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
346
  snapshots_text = "\n".join(
347
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
348
  for s in result["snapshots"]
349
  )
350
 
 
370
  gr.Textbox(label="Violation Details URL")
371
  ],
372
  title="Worksite Safety Violation Analyzer",
373
+ description="Upload site videos to detect safety violations (No Helmet Violation, No Harness Violation, Unsafe Posture Violation). Non-violations are ignored."
374
  )
375
 
376
  if __name__ == "__main__":