PrashanthB461 commited on
Commit
c031717
·
verified ·
1 Parent(s): 1285b50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -62
app.py CHANGED
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
25
 
26
  # ========================== # Enhanced Tracker Implementation # ==========================
27
  class EnhancedTracker:
28
- def __init__(self, track_thresh=0.4, track_buffer=60, match_thresh=0.9, frame_rate=30):
29
  self.track_thresh = track_thresh
30
  self.track_buffer = track_buffer
31
  self.match_thresh = match_thresh
@@ -34,8 +34,9 @@ class EnhancedTracker:
34
  self.tracks = {}
35
  self.worker_history = {}
36
  self.last_positions = {}
37
- self.flagged_violations = {}
38
- self.permanent_violations = {}
 
39
 
40
  def update(self, dets, scores, cls):
41
  tracks = []
@@ -96,7 +97,6 @@ class EnhancedTracker:
96
  for worker_id, last_pos in self.last_positions.items():
97
  if self._is_same_worker([x, y], last_pos):
98
  self._update_existing_track(worker_id, x, y, w, h, score, cl, current_time)
99
- logger.debug(f"Reused existing worker ID {worker_id} for new detection")
100
  return worker_id
101
 
102
  self.tracks[self.next_id] = {
@@ -107,8 +107,7 @@ class EnhancedTracker:
107
  }
108
  self.worker_history[self.next_id] = [[x, y]]
109
  self.last_positions[self.next_id] = [x, y]
110
- self.permanent_violations[self.next_id] = set()
111
- logger.info(f"Created new worker ID {self.next_id}")
112
  self.next_id += 1
113
  return self.next_id - 1
114
 
@@ -121,8 +120,6 @@ class EnhancedTracker:
121
  self.worker_history.pop(tid, None)
122
  self.last_positions.pop(tid, None)
123
  self.flagged_violations.pop(tid, None)
124
- self.permanent_violations.pop(tid, None)
125
- logger.debug(f"Cleaned up stale worker ID {tid}")
126
 
127
  def _calculate_iou(self, box1, box2):
128
  x1, y1, w1, h1 = box1
@@ -141,37 +138,32 @@ class EnhancedTracker:
141
  area2 = w2 * h2
142
  return intersection / (area1 + area2 - intersection)
143
 
144
- def _is_same_worker(self, pos1, pos2, threshold=50):
145
  x1, y1 = pos1
146
  x2, y2 = pos2
147
- distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
148
- logger.debug(f"Distance between positions: {distance:.2f}")
149
- return distance < threshold
150
 
151
  def should_alert_violation(self, worker_id, violation_type, position, cooldown=45.0, distance_thresh=120):
152
  if worker_id not in self.permanent_violations:
153
  self.permanent_violations[worker_id] = set()
154
- logger.debug(f"Initialized permanent_violations for worker {worker_id}")
155
-
156
  if violation_type == "no_helmet":
157
  if "no_helmet" in self.permanent_violations[worker_id]:
158
- logger.debug(f"Skipped no_helmet violation for worker {worker_id} (already reported)")
159
  return False
160
  self.permanent_violations[worker_id].add("no_helmet")
161
- logger.info(f"Recorded no_helmet violation for worker {worker_id}")
162
  return True
163
 
 
164
  if worker_id not in self.flagged_violations:
165
  self.flagged_violations[worker_id] = {}
166
 
167
  if violation_type in self.flagged_violations[worker_id]:
168
  last_time, last_pos, _ = self.flagged_violations[worker_id][violation_type]
169
  if time.time() - last_time < cooldown and self._is_same_worker(position, last_pos, distance_thresh):
170
- logger.debug(f"Skipped {violation_type} for worker {worker_id} due to cooldown")
171
  return False
172
 
173
  self.flagged_violations[worker_id][violation_type] = (time.time(), position, 1)
174
- logger.info(f"Recorded {violation_type} violation for worker {worker_id}")
175
  return True
176
 
177
  # ========================== # Optimized Configuration # ==========================
@@ -217,13 +209,13 @@ CONFIG = {
217
  "VIOLATION_COOLDOWN": 45.0,
218
  "WORKER_TRACKING_DURATION": 20.0,
219
  "MAX_PROCESSING_TIME": 45,
220
- "FRAME_SKIP": 5,
221
- "BATCH_SIZE": 16,
222
  "TRACK_BUFFER": 60,
223
  "TRACK_THRESH": 0.4,
224
- "MATCH_THRESH": 0.9,
225
- "SNAPSHOT_QUALITY": 85,
226
- "MAX_WORKER_DISTANCE": 50,
227
  "VIOLATION_DISTANCE_THRESH": 120
228
  }
