Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -263,16 +263,16 @@ CONFIG = {
|
|
| 263 |
"MIN_VIOLATION_FRAMES": 2,
|
| 264 |
"VIOLATION_COOLDOWN": 30.0,
|
| 265 |
"WORKER_TRACKING_DURATION": 10.0,
|
| 266 |
-
"MAX_PROCESSING_TIME":
|
| 267 |
-
"FRAME_SKIP": 4,
|
| 268 |
-
"BATCH_SIZE":
|
| 269 |
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
|
| 270 |
"TRACK_BUFFER": 150,
|
| 271 |
"TRACK_THRESH": 0.3,
|
| 272 |
"MATCH_THRESH": 0.5,
|
| 273 |
"SNAPSHOT_QUALITY": 95,
|
| 274 |
"MAX_WORKER_DISTANCE": 150,
|
| 275 |
-
"TARGET_RESOLUTION": (
|
| 276 |
"HELMET_VALIDATION_FRAMES": 3
|
| 277 |
}
|
| 278 |
|
|
@@ -308,8 +308,7 @@ def preprocess_frame(frame):
|
|
| 308 |
target_res = CONFIG["TARGET_RESOLUTION"]
|
| 309 |
frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
|
| 310 |
frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20)
|
| 311 |
-
|
| 312 |
-
frame = cv2.filter2D(frame, -1, kernel)
|
| 313 |
return frame
|
| 314 |
|
| 315 |
def is_unsafe_posture(box, frame_shape):
|
|
@@ -615,38 +614,45 @@ def process_video(video_data, temp_dir):
|
|
| 615 |
unique_violations = {}
|
| 616 |
violation_frames = {}
|
| 617 |
helmet_detections = {}
|
| 618 |
-
frame_detections = {}
|
| 619 |
start_time = time.time()
|
| 620 |
frame_skip = CONFIG["FRAME_SKIP"]
|
| 621 |
-
processed_frames = 0
|
|
|
|
| 622 |
last_yield_time = start_time
|
| 623 |
worker_counter = 1
|
| 624 |
|
| 625 |
-
while
|
| 626 |
batch_frames = []
|
| 627 |
batch_indices = []
|
| 628 |
batch_originals = []
|
| 629 |
|
| 630 |
for _ in range(CONFIG["BATCH_SIZE"]):
|
| 631 |
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
|
|
|
| 632 |
if frame_idx >= total_frames:
|
|
|
|
| 633 |
break
|
| 634 |
ret, frame = cap.read()
|
| 635 |
if not ret:
|
| 636 |
-
logger.warning(f"Failed to read frame {frame_idx}.
|
|
|
|
| 637 |
break
|
| 638 |
original_frame = frame.copy()
|
| 639 |
frame = preprocess_frame(frame)
|
| 640 |
for _ in range(frame_skip - 1):
|
| 641 |
if not cap.grab():
|
|
|
|
|
|
|
| 642 |
break
|
|
|
|
| 643 |
batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
| 644 |
batch_indices.append(frame_idx)
|
| 645 |
batch_originals.append(original_frame)
|
| 646 |
-
processed_frames +=
|
| 647 |
|
| 648 |
if not batch_frames:
|
| 649 |
-
logger.info("No more frames to process.")
|
| 650 |
break
|
| 651 |
|
| 652 |
# Check for timeout
|
|
@@ -673,10 +679,11 @@ def process_video(video_data, temp_dir):
|
|
| 673 |
|
| 674 |
current_time = time.time()
|
| 675 |
if current_time - last_yield_time > 0.1:
|
| 676 |
-
progress = (
|
|
|
|
| 677 |
elapsed_time = current_time - start_time
|
| 678 |
fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
|
| 679 |
-
yield f"Processing video... {progress:.1f}% complete (Frame {
|
| 680 |
last_yield_time = current_time
|
| 681 |
|
| 682 |
for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
|
|
@@ -685,7 +692,7 @@ def process_video(video_data, temp_dir):
|
|
| 685 |
person_boxes = []
|
| 686 |
tool_boxes = []
|
| 687 |
|
| 688 |
-
frame_detections[frame_idx] = []
|
| 689 |
|
| 690 |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 691 |
label_name = model.config.id2label[label.item()]
|
|
@@ -770,6 +777,11 @@ def process_video(video_data, temp_dir):
|
|
| 770 |
logger.info(f"Processing complete in {processing_time:.2f}s")
|
| 771 |
logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
|
| 772 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
violations = []
|
| 774 |
for (worker_id, label), detection_time in unique_violations.items():
|
| 775 |
frame_idx = violation_frames[(worker_id, label)]
|
|
@@ -787,7 +799,6 @@ def process_video(video_data, temp_dir):
|
|
| 787 |
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
|
| 788 |
return
|
| 789 |
|
| 790 |
-
# Generate violation table early for intermediate output
|
| 791 |
violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
|
| 792 |
violation_table += "|-----------|-----------|----------|------------|\n"
|
| 793 |
for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
|
|
@@ -810,7 +821,6 @@ def process_video(video_data, temp_dir):
|
|
| 810 |
continue
|
| 811 |
|
| 812 |
frame = preprocess_frame(frame)
|
| 813 |
-
# Reuse detections instead of re-running inference
|
| 814 |
detections = frame_detections.get(frame_idx, [])
|
| 815 |
for det in detections:
|
| 816 |
if det["label"] == violation["violation"]:
|
|
|
|
| 263 |
"MIN_VIOLATION_FRAMES": 2,
|
| 264 |
"VIOLATION_COOLDOWN": 30.0,
|
| 265 |
"WORKER_TRACKING_DURATION": 10.0,
|
| 266 |
+
"MAX_PROCESSING_TIME": 120, # Increased to allow more time for CPU processing
|
| 267 |
+
"FRAME_SKIP": 4,
|
| 268 |
+
"BATCH_SIZE": 2, # Reduced for better CPU performance
|
| 269 |
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
|
| 270 |
"TRACK_BUFFER": 150,
|
| 271 |
"TRACK_THRESH": 0.3,
|
| 272 |
"MATCH_THRESH": 0.5,
|
| 273 |
"SNAPSHOT_QUALITY": 95,
|
| 274 |
"MAX_WORKER_DISTANCE": 150,
|
| 275 |
+
"TARGET_RESOLUTION": (256, 256), # Further reduced for faster inference
|
| 276 |
"HELMET_VALIDATION_FRAMES": 3
|
| 277 |
}
|
| 278 |
|
|
|
|
| 308 |
target_res = CONFIG["TARGET_RESOLUTION"]
|
| 309 |
frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
|
| 310 |
frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20)
|
| 311 |
+
# Removed cv2.filter2D to reduce processing time
|
|
|
|
| 312 |
return frame
|
| 313 |
|
| 314 |
def is_unsafe_posture(box, frame_shape):
|
|
|
|
| 614 |
unique_violations = {}
|
| 615 |
violation_frames = {}
|
| 616 |
helmet_detections = {}
|
| 617 |
+
frame_detections = {}
|
| 618 |
start_time = time.time()
|
| 619 |
frame_skip = CONFIG["FRAME_SKIP"]
|
| 620 |
+
processed_frames = 0 # Track actual frames processed
|
| 621 |
+
frames_read = 0 # Track frames read from video
|
| 622 |
last_yield_time = start_time
|
| 623 |
worker_counter = 1
|
| 624 |
|
| 625 |
+
while True:
|
| 626 |
batch_frames = []
|
| 627 |
batch_indices = []
|
| 628 |
batch_originals = []
|
| 629 |
|
| 630 |
for _ in range(CONFIG["BATCH_SIZE"]):
|
| 631 |
frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
| 632 |
+
frames_read = frame_idx
|
| 633 |
if frame_idx >= total_frames:
|
| 634 |
+
logger.info("Reached end of video.")
|
| 635 |
break
|
| 636 |
ret, frame = cap.read()
|
| 637 |
if not ret:
|
| 638 |
+
logger.warning(f"Failed to read frame {frame_idx}. Assuming end of video.")
|
| 639 |
+
frames_read = total_frames # Assume we've reached the end
|
| 640 |
break
|
| 641 |
original_frame = frame.copy()
|
| 642 |
frame = preprocess_frame(frame)
|
| 643 |
for _ in range(frame_skip - 1):
|
| 644 |
if not cap.grab():
|
| 645 |
+
logger.warning(f"Failed to grab frame after {frame_idx}. Assuming end of video.")
|
| 646 |
+
frames_read = total_frames
|
| 647 |
break
|
| 648 |
+
frames_read += 1
|
| 649 |
batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
| 650 |
batch_indices.append(frame_idx)
|
| 651 |
batch_originals.append(original_frame)
|
| 652 |
+
processed_frames += 1 # Increment for each frame actually processed
|
| 653 |
|
| 654 |
if not batch_frames:
|
| 655 |
+
logger.info("No more frames to process in this batch.")
|
| 656 |
break
|
| 657 |
|
| 658 |
# Check for timeout
|
|
|
|
| 679 |
|
| 680 |
current_time = time.time()
|
| 681 |
if current_time - last_yield_time > 0.1:
|
| 682 |
+
progress = (frames_read / total_frames) * 100
|
| 683 |
+
progress = min(progress, 100.0) # Ensure progress doesn't exceed 100%
|
| 684 |
elapsed_time = current_time - start_time
|
| 685 |
fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
|
| 686 |
+
yield f"Processing video... {progress:.1f}% complete (Frame {frames_read}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", f"Elapsed: {elapsed_time:.1f}s"
|
| 687 |
last_yield_time = current_time
|
| 688 |
|
| 689 |
for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
|
|
|
|
| 692 |
person_boxes = []
|
| 693 |
tool_boxes = []
|
| 694 |
|
| 695 |
+
frame_detections[frame_idx] = []
|
| 696 |
|
| 697 |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 698 |
label_name = model.config.id2label[label.item()]
|
|
|
|
| 777 |
logger.info(f"Processing complete in {processing_time:.2f}s")
|
| 778 |
logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
|
| 779 |
|
| 780 |
+
# Ensure final progress update
|
| 781 |
+
final_progress = (frames_read / total_frames) * 100
|
| 782 |
+
final_progress = min(final_progress, 100.0)
|
| 783 |
+
yield f"Processing video... {final_progress:.1f}% complete (Frame {frames_read}/{total_frames})", "", "", "", "", f"Elapsed: {processing_time:.1f}s"
|
| 784 |
+
|
| 785 |
violations = []
|
| 786 |
for (worker_id, label), detection_time in unique_violations.items():
|
| 787 |
frame_idx = violation_frames[(worker_id, label)]
|
|
|
|
| 799 |
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
|
| 800 |
return
|
| 801 |
|
|
|
|
| 802 |
violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
|
| 803 |
violation_table += "|-----------|-----------|----------|------------|\n"
|
| 804 |
for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
|
|
|
|
| 821 |
continue
|
| 822 |
|
| 823 |
frame = preprocess_frame(frame)
|
|
|
|
| 824 |
detections = frame_detections.get(frame_idx, [])
|
| 825 |
for det in detections:
|
| 826 |
if det["label"] == violation["violation"]:
|