Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -260,6 +260,10 @@ def process_video(video_data):
|
|
| 260 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 261 |
workers = [] # List to track workers
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
while True:
|
| 264 |
ret, frame = video.read()
|
| 265 |
if not ret:
|
|
@@ -273,21 +277,35 @@ def process_video(video_data):
|
|
| 273 |
logger.info("Processing time limit reached")
|
| 274 |
break
|
| 275 |
|
|
|
|
| 276 |
results = model(frame, device=device)
|
| 277 |
current_detections = []
|
|
|
|
|
|
|
| 278 |
for result in results:
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
| 280 |
cls, conf = int(box.cls), float(box.conf)
|
| 281 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 282 |
|
| 283 |
-
#
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
if label not in CONFIG["VIOLATION_LABELS"].values()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
continue
|
| 289 |
|
|
|
|
| 290 |
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
|
|
|
|
|
|
| 291 |
current_detections.append({
|
| 292 |
"violation": label,
|
| 293 |
"confidence": round(conf, 2),
|
|
@@ -296,10 +314,13 @@ def process_video(video_data):
|
|
| 296 |
"frame": frame_count
|
| 297 |
})
|
| 298 |
|
| 299 |
-
# Process detections and workers
|
|
|
|
| 300 |
for detection in current_detections:
|
| 301 |
matched_worker = None
|
| 302 |
max_iou = 0
|
|
|
|
|
|
|
| 303 |
for worker in workers:
|
| 304 |
iou = calculate_iou(detection["bounding_box"], worker["bbox"])
|
| 305 |
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
|
|
@@ -309,6 +330,8 @@ def process_video(video_data):
|
|
| 309 |
if matched_worker:
|
| 310 |
# Update existing worker
|
| 311 |
if detection["violation"] not in matched_worker["violations"]:
|
|
|
|
|
|
|
| 312 |
matched_worker["violations"].add(detection["violation"])
|
| 313 |
violations.append({
|
| 314 |
"frame": frame_count,
|
|
@@ -318,18 +341,34 @@ def process_video(video_data):
|
|
| 318 |
"timestamp": detection["timestamp"],
|
| 319 |
"worker_id": matched_worker["id"]
|
| 320 |
})
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
matched_worker["bbox"] = detection["bounding_box"]
|
| 323 |
matched_worker["last_frame"] = frame_count
|
| 324 |
else:
|
| 325 |
-
# New worker
|
| 326 |
worker_id = len(workers) + 1
|
|
|
|
| 327 |
workers.append({
|
| 328 |
"id": worker_id,
|
| 329 |
"violations": {detection["violation"]},
|
| 330 |
"bbox": detection["bounding_box"],
|
| 331 |
"last_frame": frame_count
|
| 332 |
})
|
|
|
|
| 333 |
violations.append({
|
| 334 |
"frame": frame_count,
|
| 335 |
"violation": detection["violation"],
|
|
@@ -338,11 +377,37 @@ def process_video(video_data):
|
|
| 338 |
"timestamp": detection["timestamp"],
|
| 339 |
"worker_id": worker_id
|
| 340 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
frame_count += 1
|
| 343 |
|
| 344 |
video.release()
|
| 345 |
os.remove(video_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
if not violations:
|
| 348 |
logger.info("No violations detected")
|
|
@@ -441,4 +506,4 @@ interface = gr.Interface(
|
|
| 441 |
|
| 442 |
if __name__ == "__main__":
|
| 443 |
logger.info("Launching Safety Analyzer App...")
|
| 444 |
-
interface.launch()
|
|
|
|
| 260 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 261 |
workers = [] # List to track workers
|
| 262 |
|
| 263 |
+
# Adding debug logging for violation labels
|
| 264 |
+
logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
|
| 265 |
+
logger.info(f"Using confidence threshold: {CONFIG['CONFIDENCE_THRESHOLD']}")
|
| 266 |
+
|
| 267 |
while True:
|
| 268 |
ret, frame = video.read()
|
| 269 |
if not ret:
|
|
|
|
| 277 |
logger.info("Processing time limit reached")
|
| 278 |
break
|
| 279 |
|
| 280 |
+
# Run detection on this frame
|
| 281 |
results = model(frame, device=device)
|
| 282 |
current_detections = []
|
| 283 |
+
|
| 284 |
+
# Process detections from the model
|
| 285 |
for result in results:
|
| 286 |
+
boxes = result.boxes
|
| 287 |
+
logger.info(f"Frame {frame_count}: Found {len(boxes)} potential detections")
|
| 288 |
+
|
| 289 |
+
for box in boxes:
|
| 290 |
cls, conf = int(box.cls), float(box.conf)
|
| 291 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 292 |
|
| 293 |
+
# Enhanced logging
|
| 294 |
+
logger.info(f"Detection: class={cls}, conf={conf:.2f}, label={label}")
|
| 295 |
+
|
| 296 |
+
# Skip if not a known violation or below confidence threshold
|
| 297 |
+
if label not in CONFIG["VIOLATION_LABELS"].values():
|
| 298 |
+
logger.info(f"Skipping unknown class: {cls}")
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 302 |
+
logger.info(f"Skipping low confidence: {conf:.2f} < {CONFIG['CONFIDENCE_THRESHOLD']}")
|
| 303 |
continue
|
| 304 |
|
| 305 |
+
# Process valid detection
|
| 306 |
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
| 307 |
+
logger.info(f"Valid detection: {label} with confidence: {conf:.2f}")
|
| 308 |
+
|
| 309 |
current_detections.append({
|
| 310 |
"violation": label,
|
| 311 |
"confidence": round(conf, 2),
|
|
|
|
| 314 |
"frame": frame_count
|
| 315 |
})
|
| 316 |
|
| 317 |
+
# Process detections and associate with workers
|
| 318 |
+
# FIXED: Improved worker tracking logic
|
| 319 |
for detection in current_detections:
|
| 320 |
matched_worker = None
|
| 321 |
max_iou = 0
|
| 322 |
+
|
| 323 |
+
# Try to match with existing workers
|
| 324 |
for worker in workers:
|
| 325 |
iou = calculate_iou(detection["bounding_box"], worker["bbox"])
|
| 326 |
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
|
|
|
|
| 330 |
if matched_worker:
|
| 331 |
# Update existing worker
|
| 332 |
if detection["violation"] not in matched_worker["violations"]:
|
| 333 |
+
# New violation for this worker
|
| 334 |
+
logger.info(f"New violation for worker {matched_worker['id']}: {detection['violation']}")
|
| 335 |
matched_worker["violations"].add(detection["violation"])
|
| 336 |
violations.append({
|
| 337 |
"frame": frame_count,
|
|
|
|
| 341 |
"timestamp": detection["timestamp"],
|
| 342 |
"worker_id": matched_worker["id"]
|
| 343 |
})
|
| 344 |
+
|
| 345 |
+
# Save snapshot for this violation type if not already taken
|
| 346 |
+
if not snapshot_taken[detection["violation"]]:
|
| 347 |
+
snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
|
| 348 |
+
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
|
| 349 |
+
cv2.imwrite(snapshot_path, frame)
|
| 350 |
+
snapshot_taken[detection["violation"]] = True
|
| 351 |
+
snapshots.append({
|
| 352 |
+
"violation": detection["violation"],
|
| 353 |
+
"frame": frame_count,
|
| 354 |
+
"snapshot_path": snapshot_path,
|
| 355 |
+
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
|
| 356 |
+
})
|
| 357 |
+
|
| 358 |
+
# Update worker position
|
| 359 |
matched_worker["bbox"] = detection["bounding_box"]
|
| 360 |
matched_worker["last_frame"] = frame_count
|
| 361 |
else:
|
| 362 |
+
# New worker detected
|
| 363 |
worker_id = len(workers) + 1
|
| 364 |
+
logger.info(f"New worker {worker_id} with violation: {detection['violation']}")
|
| 365 |
workers.append({
|
| 366 |
"id": worker_id,
|
| 367 |
"violations": {detection["violation"]},
|
| 368 |
"bbox": detection["bounding_box"],
|
| 369 |
"last_frame": frame_count
|
| 370 |
})
|
| 371 |
+
|
| 372 |
violations.append({
|
| 373 |
"frame": frame_count,
|
| 374 |
"violation": detection["violation"],
|
|
|
|
| 377 |
"timestamp": detection["timestamp"],
|
| 378 |
"worker_id": worker_id
|
| 379 |
})
|
| 380 |
+
|
| 381 |
+
# Save snapshot for this violation type if not already taken
|
| 382 |
+
if not snapshot_taken[detection["violation"]]:
|
| 383 |
+
snapshot_filename = f"{detection['violation']}_{frame_count}.jpg"
|
| 384 |
+
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
|
| 385 |
+
cv2.imwrite(snapshot_path, frame)
|
| 386 |
+
snapshot_taken[detection["violation"]] = True
|
| 387 |
+
snapshots.append({
|
| 388 |
+
"violation": detection["violation"],
|
| 389 |
+
"frame": frame_count,
|
| 390 |
+
"snapshot_path": snapshot_path,
|
| 391 |
+
"snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
|
| 392 |
+
})
|
| 393 |
|
| 394 |
+
# Clean up workers that haven't been seen for a while
|
| 395 |
+
active_workers = [w for w in workers if frame_count - w["last_frame"] < CONFIG["FRAME_SKIP"] * 5]
|
| 396 |
+
if len(active_workers) != len(workers):
|
| 397 |
+
logger.info(f"Cleaned up {len(workers) - len(active_workers)} inactive workers")
|
| 398 |
+
workers = active_workers
|
| 399 |
+
|
| 400 |
frame_count += 1
|
| 401 |
|
| 402 |
video.release()
|
| 403 |
os.remove(video_path)
|
| 404 |
+
|
| 405 |
+
# Final log of violations detected
|
| 406 |
+
violation_types = {}
|
| 407 |
+
for v in violations:
|
| 408 |
+
violation_types[v["violation"]] = violation_types.get(v["violation"], 0) + 1
|
| 409 |
+
|
| 410 |
+
logger.info(f"Detection complete. Found violations: {violation_types}")
|
| 411 |
|
| 412 |
if not violations:
|
| 413 |
logger.info("No violations detected")
|
|
|
|
| 506 |
|
| 507 |
if __name__ == "__main__":
|
| 508 |
logger.info("Launching Safety Analyzer App...")
|
| 509 |
+
interface.launch()
|