229
 
@@ -232,18 +224,10 @@ logger.info(f"Using device: {device}")
232
 
233
  def load_model():
234
  try:
235
- model_path = CONFIG["MODEL_PATH"]
236
  if not os.path.isfile(model_path):
237
- logger.error(f"Custom model {model_path} not found. Please provide a trained safety model.")
238
- raise FileNotFoundError(f"Custom model {model_path} not found")
239
-
240
  model = YOLO(model_path).to(device)
241
- expected_classes = set(CONFIG["VIOLATION_LABELS"].values())
242
- model_classes = set(model.names.values())
243
- if not expected_classes.issubset(model_classes):
244
- logger.error(f"Model classes {model.names} do not contain required safety classes: {expected_classes}")
245
- raise ValueError("Loaded model does not support required safety violation classes")
246
-
247
  logger.info(f"Model loaded with classes: {model.names}")
248
  return model
249
  except Exception as e:
@@ -252,13 +236,10 @@ def load_model():
252
 
253
  model = load_model()
254
 
255
- # Global cache for no_helmet violations
256
- NO_HELMET_CACHE = set()
257
-
258
  # ========================== # Helper Functions # ==========================
259
  def preprocess_frame(frame):
260
  frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=25)
261
- frame = cv2.GaussianBlur(frame, (5, 5), 0)
262
  return frame
263
 
264
  def draw_detections(frame, detections):
@@ -501,20 +482,6 @@ def process_video(video_file):
501
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
502
 
503
  if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
