PrashanthB461 commited on
Commit
c7afc58
·
verified ·
1 Parent(s): 01caa2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -263,16 +263,16 @@ CONFIG = {
263
  "MIN_VIOLATION_FRAMES": 2,
264
  "VIOLATION_COOLDOWN": 30.0,
265
  "WORKER_TRACKING_DURATION": 10.0,
266
- "MAX_PROCESSING_TIME": 60, # Reduced for early termination
267
- "FRAME_SKIP": 4, # Increased to reduce frames processed
268
- "BATCH_SIZE": 4, # Reduced for CPU efficiency
269
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
270
  "TRACK_BUFFER": 150,
271
  "TRACK_THRESH": 0.3,
272
  "MATCH_THRESH": 0.5,
273
  "SNAPSHOT_QUALITY": 95,
274
  "MAX_WORKER_DISTANCE": 150,
275
- "TARGET_RESOLUTION": (320, 320), # Reduced for faster inference
276
  "HELMET_VALIDATION_FRAMES": 3
277
  }
278
 
@@ -308,8 +308,7 @@ def preprocess_frame(frame):
308
  target_res = CONFIG["TARGET_RESOLUTION"]
309
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
310
  frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20)
311
- kernel = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]])
312
- frame = cv2.filter2D(frame, -1, kernel)
313
  return frame
314
 
315
  def is_unsafe_posture(box, frame_shape):
@@ -615,38 +614,45 @@ def process_video(video_data, temp_dir):
615
  unique_violations = {}
616
  violation_frames = {}
617
  helmet_detections = {}
618
- frame_detections = {} # Store detections for snapshot reuse
619
  start_time = time.time()
620
  frame_skip = CONFIG["FRAME_SKIP"]
621
- processed_frames = 0
 
622
  last_yield_time = start_time
623
  worker_counter = 1
624
 
625
- while processed_frames < total_frames:
626
  batch_frames = []
627
  batch_indices = []
628
  batch_originals = []
629
 
630
  for _ in range(CONFIG["BATCH_SIZE"]):
631
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
 
632
  if frame_idx >= total_frames:
 
633
  break
634
  ret, frame = cap.read()
635
  if not ret:
636
- logger.warning(f"Failed to read frame {frame_idx}. Skipping.")
 
637
  break
638
  original_frame = frame.copy()
639
  frame = preprocess_frame(frame)
640
  for _ in range(frame_skip - 1):
641
  if not cap.grab():
 
 
642
  break
 
643
  batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
644
  batch_indices.append(frame_idx)
645
  batch_originals.append(original_frame)
646
- processed_frames += frame_skip
647
 
648
  if not batch_frames:
649
- logger.info("No more frames to process.")
650
  break
651
 
652
  # Check for timeout
@@ -673,10 +679,11 @@ def process_video(video_data, temp_dir):
673
 
674
  current_time = time.time()
675
  if current_time - last_yield_time > 0.1:
676
- progress = (processed_frames / total_frames) * 100
 
677
  elapsed_time = current_time - start_time
678
  fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
679
- yield f"Processing video... {progress:.1f}% complete (Frame {processed_frames}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", f"Elapsed: {elapsed_time:.1f}s"
680
  last_yield_time = current_time
681
 
682
  for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
@@ -685,7 +692,7 @@ def process_video(video_data, temp_dir):
685
  person_boxes = []
686
  tool_boxes = []
687
 
688
- frame_detections[frame_idx] = [] # Store detections for this frame
689
 
690
  for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
691
  label_name = model.config.id2label[label.item()]
@@ -770,6 +777,11 @@ def process_video(video_data, temp_dir):
770
  logger.info(f"Processing complete in {processing_time:.2f}s")
771
  logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
772
 
 
 
 
 
 
773
  violations = []
774
  for (worker_id, label), detection_time in unique_violations.items():
775
  frame_idx = violation_frames[(worker_id, label)]
@@ -787,7 +799,6 @@ def process_video(video_data, temp_dir):
787
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
788
  return
789
 
790
- # Generate violation table early for intermediate output
791
  violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
792
  violation_table += "|-----------|-----------|----------|------------|\n"
793
  for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
@@ -810,7 +821,6 @@ def process_video(video_data, temp_dir):
810
  continue
811
 
812
  frame = preprocess_frame(frame)
813
- # Reuse detections instead of re-running inference
814
  detections = frame_detections.get(frame_idx, [])
815
  for det in detections:
816
  if det["label"] == violation["violation"]:
 
263
  "MIN_VIOLATION_FRAMES": 2,
264
  "VIOLATION_COOLDOWN": 30.0,
265
  "WORKER_TRACKING_DURATION": 10.0,
266
+ "MAX_PROCESSING_TIME": 120, # Increased to allow more time for CPU processing
267
+ "FRAME_SKIP": 4,
268
+ "BATCH_SIZE": 2, # Reduced for better CPU performance
269
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
270
  "TRACK_BUFFER": 150,
271
  "TRACK_THRESH": 0.3,
272
  "MATCH_THRESH": 0.5,
273
  "SNAPSHOT_QUALITY": 95,
274
  "MAX_WORKER_DISTANCE": 150,
275
+ "TARGET_RESOLUTION": (256, 256), # Further reduced for faster inference
276
  "HELMET_VALIDATION_FRAMES": 3
277
  }
278
 
 
308
  target_res = CONFIG["TARGET_RESOLUTION"]
309
  frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
310
  frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20)
311
+ # Removed cv2.filter2D to reduce processing time
 
312
  return frame
313
 
314
  def is_unsafe_posture(box, frame_shape):
 
614
  unique_violations = {}
615
  violation_frames = {}
616
  helmet_detections = {}
617
+ frame_detections = {}
618
  start_time = time.time()
619
  frame_skip = CONFIG["FRAME_SKIP"]
620
+ processed_frames = 0 # Track actual frames processed
621
+ frames_read = 0 # Track frames read from video
622
  last_yield_time = start_time
623
  worker_counter = 1
624
 
625
+ while True:
626
  batch_frames = []
627
  batch_indices = []
628
  batch_originals = []
629
 
630
  for _ in range(CONFIG["BATCH_SIZE"]):
631
  frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
632
+ frames_read = frame_idx
633
  if frame_idx >= total_frames:
634
+ logger.info("Reached end of video.")
635
  break
636
  ret, frame = cap.read()
637
  if not ret:
638
+ logger.warning(f"Failed to read frame {frame_idx}. Assuming end of video.")
639
+ frames_read = total_frames # Assume we've reached the end
640
  break
641
  original_frame = frame.copy()
642
  frame = preprocess_frame(frame)
643
  for _ in range(frame_skip - 1):
644
  if not cap.grab():
645
+ logger.warning(f"Failed to grab frame after {frame_idx}. Assuming end of video.")
646
+ frames_read = total_frames
647
  break
648
+ frames_read += 1
649
  batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
650
  batch_indices.append(frame_idx)
651
  batch_originals.append(original_frame)
652
+ processed_frames += 1 # Increment for each frame actually processed
653
 
654
  if not batch_frames:
655
+ logger.info("No more frames to process in this batch.")
656
  break
657
 
658
  # Check for timeout
 
679
 
680
  current_time = time.time()
681
  if current_time - last_yield_time > 0.1:
682
+ progress = (frames_read / total_frames) * 100
683
+ progress = min(progress, 100.0) # Ensure progress doesn't exceed 100%
684
  elapsed_time = current_time - start_time
685
  fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
686
+ yield f"Processing video... {progress:.1f}% complete (Frame {frames_read}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", f"Elapsed: {elapsed_time:.1f}s"
687
  last_yield_time = current_time
688
 
689
  for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
 
692
  person_boxes = []
693
  tool_boxes = []
694
 
695
+ frame_detections[frame_idx] = []
696
 
697
  for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
698
  label_name = model.config.id2label[label.item()]
 
777
  logger.info(f"Processing complete in {processing_time:.2f}s")
778
  logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
779
 
780
+ # Ensure final progress update
781
+ final_progress = (frames_read / total_frames) * 100
782
+ final_progress = min(final_progress, 100.0)
783
+ yield f"Processing video... {final_progress:.1f}% complete (Frame {frames_read}/{total_frames})", "", "", "", "", f"Elapsed: {processing_time:.1f}s"
784
+
785
  violations = []
786
  for (worker_id, label), detection_time in unique_violations.items():
787
  frame_idx = violation_frames[(worker_id, label)]
 
799
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
800
  return
801
 
 
802
  violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
803
  violation_table += "|-----------|-----------|----------|------------|\n"
804
  for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
 
821
  continue
822
 
823
  frame = preprocess_frame(frame)
 
824
  detections = frame_detections.get(frame_idx, [])
825
  for det in detections:
826
  if det["label"] == violation["violation"]: