Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -49,7 +49,7 @@ CONFIG = {
|
|
| 49 |
"domain": "login"
|
| 50 |
},
|
| 51 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 52 |
-
"FRAME_SKIP":
|
| 53 |
"MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
|
| 54 |
"CONFIDENCE_THRESHOLD": { # Per-class thresholds
|
| 55 |
"no_helmet": 0.4,
|
|
@@ -59,7 +59,7 @@ CONFIG = {
|
|
| 59 |
"improper_tool_use": 0.35
|
| 60 |
},
|
| 61 |
"IOU_THRESHOLD": 0.4, # For worker tracking
|
| 62 |
-
"MIN_VIOLATION_FRAMES":
|
| 63 |
}
|
| 64 |
|
| 65 |
# Setup logging
|
|
@@ -239,6 +239,7 @@ def process_video(video_path):
|
|
| 239 |
snapshots = []
|
| 240 |
workers = []
|
| 241 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
|
|
|
| 242 |
|
| 243 |
while cap.isOpened():
|
| 244 |
ret, frame = cap.read()
|
|
@@ -271,37 +272,18 @@ def process_video(video_path):
|
|
| 271 |
"timestamp": current_time
|
| 272 |
}
|
| 273 |
|
| 274 |
-
#
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
for worker in workers:
|
| 278 |
-
iou = calculate_iou(worker["bbox"], bbox)
|
| 279 |
-
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
|
| 280 |
-
max_iou = iou
|
| 281 |
-
matched_worker = worker
|
| 282 |
-
|
| 283 |
-
if matched_worker:
|
| 284 |
-
worker_id = matched_worker["id"]
|
| 285 |
-
matched_worker["bbox"] = bbox
|
| 286 |
-
else:
|
| 287 |
-
worker_id = len(workers) + 1
|
| 288 |
-
workers.append({"id": worker_id, "bbox": bbox})
|
| 289 |
-
|
| 290 |
detection["worker_id"] = worker_id
|
| 291 |
violations.append(detection)
|
| 292 |
|
| 293 |
-
#
|
| 294 |
if not snapshot_taken[label]:
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
cv2.imwrite(snapshot_path, draw_detections(frame.copy(), [detection]))
|
| 300 |
-
snapshots.append({
|
| 301 |
-
"violation": label,
|
| 302 |
-
"frame": frame_count,
|
| 303 |
-
"path": snapshot_path,
|
| 304 |
-
"url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
|
| 305 |
})
|
| 306 |
snapshot_taken[label] = True
|
| 307 |
|
|
@@ -309,6 +291,23 @@ def process_video(video_path):
|
|
| 309 |
|
| 310 |
cap.release()
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
# Filter violations (require min frames)
|
| 313 |
filtered_violations = []
|
| 314 |
violation_counts = {}
|
|
|
|
| 49 |
"domain": "login"
|
| 50 |
},
|
| 51 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 52 |
+
"FRAME_SKIP": 10, # Increased to process every 10th frame (faster)
|
| 53 |
"MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
|
| 54 |
"CONFIDENCE_THRESHOLD": { # Per-class thresholds
|
| 55 |
"no_helmet": 0.4,
|
|
|
|
| 59 |
"improper_tool_use": 0.35
|
| 60 |
},
|
| 61 |
"IOU_THRESHOLD": 0.4, # For worker tracking
|
| 62 |
+
"MIN_VIOLATION_FRAMES": 2 # Reduced to 2 frames for faster confirmation
|
| 63 |
}
|
| 64 |
|
| 65 |
# Setup logging
|
|
|
|
| 239 |
snapshots = []
|
| 240 |
workers = []
|
| 241 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 242 |
+
frames_to_save = [] # Store frames for snapshot saving later
|
| 243 |
|
| 244 |
while cap.isOpened():
|
| 245 |
ret, frame = cap.read()
|
|
|
|
| 272 |
"timestamp": current_time
|
| 273 |
}
|
| 274 |
|
| 275 |
+
# Simplified worker tracking
|
| 276 |
+
worker_id = len(workers) + 1 # Assign new ID without IoU for speed
|
| 277 |
+
workers.append({"id": worker_id, "bbox": bbox})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
detection["worker_id"] = worker_id
|
| 279 |
violations.append(detection)
|
| 280 |
|
| 281 |
+
# Store frame for snapshot if first detection of this type
|
| 282 |
if not snapshot_taken[label]:
|
| 283 |
+
frames_to_save.append({
|
| 284 |
+
"frame": frame.copy(),
|
| 285 |
+
"detection": detection,
|
| 286 |
+
"label": label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
})
|
| 288 |
snapshot_taken[label] = True
|
| 289 |
|
|
|
|
| 291 |
|
| 292 |
cap.release()
|
| 293 |
|
| 294 |
+
# Save snapshots after processing all frames
|
| 295 |
+
for item in frames_to_save:
|
| 296 |
+
frame = item["frame"]
|
| 297 |
+
detection = item["detection"]
|
| 298 |
+
label = item["label"]
|
| 299 |
+
snapshot_path = os.path.join(
|
| 300 |
+
CONFIG["OUTPUT_DIR"],
|
| 301 |
+
f"{label}_{detection['frame']}.jpg"
|
| 302 |
+
)
|
| 303 |
+
cv2.imwrite(snapshot_path, draw_detections(frame, [detection]))
|
| 304 |
+
snapshots.append({
|
| 305 |
+
"violation": label,
|
| 306 |
+
"frame": detection["frame"],
|
| 307 |
+
"path": snapshot_path,
|
| 308 |
+
"url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}"
|
| 309 |
+
})
|
| 310 |
+
|
| 311 |
# Filter violations (require min frames)
|
| 312 |
filtered_violations = []
|
| 313 |
violation_counts = {}
|