504
- worker_id = None
505
- if label == "no_helmet":
506
- for tid, track in tracker.tracks.items():
507
- if tid in NO_HELMET_CACHE:
508
- continue
509
- tx, ty, tw, th = track['bbox']
510
- iou = tracker._calculate_iou(box.xywh.cpu().numpy()[0], [tx, ty, tw, th])
511
- if iou > CONFIG["MATCH_THRESH"]:
512
- worker_id = tid
513
- break
514
- if worker_id and worker_id in NO_HELMET_CACHE:
515
- logger.debug(f"Skipped no_helmet detection for worker {worker_id} (already in cache)")
516
- continue
517
-
518
  track_inputs.append({
519
  "bbox": box.xywh.cpu().numpy()[0],
520
  "conf": conf,
@@ -535,19 +502,11 @@ def process_video(video_file):
535
  bbox = obj['bbox']
536
  position = (bbox[0], bbox[1])
537
 
538
- if label == "no_helmet" and worker_id in NO_HELMET_CACHE:
539
- logger.debug(f"Skipped no_helmet violation for worker {worker_id} (already reported)")
540
- continue
541
-
542
  if label and tracker.should_alert_violation(
543
  worker_id, label, position,
544
  CONFIG["VIOLATION_COOLDOWN"],
545
  CONFIG["VIOLATION_DISTANCE_THRESH"]
546
  ):
547
- if label == "no_helmet":
548
- NO_HELMET_CACHE.add(worker_id)
549
- logger.info(f"Added worker {worker_id} to no_helmet cache")
550
-
551
  if worker_id not in unique_violations:
552
  unique_violations[worker_id] = {}
553
 
@@ -625,9 +584,6 @@ def gradio_interface(video_file):
625
  if not video_file:
626
  return "No file uploaded", "", "", "", ""
627
 
628
- NO_HELMET_CACHE.clear()
629
- logger.info("Cleared NO_HELMET_CACHE for new video processing")
630
-
631
  for result in process_video(video_file):
632
  yield result
633
 
 
25
 
26
  # ========================== # Enhanced Tracker Implementation # ==========================
27
  class EnhancedTracker:
28
+ def __init__(self, track_thresh=0.4, track_buffer=60, match_thresh=0.8, frame_rate=30):
29
  self.track_thresh = track_thresh
30
  self.track_buffer = track_buffer
31
  self.match_thresh = match_thresh
 
34
  self.tracks = {}
35
  self.worker_history = {}
36
  self.last_positions = {}
37
+ self.flagged_violations = {} # {worker_id: {violation_type: (time, position, count)}}
38
+ self.permanent_violations = {} # {worker_id: set(violation_types)}
39
+ self.violation_zones = {}
40
 
41
  def update(self, dets, scores, cls):
42
  tracks = []
 
97
  for worker_id, last_pos in self.last_positions.items():
98
  if self._is_same_worker([x, y], last_pos):
99
  self._update_existing_track(worker_id, x, y, w, h, score, cl, current_time)
 
100
  return worker_id
101
 
102
  self.tracks[self.next_id] = {
 
107
  }
108
  self.worker_history[self.next_id] = [[x, y]]
109
  self.last_positions[self.next_id] = [x, y]
110
+ self.permanent_violations[self.next_id] = set() # Initialize permanent violations set
 
111
  self.next_id += 1
112
  return self.next_id - 1
113
 
 
120
  self.worker_history.pop(tid, None)
121
  self.last_positions.pop(tid, None)
122
  self.flagged_violations.pop(tid, None)
 
 
123
 
124
  def _calculate_iou(self, box1, box2):
125
  x1, y1, w1, h1 = box1
 
138
  area2 = w2 * h2
139
  return intersection / (area1 + area2 - intersection)
140
 
141
+ def _is_same_worker(self, pos1, pos2, threshold=80):
142
  x1, y1 = pos1
143
  x2, y2 = pos2
144
+ return np.sqrt((x1 - x2)**2 + (y1 - y2)**2) < threshold
 
 
145
 
146
  def should_alert_violation(self, worker_id, violation_type, position, cooldown=45.0, distance_thresh=120):
147
  if worker_id not in self.permanent_violations:
148
  self.permanent_violations[worker_id] = set()
149
+
150
+ # For no_helmet, only report once per worker
151
  if violation_type == "no_helmet":
152
  if "no_helmet" in self.permanent_violations[worker_id]:
 
153
  return False
154
  self.permanent_violations[worker_id].add("no_helmet")
 
155
  return True
156
 
157
+ # For other violations, use cooldown logic
158
  if worker_id not in self.flagged_violations:
159
  self.flagged_violations[worker_id] = {}
160
 
161
  if violation_type in self.flagged_violations[worker_id]:
162
  last_time, last_pos, _ = self.flagged_violations[worker_id][violation_type]
163
  if time.time() - last_time < cooldown and self._is_same_worker(position, last_pos, distance_thresh):
 
164
  return False
165
 
166
  self.flagged_violations[worker_id][violation_type] = (time.time(), position, 1)
 
167
  return True
168
 
169
  # ========================== # Optimized Configuration # ==========================
 
209
  "VIOLATION_COOLDOWN": 45.0,
210
  "WORKER_TRACKING_DURATION": 20.0,
211
  "MAX_PROCESSING_TIME": 45,
212
+ "FRAME_SKIP": 5, # Increased for faster processing
213
+ "BATCH_SIZE": 16, # Increased for better throughput
214
  "TRACK_BUFFER": 60,
215
  "TRACK_THRESH": 0.4,
216
+ "MATCH_THRESH": 0.8,
217
+ "SNAPSHOT_QUALITY": 85, # Slightly reduced for faster saving
218
+ "MAX_WORKER_DISTANCE": 80,
219
  "VIOLATION_DISTANCE_THRESH": 120
220
  }
221
 
 
224
 
225
  def load_model():
226
  try:
227
+ model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
228
  if not os.path.isfile(model_path):
229
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
 
 
230
  model = YOLO(model_path).to(device)
 
 
 
 
 
 
231
  logger.info(f"Model loaded with classes: {model.names}")
232
  return model
233
  except Exception as e:
 
236
 
237
  model = load_model()
238
 
 
 
 
239
  # ========================== # Helper Functions # ==========================
240
  def preprocess_frame(frame):
241
  frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=25)
242
+ frame = cv2.GaussianBlur(frame, (5, 5), 0) # Noise reduction
243
  return frame
244
 
245
  def draw_detections(frame, detections):
 
482
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
483
 
484
  if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  track_inputs.append({
486
  "bbox": box.xywh.cpu().numpy()[0],
487
  "conf": conf,
 
502
  bbox = obj['bbox']
503
  position = (bbox[0], bbox[1])
504
 
 
 
 
 
505
  if label and tracker.should_alert_violation(
506
  worker_id, label, position,
507
  CONFIG["VIOLATION_COOLDOWN"],
508
  CONFIG["VIOLATION_DISTANCE_THRESH"]
509
  ):
 
 
 
 
510
  if worker_id not in unique_violations:
511
  unique_violations[worker_id] = {}
512
 
 
584
  if not video_file:
585
  return "No file uploaded", "", "", "", ""
586
 
 
 
 
587
  for result in process_video(video_file):
588
  yield result
589