Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|
| 25 |
|
| 26 |
# ========================== # Enhanced Tracker Implementation # ==========================
|
| 27 |
class EnhancedTracker:
|
| 28 |
-
def __init__(self, track_thresh=0.4, track_buffer=60, match_thresh=0.
|
| 29 |
self.track_thresh = track_thresh
|
| 30 |
self.track_buffer = track_buffer
|
| 31 |
self.match_thresh = match_thresh
|
|
@@ -34,8 +34,9 @@ class EnhancedTracker:
|
|
| 34 |
self.tracks = {}
|
| 35 |
self.worker_history = {}
|
| 36 |
self.last_positions = {}
|
| 37 |
-
self.flagged_violations = {}
|
| 38 |
-
self.permanent_violations = {}
|
|
|
|
| 39 |
|
| 40 |
def update(self, dets, scores, cls):
|
| 41 |
tracks = []
|
|
@@ -96,7 +97,6 @@ class EnhancedTracker:
|
|
| 96 |
for worker_id, last_pos in self.last_positions.items():
|
| 97 |
if self._is_same_worker([x, y], last_pos):
|
| 98 |
self._update_existing_track(worker_id, x, y, w, h, score, cl, current_time)
|
| 99 |
-
logger.debug(f"Reused existing worker ID {worker_id} for new detection")
|
| 100 |
return worker_id
|
| 101 |
|
| 102 |
self.tracks[self.next_id] = {
|
|
@@ -107,8 +107,7 @@ class EnhancedTracker:
|
|
| 107 |
}
|
| 108 |
self.worker_history[self.next_id] = [[x, y]]
|
| 109 |
self.last_positions[self.next_id] = [x, y]
|
| 110 |
-
self.permanent_violations[self.next_id] = set()
|
| 111 |
-
logger.info(f"Created new worker ID {self.next_id}")
|
| 112 |
self.next_id += 1
|
| 113 |
return self.next_id - 1
|
| 114 |
|
|
@@ -121,8 +120,6 @@ class EnhancedTracker:
|
|
| 121 |
self.worker_history.pop(tid, None)
|
| 122 |
self.last_positions.pop(tid, None)
|
| 123 |
self.flagged_violations.pop(tid, None)
|
| 124 |
-
self.permanent_violations.pop(tid, None)
|
| 125 |
-
logger.debug(f"Cleaned up stale worker ID {tid}")
|
| 126 |
|
| 127 |
def _calculate_iou(self, box1, box2):
|
| 128 |
x1, y1, w1, h1 = box1
|
|
@@ -141,37 +138,32 @@ class EnhancedTracker:
|
|
| 141 |
area2 = w2 * h2
|
| 142 |
return intersection / (area1 + area2 - intersection)
|
| 143 |
|
| 144 |
-
def _is_same_worker(self, pos1, pos2, threshold=
|
| 145 |
x1, y1 = pos1
|
| 146 |
x2, y2 = pos2
|
| 147 |
-
|
| 148 |
-
logger.debug(f"Distance between positions: {distance:.2f}")
|
| 149 |
-
return distance < threshold
|
| 150 |
|
| 151 |
def should_alert_violation(self, worker_id, violation_type, position, cooldown=45.0, distance_thresh=120):
|
| 152 |
if worker_id not in self.permanent_violations:
|
| 153 |
self.permanent_violations[worker_id] = set()
|
| 154 |
-
|
| 155 |
-
|
| 156 |
if violation_type == "no_helmet":
|
| 157 |
if "no_helmet" in self.permanent_violations[worker_id]:
|
| 158 |
-
logger.debug(f"Skipped no_helmet violation for worker {worker_id} (already reported)")
|
| 159 |
return False
|
| 160 |
self.permanent_violations[worker_id].add("no_helmet")
|
| 161 |
-
logger.info(f"Recorded no_helmet violation for worker {worker_id}")
|
| 162 |
return True
|
| 163 |
|
|
|
|
| 164 |
if worker_id not in self.flagged_violations:
|
| 165 |
self.flagged_violations[worker_id] = {}
|
| 166 |
|
| 167 |
if violation_type in self.flagged_violations[worker_id]:
|
| 168 |
last_time, last_pos, _ = self.flagged_violations[worker_id][violation_type]
|
| 169 |
if time.time() - last_time < cooldown and self._is_same_worker(position, last_pos, distance_thresh):
|
| 170 |
-
logger.debug(f"Skipped {violation_type} for worker {worker_id} due to cooldown")
|
| 171 |
return False
|
| 172 |
|
| 173 |
self.flagged_violations[worker_id][violation_type] = (time.time(), position, 1)
|
| 174 |
-
logger.info(f"Recorded {violation_type} violation for worker {worker_id}")
|
| 175 |
return True
|
| 176 |
|
| 177 |
# ========================== # Optimized Configuration # ==========================
|
|
@@ -217,13 +209,13 @@ CONFIG = {
|
|
| 217 |
"VIOLATION_COOLDOWN": 45.0,
|
| 218 |
"WORKER_TRACKING_DURATION": 20.0,
|
| 219 |
"MAX_PROCESSING_TIME": 45,
|
| 220 |
-
"FRAME_SKIP": 5,
|
| 221 |
-
"BATCH_SIZE": 16,
|
| 222 |
"TRACK_BUFFER": 60,
|
| 223 |
"TRACK_THRESH": 0.4,
|
| 224 |
-
"MATCH_THRESH": 0.
|
| 225 |
-
"SNAPSHOT_QUALITY": 85,
|
| 226 |
-
"MAX_WORKER_DISTANCE":
|
| 227 |
"VIOLATION_DISTANCE_THRESH": 120
|
| 228 |
}
|
| 229 |
|
|
@@ -232,18 +224,10 @@ logger.info(f"Using device: {device}")
|
|
| 232 |
|
| 233 |
def load_model():
|
| 234 |
try:
|
| 235 |
-
model_path = CONFIG["MODEL_PATH"]
|
| 236 |
if not os.path.isfile(model_path):
|
| 237 |
-
|
| 238 |
-
raise FileNotFoundError(f"Custom model {model_path} not found")
|
| 239 |
-
|
| 240 |
model = YOLO(model_path).to(device)
|
| 241 |
-
expected_classes = set(CONFIG["VIOLATION_LABELS"].values())
|
| 242 |
-
model_classes = set(model.names.values())
|
| 243 |
-
if not expected_classes.issubset(model_classes):
|
| 244 |
-
logger.error(f"Model classes {model.names} do not contain required safety classes: {expected_classes}")
|
| 245 |
-
raise ValueError("Loaded model does not support required safety violation classes")
|
| 246 |
-
|
| 247 |
logger.info(f"Model loaded with classes: {model.names}")
|
| 248 |
return model
|
| 249 |
except Exception as e:
|
|
@@ -252,13 +236,10 @@ def load_model():
|
|
| 252 |
|
| 253 |
model = load_model()
|
| 254 |
|
| 255 |
-
# Global cache for no_helmet violations
|
| 256 |
-
NO_HELMET_CACHE = set()
|
| 257 |
-
|
| 258 |
# ========================== # Helper Functions # ==========================
|
| 259 |
def preprocess_frame(frame):
|
| 260 |
frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=25)
|
| 261 |
-
frame = cv2.GaussianBlur(frame, (5, 5), 0)
|
| 262 |
return frame
|
| 263 |
|
| 264 |
def draw_detections(frame, detections):
|
|
@@ -501,20 +482,6 @@ def process_video(video_file):
|
|
| 501 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 502 |
|
| 503 |
if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
|
| 504 |
-
worker_id = None
|
| 505 |
-
if label == "no_helmet":
|
| 506 |
-
for tid, track in tracker.tracks.items():
|
| 507 |
-
if tid in NO_HELMET_CACHE:
|
| 508 |
-
continue
|
| 509 |
-
tx, ty, tw, th = track['bbox']
|
| 510 |
-
iou = tracker._calculate_iou(box.xywh.cpu().numpy()[0], [tx, ty, tw, th])
|
| 511 |
-
if iou > CONFIG["MATCH_THRESH"]:
|
| 512 |
-
worker_id = tid
|
| 513 |
-
break
|
| 514 |
-
if worker_id and worker_id in NO_HELMET_CACHE:
|
| 515 |
-
logger.debug(f"Skipped no_helmet detection for worker {worker_id} (already in cache)")
|
| 516 |
-
continue
|
| 517 |
-
|
| 518 |
track_inputs.append({
|
| 519 |
"bbox": box.xywh.cpu().numpy()[0],
|
| 520 |
"conf": conf,
|
|
@@ -535,19 +502,11 @@ def process_video(video_file):
|
|
| 535 |
bbox = obj['bbox']
|
| 536 |
position = (bbox[0], bbox[1])
|
| 537 |
|
| 538 |
-
if label == "no_helmet" and worker_id in NO_HELMET_CACHE:
|
| 539 |
-
logger.debug(f"Skipped no_helmet violation for worker {worker_id} (already reported)")
|
| 540 |
-
continue
|
| 541 |
-
|
| 542 |
if label and tracker.should_alert_violation(
|
| 543 |
worker_id, label, position,
|
| 544 |
CONFIG["VIOLATION_COOLDOWN"],
|
| 545 |
CONFIG["VIOLATION_DISTANCE_THRESH"]
|
| 546 |
):
|
| 547 |
-
if label == "no_helmet":
|
| 548 |
-
NO_HELMET_CACHE.add(worker_id)
|
| 549 |
-
logger.info(f"Added worker {worker_id} to no_helmet cache")
|
| 550 |
-
|
| 551 |
if worker_id not in unique_violations:
|
| 552 |
unique_violations[worker_id] = {}
|
| 553 |
|
|
@@ -625,9 +584,6 @@ def gradio_interface(video_file):
|
|
| 625 |
if not video_file:
|
| 626 |
return "No file uploaded", "", "", "", ""
|
| 627 |
|
| 628 |
-
NO_HELMET_CACHE.clear()
|
| 629 |
-
logger.info("Cleared NO_HELMET_CACHE for new video processing")
|
| 630 |
-
|
| 631 |
for result in process_video(video_file):
|
| 632 |
yield result
|
| 633 |
|
|
|
|
| 25 |
|
| 26 |
# ========================== # Enhanced Tracker Implementation # ==========================
|
| 27 |
class EnhancedTracker:
|
| 28 |
+
def __init__(self, track_thresh=0.4, track_buffer=60, match_thresh=0.8, frame_rate=30):
|
| 29 |
self.track_thresh = track_thresh
|
| 30 |
self.track_buffer = track_buffer
|
| 31 |
self.match_thresh = match_thresh
|
|
|
|
| 34 |
self.tracks = {}
|
| 35 |
self.worker_history = {}
|
| 36 |
self.last_positions = {}
|
| 37 |
+
self.flagged_violations = {} # {worker_id: {violation_type: (time, position, count)}}
|
| 38 |
+
self.permanent_violations = {} # {worker_id: set(violation_types)}
|
| 39 |
+
self.violation_zones = {}
|
| 40 |
|
| 41 |
def update(self, dets, scores, cls):
|
| 42 |
tracks = []
|
|
|
|
| 97 |
for worker_id, last_pos in self.last_positions.items():
|
| 98 |
if self._is_same_worker([x, y], last_pos):
|
| 99 |
self._update_existing_track(worker_id, x, y, w, h, score, cl, current_time)
|
|
|
|
| 100 |
return worker_id
|
| 101 |
|
| 102 |
self.tracks[self.next_id] = {
|
|
|
|
| 107 |
}
|
| 108 |
self.worker_history[self.next_id] = [[x, y]]
|
| 109 |
self.last_positions[self.next_id] = [x, y]
|
| 110 |
+
self.permanent_violations[self.next_id] = set() # Initialize permanent violations set
|
|
|
|
| 111 |
self.next_id += 1
|
| 112 |
return self.next_id - 1
|
| 113 |
|
|
|
|
| 120 |
self.worker_history.pop(tid, None)
|
| 121 |
self.last_positions.pop(tid, None)
|
| 122 |
self.flagged_violations.pop(tid, None)
|
|
|
|
|
|
|
| 123 |
|
| 124 |
def _calculate_iou(self, box1, box2):
|
| 125 |
x1, y1, w1, h1 = box1
|
|
|
|
| 138 |
area2 = w2 * h2
|
| 139 |
return intersection / (area1 + area2 - intersection)
|
| 140 |
|
| 141 |
+
def _is_same_worker(self, pos1, pos2, threshold=80):
|
| 142 |
x1, y1 = pos1
|
| 143 |
x2, y2 = pos2
|
| 144 |
+
return np.sqrt((x1 - x2)**2 + (y1 - y2)**2) < threshold
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def should_alert_violation(self, worker_id, violation_type, position, cooldown=45.0, distance_thresh=120):
|
| 147 |
if worker_id not in self.permanent_violations:
|
| 148 |
self.permanent_violations[worker_id] = set()
|
| 149 |
+
|
| 150 |
+
# For no_helmet, only report once per worker
|
| 151 |
if violation_type == "no_helmet":
|
| 152 |
if "no_helmet" in self.permanent_violations[worker_id]:
|
|
|
|
| 153 |
return False
|
| 154 |
self.permanent_violations[worker_id].add("no_helmet")
|
|
|
|
| 155 |
return True
|
| 156 |
|
| 157 |
+
# For other violations, use cooldown logic
|
| 158 |
if worker_id not in self.flagged_violations:
|
| 159 |
self.flagged_violations[worker_id] = {}
|
| 160 |
|
| 161 |
if violation_type in self.flagged_violations[worker_id]:
|
| 162 |
last_time, last_pos, _ = self.flagged_violations[worker_id][violation_type]
|
| 163 |
if time.time() - last_time < cooldown and self._is_same_worker(position, last_pos, distance_thresh):
|
|
|
|
| 164 |
return False
|
| 165 |
|
| 166 |
self.flagged_violations[worker_id][violation_type] = (time.time(), position, 1)
|
|
|
|
| 167 |
return True
|
| 168 |
|
| 169 |
# ========================== # Optimized Configuration # ==========================
|
|
|
|
| 209 |
"VIOLATION_COOLDOWN": 45.0,
|
| 210 |
"WORKER_TRACKING_DURATION": 20.0,
|
| 211 |
"MAX_PROCESSING_TIME": 45,
|
| 212 |
+
"FRAME_SKIP": 5, # Increased for faster processing
|
| 213 |
+
"BATCH_SIZE": 16, # Increased for better throughput
|
| 214 |
"TRACK_BUFFER": 60,
|
| 215 |
"TRACK_THRESH": 0.4,
|
| 216 |
+
"MATCH_THRESH": 0.8,
|
| 217 |
+
"SNAPSHOT_QUALITY": 85, # Slightly reduced for faster saving
|
| 218 |
+
"MAX_WORKER_DISTANCE": 80,
|
| 219 |
"VIOLATION_DISTANCE_THRESH": 120
|
| 220 |
}
|
| 221 |
|
|
|
|
| 224 |
|
| 225 |
def load_model():
|
| 226 |
try:
|
| 227 |
+
model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
|
| 228 |
if not os.path.isfile(model_path):
|
| 229 |
+
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
|
|
|
|
|
|
|
| 230 |
model = YOLO(model_path).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
logger.info(f"Model loaded with classes: {model.names}")
|
| 232 |
return model
|
| 233 |
except Exception as e:
|
|
|
|
| 236 |
|
| 237 |
model = load_model()
|
| 238 |
|
|
|
|
|
|
|
|
|
|
| 239 |
# ========================== # Helper Functions # ==========================
|
| 240 |
def preprocess_frame(frame):
|
| 241 |
frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=25)
|
| 242 |
+
frame = cv2.GaussianBlur(frame, (5, 5), 0) # Noise reduction
|
| 243 |
return frame
|
| 244 |
|
| 245 |
def draw_detections(frame, detections):
|
|
|
|
| 482 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 483 |
|
| 484 |
if label and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
track_inputs.append({
|
| 486 |
"bbox": box.xywh.cpu().numpy()[0],
|
| 487 |
"conf": conf,
|
|
|
|
| 502 |
bbox = obj['bbox']
|
| 503 |
position = (bbox[0], bbox[1])
|
| 504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
if label and tracker.should_alert_violation(
|
| 506 |
worker_id, label, position,
|
| 507 |
CONFIG["VIOLATION_COOLDOWN"],
|
| 508 |
CONFIG["VIOLATION_DISTANCE_THRESH"]
|
| 509 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
if worker_id not in unique_violations:
|
| 511 |
unique_violations[worker_id] = {}
|
| 512 |
|
|
|
|
| 584 |
if not video_file:
|
| 585 |
return "No file uploaded", "", "", "", ""
|
| 586 |
|
|
|
|
|
|
|
|
|
|
| 587 |
for result in process_video(video_file):
|
| 588 |
yield result
|
| 589 |
|