PrashanthB461 commited on
Commit
af9bf78
·
verified ·
1 Parent(s): 220ca2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -40,10 +40,10 @@ FFMPEG_AVAILABLE = check_ffmpeg()
40
 
41
  # ========================== # ByteTrack Implementation # ==========================
42
  class BYTETracker:
43
- def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
- self.match_thresh = match_thresh
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
@@ -159,7 +159,7 @@ class BYTETracker:
159
  iou = intersection_area / (box1_area + box2_area - intersection_area)
160
  return iou
161
 
162
- def _is_same_worker(self, pos1, pos2, threshold=300):
163
  x1, y1 = pos1
164
  x2, y2 = pos2
165
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
@@ -197,12 +197,12 @@ CONFIG = {
197
  "domain": "login"
198
  },
199
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
200
- "CONFIDENCE_THRESHOLDS": {
201
- "no_helmet": 0.5,
202
- "no_harness": 0.3,
203
- "unsafe_posture": 0.3,
204
- "unsafe_zone": 0.3,
205
- "improper_tool_use": 0.3
206
  },
207
  "MIN_VIOLATION_FRAMES": 1,
208
  "VIOLATION_COOLDOWN": 30.0,
@@ -213,10 +213,10 @@ CONFIG = {
213
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
214
  "TRACK_BUFFER": 90,
215
  "TRACK_THRESH": 0.3,
216
- "MATCH_THRESH": 0.5,
217
  "SNAPSHOT_QUALITY": 95,
218
- "MAX_WORKER_DISTANCE": 300,
219
- "TARGET_RESOLUTION": (384, 384) # Changed to 384x384 (divisible by 32)
220
  }
221
 
222
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -548,11 +548,12 @@ def process_video(video_data, temp_dir):
548
 
549
  worker_id_mapping = {}
550
  unique_violations = {}
551
- violation_frames = {} # Store frame indices for violations
552
  start_time = time.time()
553
  frame_skip = CONFIG["FRAME_SKIP"]
554
  processed_frames = 0
555
  last_yield_time = start_time
 
556
 
557
  while processed_frames < total_frames:
558
  batch_frames = []
@@ -583,7 +584,7 @@ def process_video(video_data, temp_dir):
583
  break
584
 
585
  try:
586
- batch_frames_np = np.array(batch_frames) # Shape: (batch, height, width, channels)
587
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
588
  batch_frames_tensor = batch_frames_tensor.to(device)
589
  if device.type == "cuda":
@@ -638,6 +639,7 @@ def process_video(video_data, temp_dir):
638
  np.array([t["conf"] for t in track_inputs]),
639
  np.array([t["cls"] for t in track_inputs])
640
  )
 
641
 
642
  for obj in tracked_objects:
643
  tracker_id = obj['id']
@@ -648,10 +650,9 @@ def process_video(video_data, temp_dir):
648
  if label is None:
649
  continue
650
 
651
- if not worker_id_mapping:
652
- worker_id_mapping[tracker_id] = 1
653
- else:
654
- worker_id_mapping[tracker_id] = worker_id_mapping[list(worker_id_mapping.keys())[0]]
655
 
656
  worker_id = worker_id_mapping[tracker_id]
657
 
@@ -664,6 +665,7 @@ def process_video(video_data, temp_dir):
664
  cap.release()
665
  processing_time = time.time() - start_time
666
  logger.info(f"Processing complete in {processing_time:.2f}s")
 
667
 
668
  violations = []
669
  for (worker_id, label), detection_time in unique_violations.items():
@@ -671,7 +673,7 @@ def process_video(video_data, temp_dir):
671
  "worker_id": worker_id,
672
  "violation": label,
673
  "timestamp": detection_time,
674
- "confidence": 0.0, # Will be updated after reprocessing frames
675
  "frame_idx": violation_frames[(worker_id, label)]
676
  })
677
 
@@ -680,7 +682,6 @@ def process_video(video_data, temp_dir):
680
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
681
  return
682
 
683
- # Reopen video to capture snapshots for violations
684
  snapshots = []
685
  cap = cv2.VideoCapture(video_path)
686
  for violation in violations:
 
40
 
41
  # ========================== # ByteTrack Implementation # ==========================
42
  class BYTETracker:
43
+ def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.4, frame_rate=30):
44
  self.track_thresh = track_thresh
45
  self.track_buffer = track_buffer
46
+ self.match_thresh = match_thresh # Lowered to 0.4 to improve tracking sensitivity
47
  self.frame_rate = frame_rate
48
  self.next_id = 1
49
  self.tracks = {}
 
159
  iou = intersection_area / (box1_area + box2_area - intersection_area)
160
  return iou
161
 
162
+ def _is_same_worker(self, pos1, pos2, threshold=100): # Reduced threshold for 384x384 frames
163
  x1, y1 = pos1
164
  x2, y2 = pos2
165
  distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
 
197
  "domain": "login"
198
  },
199
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
200
+ "CONFIDENCE_THRESHOLDS": { # Lowered thresholds to improve detection
201
+ "no_helmet": 0.4,
202
+ "no_harness": 0.25,
203
+ "unsafe_posture": 0.25,
204
+ "unsafe_zone": 0.25,
205
+ "improper_tool_use": 0.25
206
  },
207
  "MIN_VIOLATION_FRAMES": 1,
208
  "VIOLATION_COOLDOWN": 30.0,
 
213
  "PARALLEL_WORKERS": max(1, cpu_count() - 1),
214
  "TRACK_BUFFER": 90,
215
  "TRACK_THRESH": 0.3,
216
+ "MATCH_THRESH": 0.4,
217
  "SNAPSHOT_QUALITY": 95,
218
+ "MAX_WORKER_DISTANCE": 100, # Adjusted to match BYTETracker threshold
219
+ "TARGET_RESOLUTION": (384, 384)
220
  }
221
 
222
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
548
 
549
  worker_id_mapping = {}
550
  unique_violations = {}
551
+ violation_frames = {}
552
  start_time = time.time()
553
  frame_skip = CONFIG["FRAME_SKIP"]
554
  processed_frames = 0
555
  last_yield_time = start_time
556
+ worker_counter = 1 # For assigning unique worker IDs
557
 
558
  while processed_frames < total_frames:
559
  batch_frames = []
 
584
  break
585
 
586
  try:
587
+ batch_frames_np = np.array(batch_frames)
588
  batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
589
  batch_frames_tensor = batch_frames_tensor.to(device)
590
  if device.type == "cuda":
 
639
  np.array([t["conf"] for t in track_inputs]),
640
  np.array([t["cls"] for t in track_inputs])
641
  )
642
+ logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
643
 
644
  for obj in tracked_objects:
645
  tracker_id = obj['id']
 
650
  if label is None:
651
  continue
652
 
653
+ if tracker_id not in worker_id_mapping:
654
+ worker_id_mapping[tracker_id] = worker_counter
655
+ worker_counter += 1
 
656
 
657
  worker_id = worker_id_mapping[tracker_id]
658
 
 
665
  cap.release()
666
  processing_time = time.time() - start_time
667
  logger.info(f"Processing complete in {processing_time:.2f}s")
668
+ logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
669
 
670
  violations = []
671
  for (worker_id, label), detection_time in unique_violations.items():
 
673
  "worker_id": worker_id,
674
  "violation": label,
675
  "timestamp": detection_time,
676
+ "confidence": 0.0,
677
  "frame_idx": violation_frames[(worker_id, label)]
678
  })
679
 
 
682
  yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
683
  return
684
 
 
685
  snapshots = []
686
  cap = cv2.VideoCapture(video_path)
687
  for violation in violations: