PrashanthB461 commited on
Commit
e877767
·
verified ·
1 Parent(s): 462ec80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -39,7 +39,6 @@ CONFIG = {
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
41
  "FRAME_SKIP": 15, # Process every 15th frame
42
- "MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
43
  "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
44
  }
45
 
@@ -239,27 +238,20 @@ def process_video(video_data):
239
 
240
  violations, snapshots = [], []
241
  frame_count = 0
242
- start_time = time.time()
243
  fps = video.get(cv2.CAP_PROP_FPS)
244
- max_frames = int(60 * fps) # Process up to 1 minute
245
 
246
  # Track one snapshot per violation type
247
  snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
248
 
249
  while True:
250
  ret, frame = video.read()
251
- if not ret or frame_count >= max_frames:
252
  break
253
 
254
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
255
  frame_count += 1
256
  continue
257
 
258
- # Stop if processing time exceeds 25 seconds
259
- if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
260
- logger.info("Processing time limit reached")
261
- break
262
-
263
  results = model(frame, device=device)
264
  seen_violations = set()
265
  for result in results:
@@ -268,11 +260,9 @@ def process_video(video_data):
268
  label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
269
  # Only process specified violations
270
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
271
- logger.warning(f"Unexpected detection: {label} (cls: {cls}, conf: {conf}) - ignored")
272
  continue
273
  # Apply confidence threshold
274
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
275
- logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
276
  continue
277
  if label in seen_violations:
278
  continue
@@ -313,7 +303,8 @@ def process_video(video_data):
313
  "snapshots": [],
314
  "score": 100,
315
  "salesforce_record_id": None,
316
- "violation_details_url": ""
 
317
  }
318
 
319
  score = calculate_safety_score(violations)
@@ -325,7 +316,8 @@ def process_video(video_data):
325
  "snapshots": snapshots,
326
  "score": score,
327
  "salesforce_record_id": report_id,
328
- "violation_details_url": final_pdf_url
 
329
  }
330
  except Exception as e:
331
  logger.error(f"Error processing video: {e}")
@@ -334,7 +326,8 @@ def process_video(video_data):
334
  "snapshots": [],
335
  "score": 100,
336
  "salesforce_record_id": None,
337
- "violation_details_url": ""
 
338
  }
339
 
340
  # ==========================
@@ -348,14 +341,28 @@ def gradio_interface(video_file):
348
  video_data = f.read()
349
  result = process_video(video_data)
350
 
 
 
 
 
351
  violation_table = "No violations detected."
352
  if result["violations"]:
353
- header = "| Violation | Timestamp | Confidence | Bounding Box | Violation Details |\n"
354
- separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
355
  rows = []
356
  for v in result["violations"]:
357
  display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
358
- row = f"| {display_name:<16} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
 
 
 
 
 
 
 
 
 
 
359
  rows.append(row)
360
  violation_table = header + separator + "\n".join(rows)
361
 
 
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
41
  "FRAME_SKIP": 15, # Process every 15th frame
 
42
  "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
43
  }
44
 
 
238
 
239
  violations, snapshots = [], []
240
  frame_count = 0
 
241
  fps = video.get(cv2.CAP_PROP_FPS)
 
242
 
243
  # Track one snapshot per violation type
244
  snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
245
 
246
  while True:
247
  ret, frame = video.read()
248
+ if not ret:
249
  break
250
 
251
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
252
  frame_count += 1
253
  continue
254
 
 
 
 
 
 
255
  results = model(frame, device=device)
256
  seen_violations = set()
257
  for result in results:
 
260
  label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
261
  # Only process specified violations
262
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
 
263
  continue
264
  # Apply confidence threshold
265
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
266
  continue
267
  if label in seen_violations:
268
  continue
 
303
  "snapshots": [],
304
  "score": 100,
305
  "salesforce_record_id": None,
306
+ "violation_details_url": "",
307
+ "message": "No violations detected here."
308
  }
309
 
310
  score = calculate_safety_score(violations)
 
316
  "snapshots": snapshots,
317
  "score": score,
318
  "salesforce_record_id": report_id,
319
+ "violation_details_url": final_pdf_url,
320
+ "message": ""
321
  }
322
  except Exception as e:
323
  logger.error(f"Error processing video: {e}")
 
326
  "snapshots": [],
327
  "score": 100,
328
  "salesforce_record_id": None,
329
+ "violation_details_url": "",
330
+ "message": "Error processing video."
331
  }
332
 
333
  # ==========================
 
341
  video_data = f.read()
342
  result = process_video(video_data)
343
 
344
+ if result.get("message"):
345
+ # Show message (like "No violations detected here.")
346
+ return result["message"], f"Safety Score: {result['score']}%", "", "N/A", "N/A"
347
+
348
  violation_table = "No violations detected."
349
  if result["violations"]:
350
+ header = "| Violation | Timestamp | Confidence | Violation Details |\n"
351
+ separator = "|-------------------|-----------|------------|---------------------------------|\n"
352
  rows = []
353
  for v in result["violations"]:
354
  display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
355
+ # Provide clearer human-readable violation explanation
356
+ if v["violation"] == "no_helmet":
357
+ details = "Employee not wearing helmet"
358
+ elif v["violation"] == "no_harness":
359
+ details = "Employee not wearing proper harness"
360
+ elif v["violation"] == "unsafe_posture":
361
+ details = "Employee in unsafe posture/zone"
362
+ else:
363
+ details = "Violation detected"
364
+
365
+ row = f"| {display_name:<17} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {details:<31} |"
366
  rows.append(row)
367
  violation_table = header + separator + "\n".join(rows)
368