PrashanthB461 commited on
Commit
9cc7878
·
verified ·
1 Parent(s): d0fdccb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -10
app.py CHANGED
@@ -260,6 +260,10 @@ def process_video(video_data):
260
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
261
  workers = [] # List to track workers
262
 
 
 
 
 
263
  while True:
264
  ret, frame = video.read()
265
  if not ret:
@@ -273,21 +277,35 @@ def process_video(video_data):
273
  logger.info("Processing time limit reached")
274
  break
275
 
 
276
  results = model(frame, device=device)
277
  current_detections = []
 
 
278
  for result in results:
279
- for box in result.boxes:
 
 
 
280
  cls, conf = int(box.cls), float(box.conf)
281
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
282
 
283
- # Log detected violations
284
- if label:
285
- logger.info(f"Violation Detected: {label} with confidence: {conf}")
286
-
287
- if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
 
 
 
 
288
  continue
289
 
 
290
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
 
 
291
  current_detections.append({
292
  "violation": label,
293
  "confidence": round(conf, 2),
@@ -296,10 +314,13 @@ def process_video(video_data):
296
  "frame": frame_count
297
  })
298
 
299
- # Process detections and workers
 
300
  for detection in current_detections:
301
  matched_worker = None
302
  max_iou = 0
 
 
303
  for worker in workers:
304
  iou = calculate_iou(detection["bounding_box"], worker["bbox"])
305
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
@@ -309,6 +330,8 @@ def process_video(video_data):
309
  if matched_worker:
310
  # Update existing worker
311
  if detection["violation"] not in matched_worker["violations"]:
 
 
312
  matched_worker["violations"].add(detection["violation"])
313
  violations.append({
314
  "frame": frame_count,
@@ -318,18 +341,34 @@ def process_video(video_data):
318
  "timestamp": detection["timestamp"],
319
  "worker_id": matched_worker["id"]
320
  })
321
- snapshot_taken[detection["violation"]] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  matched_worker["bbox"] = detection["bounding_box"]
323
  matched_worker["last_frame"] = frame_count
324
  else:
325
- # New worker
326
  worker_id = len(workers) + 1
 
327
  workers.append({
328
  "id": worker_id,
329
  "violations": {detection["violation"]},
330
  "bbox": detection["bounding_box"],
331
  "last_frame": frame_count
332
  })
 
