PrashanthB461 commited on
Commit
2d42edf
·
verified ·
1 Parent(s): 0b5f150

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -33
app.py CHANGED
@@ -49,18 +49,15 @@ CONFIG = {
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 15, # Increased to process every 15th frame
53
- "MAX_FRAMES": 100, # Limit the number of frames to process
54
- "FRAME_RESIZE": (640, 480), # Downscale frames to this resolution
55
- "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
56
  "CONFIDENCE_THRESHOLD": { # Per-class thresholds
57
  "no_helmet": 0.4,
58
  "no_harness": 0.3,
59
  "unsafe_posture": 0.25,
60
  "unsafe_zone": 0.3,
61
  "improper_tool_use": 0.35
62
- },
63
- "MIN_VIOLATION_FRAMES": 2 # Min frames to confirm a violation
64
  }
65
 
66
  # Setup logging
@@ -219,12 +216,12 @@ def push_report_to_salesforce(violations, score, pdf_file):
219
  # Video Processing
220
  # ==========================
221
  def process_video(video_path):
222
- """Analyze video for safety violations."""
223
  try:
 
224
  cap = cv2.VideoCapture(video_path)
225
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
226
  frame_count = 0
227
- processed_frames = 0
228
  violations = []
229
  snapshots = []
230
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
@@ -235,17 +232,15 @@ def process_video(video_path):
235
  if not ret:
236
  break
237
 
238
- # Stop if max frames reached
239
- if processed_frames >= CONFIG["MAX_FRAMES"]:
 
240
  break
241
 
242
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
243
  frame_count += 1
244
  continue
245
 
246
- # Downscale frame for faster inference
247
- frame = cv2.resize(frame, CONFIG["FRAME_RESIZE"], interpolation=cv2.INTER_AREA)
248
-
249
  # Run detection
250
  results = model(frame, device=device)
251
  current_time = frame_count / fps
@@ -268,8 +263,6 @@ def process_video(video_path):
268
  "timestamp": current_time
269
  }
270
 
271
- # Assign a generic worker ID (skipping tracking for speed)
272
- detection["worker_id"] = f"Worker_{frame_count}"
273
  violations.append(detection)
274
 
275
  # Store frame for snapshot if first detection of this type
@@ -282,7 +275,6 @@ def process_video(video_path):
282
  snapshot_taken[label] = True
283
 
284
  frame_count += 1
285
- processed_frames += 1
286
 
287
  cap.release()
288
 
@@ -303,17 +295,6 @@ def process_video(video_path):
303
  "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
304
  })
305
 
306
- # Filter violations (require min frames)
307
- filtered_violations = []
308
- violation_counts = {}
309
- for v in violations:
310
- key = (v["worker_id"], v["violation"])
311
- violation_counts[key] = violation_counts.get(key, 0) + 1
312
-
313
- for v in violations:
314
- if violation_counts[(v["worker_id"], v["violation"])] >= CONFIG["MIN_VIOLATION_FRAMES"]:
315
- filtered_violations.append(v)
316
-
317
  # Calculate safety score
318
  penalty_weights = {
319
  "no_helmet": 25,
@@ -322,12 +303,12 @@ def process_video(video_path):
322
  "unsafe_zone": 35,
323
  "improper_tool_use": 25
324
  }
325
- unique_violations = set((v["worker_id"], v["violation"]) for v in filtered_violations)
326
- total_penalty = sum(penalty_weights.get(v, 0) for _, v in unique_violations)
327
  safety_score = max(100 - total_penalty, 0)
328
 
329
  return {
330
- "violations": filtered_violations,
331
  "snapshots": snapshots,
332
  "score": safety_score,
333
  "message": ""
@@ -350,6 +331,8 @@ def analyze_video(video_file):
350
  return "No video uploaded", "", "", "", ""
351
 
352
  try:
 
 
353
  # Process video
354
  result = process_video(video_file)
355
  if result["message"]:
@@ -366,13 +349,18 @@ def analyze_video(video_file):
366
  pdf_file
367
  )
368
 
 
 
 
 
 
369
  # Format outputs
370
  violation_table = (
371
- "| Violation Type | Timestamp (s) | Confidence | Worker ID |\n"
372
- "|------------------------|---------------|------------|-----------|\n" +
373
  "\n".join(
374
  f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<22} | "
375
- f"{v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
376
  for v in result["violations"]
377
  ) if result["violations"] else "No violations detected."
378
  )
 
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 20, # Increased to process every 20th frame (faster)
53
+ "MAX_PROCESSING_TIME": 25, # Max processing time (seconds), leaving 5s for post-processing
 
 
54
  "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
  "no_helmet": 0.4,
56
  "no_harness": 0.3,
57
  "unsafe_posture": 0.25,
58
  "unsafe_zone": 0.3,
59
  "improper_tool_use": 0.35
60
+ }
 
61
  }
62
 
63
  # Setup logging
 
216
  # Video Processing
217
  # ==========================
218
  def process_video(video_path):
219
+ """Analyze video for safety violations within 30 seconds."""
220
  try:
221
+ start_time = time.time()
222
  cap = cv2.VideoCapture(video_path)
223
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
224
  frame_count = 0
 
225
  violations = []
226
  snapshots = []
227
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
 
232
  if not ret:
233
  break
234
 
235
+ # Stop if processing time exceeds limit
236
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
237
+ logger.info("Reached max processing time, stopping frame analysis")
238
  break
239
 
240
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
241
  frame_count += 1
242
  continue
243
 
 
 
 
244
  # Run detection
245
  results = model(frame, device=device)
246
  current_time = frame_count / fps
 
263
  "timestamp": current_time
264
  }
265
 
 
 
266
  violations.append(detection)
267
 
268
  # Store frame for snapshot if first detection of this type
 
275
  snapshot_taken[label] = True
276
 
277
  frame_count += 1
 
278
 
279
  cap.release()
280
 
 
295
  "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
296
  })
297
 
 
 
 
 
 
 
 
 
 
 
 
298
  # Calculate safety score
299
  penalty_weights = {
300
  "no_helmet": 25,
 
303
  "unsafe_zone": 35,
304
  "improper_tool_use": 25
305
  }
306
+ unique_violations = set(v["violation"] for v in violations)
307
+ total_penalty = sum(penalty_weights.get(v, 0) for v in unique_violations)
308
  safety_score = max(100 - total_penalty, 0)
309
 
310
  return {
311
+ "violations": violations,
312
  "snapshots": snapshots,
313
  "score": safety_score,
314
  "message": ""
 
331
  return "No video uploaded", "", "", "", ""
332
 
333
  try:
334
+ start_time = time.time()
335
+
336
  # Process video
337
  result = process_video(video_file)
338
  if result["message"]:
 
349
  pdf_file
350
  )
351
 
352
+ # Check total time
353
+ total_time = time.time() - start_time
354
+ if total_time > 30:
355
+ logger.warning(f"Processing took {total_time:.2f}s, exceeded 30s target")
356
+
357
  # Format outputs
358
  violation_table = (
359
+ "| Violation Type | Timestamp (s) | Confidence |\n"
360
+ "|------------------------|---------------|------------|\n" +
361
  "\n".join(
362
  f"| {CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation']):<22} | "
363
+ f"{v['timestamp']:.2f} | {v['confidence']:.2f} |"
364
  for v in result["violations"]
365
  ) if result["violations"] else "No violations detected."
366
  )