PrashanthB461 commited on
Commit
366a953
·
verified ·
1 Parent(s): f4592c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -29
app.py CHANGED
@@ -49,7 +49,7 @@ CONFIG = {
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 5, # Process every 5th frame (balance speed vs. accuracy)
53
  "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
54
  "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
  "no_helmet": 0.4,
@@ -59,7 +59,7 @@ CONFIG = {
59
  "improper_tool_use": 0.35
60
  },
61
  "IOU_THRESHOLD": 0.4, # For worker tracking
62
- "MIN_VIOLATION_FRAMES": 3 # Min frames to confirm a violation
63
  }
64
 
65
  # Setup logging
@@ -239,6 +239,7 @@ def process_video(video_path):
239
  snapshots = []
240
  workers = []
241
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
 
242
 
243
  while cap.isOpened():
244
  ret, frame = cap.read()
@@ -271,37 +272,18 @@ def process_video(video_path):
271
  "timestamp": current_time
272
  }
273
 
274
- # Track worker
275
- matched_worker = None
276
- max_iou = 0
277
- for worker in workers:
278
- iou = calculate_iou(worker["bbox"], bbox)
279
- if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
280
- max_iou = iou
281
- matched_worker = worker
282
-
283
- if matched_worker:
284
- worker_id = matched_worker["id"]
285
- matched_worker["bbox"] = bbox
286
- else:
287
- worker_id = len(workers) + 1
288
- workers.append({"id": worker_id, "bbox": bbox})
289
-
290
  detection["worker_id"] = worker_id
291
  violations.append(detection)
292
 
293
- # Capture snapshot if first detection of this type
294
  if not snapshot_taken[label]:
295
- snapshot_path = os.path.join(
296
- CONFIG["OUTPUT_DIR"],
297
- f"{label}_{frame_count}.jpg"
298
- )
299
- cv2.imwrite(snapshot_path, draw_detections(frame.copy(), [detection]))
300
- snapshots.append({
301
- "violation": label,
302
- "frame": frame_count,
303
- "path": snapshot_path,
304
- "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
305
  })
306
  snapshot_taken[label] = True
307
 
@@ -309,6 +291,23 @@ def process_video(video_path):
309
 
310
  cap.release()
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  # Filter violations (require min frames)
313
  filtered_violations = []
314
  violation_counts = {}
 
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 10, # Increased to process every 10th frame (faster)
53
  "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
54
  "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
  "no_helmet": 0.4,
 
59
  "improper_tool_use": 0.35
60
  },
61
  "IOU_THRESHOLD": 0.4, # For worker tracking
62
+ "MIN_VIOLATION_FRAMES": 2 # Reduced to 2 frames for faster confirmation
63
  }
64
 
65
  # Setup logging
 
239
  snapshots = []
240
  workers = []
241
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
242
+ frames_to_save = [] # Store frames for snapshot saving later
243
 
244
  while cap.isOpened():
245
  ret, frame = cap.read()
 
272
  "timestamp": current_time
273
  }
274
 
275
+ # Simplified worker tracking
276
+ worker_id = len(workers) + 1 # Assign new ID without IoU for speed
277
+ workers.append({"id": worker_id, "bbox": bbox})
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  detection["worker_id"] = worker_id
279
  violations.append(detection)
280
 
281
+ # Store frame for snapshot if first detection of this type
282
  if not snapshot_taken[label]:
283
+ frames_to_save.append({
284
+ "frame": frame.copy(),
285
+ "detection": detection,
286
+ "label": label
 
 
 
 
 
 
287
  })
288
  snapshot_taken[label] = True
289
 
 
291
 
292
  cap.release()
293
 
294
+ # Save snapshots after processing all frames
295
+ for item in frames_to_save:
296
+ frame = item["frame"]
297
+ detection = item["detection"]
298
+ label = item["label"]
299
+ snapshot_path = os.path.join(
300
+ CONFIG["OUTPUT_DIR"],
301
+ f"{label}_{detection['frame']}.jpg"
302
+ )
303
+ cv2.imwrite(snapshot_path, draw_detections(frame, [detection]))
304
+ snapshots.append({
305
+ "violation": label,
306
+ "frame": detection["frame"],
307
+ "path": snapshot_path,
308
+ "url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
309
+ })
310
+
311
  # Filter violations (require min frames)
312
  filtered_violations = []
313
  violation_counts = {}