333
  violations.append({
334
  "frame": frame_count,
335
  "violation": detection["violation"],
@@ -338,11 +377,37 @@ def process_video(video_data):
338
  "timestamp": detection["timestamp"],
339
  "worker_id": worker_id
340
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
 
 
 
 
 
 
342
  frame_count += 1
343
 
344
  video.release()
345
  os.remove(video_path)
 
 
 
 
 
 
 
346
 
347
  if not violations:
348
  logger.info("No violations detected")
@@ -441,4 +506,4 @@ interface = gr.Interface(
441
 
442
  if __name__ == "__main__":
443
  logger.info("Launching Safety Analyzer App...")
444
- interface.launch()
 
260
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
261
  workers = [] # List to track workers
262
 
263
+ # Adding debug logging for violation labels
264
+ logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
265
+ logger.info(f"Using confidence threshold: {CONFIG['CONFIDENCE_THRESHOLD']}")
266
+
267
  while True:
268
  ret, frame = video.read()
269
  if not ret:
 
277
  logger.info("Processing time limit reached")
278
  break
279
 
280
+ # Run detection on this frame
281
  results = model(frame, device=device)
282
  current_detections = []
283
+
284
+ # Process detections from the model
285
  for result in results:
286
+ boxes = result.boxes
287
+ logger.info(f"Frame {frame_count}: Found {len(boxes)} potential detections")
288
+
289
+ for box in boxes:
290
  cls, conf = int(box.cls), float(box.conf)
291
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
292
 
293
+ # Enhanced logging
294
+ logger.info(f"Detection: class={cls}, conf={conf:.2f}, label={label}")
295
+
296
+ # Skip if not a known violation or below confidence threshold
297
+ if label not in CONFIG["VIOLATION_LABELS"].values():
298
+ logger.info(f"Skipping unknown class: {cls}")
299
+ continue
300
+
301
+ if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
302
+ logger.info(f"Skipping low confidence: {conf:.2f} < {CONFIG['CONFIDENCE_THRESHOLD']}")
303
  continue
304
 
305
+ # Process valid detection
306
  bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
307
+ logger.info(f"Valid detection: {label} with confidence: {conf:.2f}")
308
+
309
  current_detections.append({
310
  "violation": label,
311
  "confidence": round(conf, 2),
 
314
  "frame": frame_count
315
  })
316
 
317
+ # Process detections and associate with workers
318
+ # FIXED: Improved worker tracking logic
319
  for detection in current_detections:
320
  matched_worker = None
321
  max_iou = 0
322
+
323
+ # Try to match with existing workers
324
  for worker in workers:
325
  iou = calculate_iou(detection["bounding_box"], worker["bbox"])
326
  if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
 
330
  if matched_worker:
331
  # Update existing worker
332
  if detection["violation"] not in matched_worker["violations"]:
333
+ # New violation for this worker
334
+ logger.info(f"New violation for worker {matched_worker['id']}: {detection['violation']}")
335
  matched_worker["violations"].add(detection["violation"])
336
  violations.append({
337
  "frame": frame_count,
 
341
  "timestamp": detection["timestamp"],
342
  "worker_id": matched_worker["id"]
343
  })
344
+
345
+ # Save snapshot for this violation type if not already taken
346
+ if not snapshot_taken[detection["violation"]]:
347
+ snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
348
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
349
+ cv2.imwrite(snapshot_path, frame)
350
+ snapshot_taken[detection["violation"]] = True
351
+ snapshots.append({
352
+ "violation": detection["violation"],
353
+ "frame": frame_count,
354
+ "snapshot_path": snapshot_path,
355
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
356
+ })
357
+
358
+ # Update worker position
359
  matched_worker["bbox"] = detection["bounding_box"]
360
  matched_worker["last_frame"] = frame_count
361
  else:
362
+ # New worker detected
363
  worker_id = len(workers) + 1
364
+ logger.info(f"New worker {worker_id} with violation: {detection['violation']}")
365
  workers.append({
366
  "id": worker_id,
367
  "violations": {detection["violation"]},
368
  "bbox": detection["bounding_box"],
369
  "last_frame": frame_count
370
  })
371
+
372
  violations.append({
373
  "frame": frame_count,
374
  "violation": detection["violation"],
 
377
  "timestamp": detection["timestamp"],
378
  "worker_id": worker_id
379
  })
380
+
381
+ # Save snapshot for this violation type if not already taken
382
+ if not snapshot_taken[detection["violation"]]:
383
+ snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
384
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
385
+ cv2.imwrite(snapshot_path, frame)
386
+ snapshot_taken[detection["violation"]] = True
387
+ snapshots.append({
388
+ "violation": detection["violation"],
389
+ "frame": frame_count,
390
+ "snapshot_path": snapshot_path,
391
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
392
+ })
393
 
394
+ # Clean up workers that haven't been seen for a while
395
+ active_workers = [w for w in workers if frame_count - w["last_frame"] < CONFIG["FRAME_SKIP"] * 5]
396
+ if len(active_workers) != len(workers):
397
+ logger.info(f"Cleaned up {len(workers) - len(active_workers)} inactive workers")
398
+ workers = active_workers
399
+
400
  frame_count += 1
401
 
402
  video.release()
403
  os.remove(video_path)
404
+
405
+ # Final log of violations detected
406
+ violation_types = {}
407
+ for v in violations:
408
+ violation_types[v["violation"]] = violation_types.get(v["violation"], 0) + 1
409
+
410
+ logger.info(f"Detection complete. Found violations: {violation_types}")
411
 
412
  if not violations:
413
  logger.info("No violations detected")
 
506
 
507
  if __name__ == "__main__":
508
  logger.info("Launching Safety Analyzer App...")
509
+ interface.launch()