Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -40,10 +40,10 @@ FFMPEG_AVAILABLE = check_ffmpeg()
|
|
| 40 |
|
| 41 |
# ========================== # ByteTrack Implementation # ==========================
|
| 42 |
class BYTETracker:
|
| 43 |
-
def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.
|
| 44 |
self.track_thresh = track_thresh
|
| 45 |
self.track_buffer = track_buffer
|
| 46 |
-
self.match_thresh = match_thresh
|
| 47 |
self.frame_rate = frame_rate
|
| 48 |
self.next_id = 1
|
| 49 |
self.tracks = {}
|
|
@@ -159,7 +159,7 @@ class BYTETracker:
|
|
| 159 |
iou = intersection_area / (box1_area + box2_area - intersection_area)
|
| 160 |
return iou
|
| 161 |
|
| 162 |
-
def _is_same_worker(self, pos1, pos2, threshold=
|
| 163 |
x1, y1 = pos1
|
| 164 |
x2, y2 = pos2
|
| 165 |
distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
|
@@ -197,12 +197,12 @@ CONFIG = {
|
|
| 197 |
"domain": "login"
|
| 198 |
},
|
| 199 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 200 |
-
"CONFIDENCE_THRESHOLDS": {
|
| 201 |
-
"no_helmet": 0.
|
| 202 |
-
"no_harness": 0.
|
| 203 |
-
"unsafe_posture": 0.
|
| 204 |
-
"unsafe_zone": 0.
|
| 205 |
-
"improper_tool_use": 0.
|
| 206 |
},
|
| 207 |
"MIN_VIOLATION_FRAMES": 1,
|
| 208 |
"VIOLATION_COOLDOWN": 30.0,
|
|
@@ -213,10 +213,10 @@ CONFIG = {
|
|
| 213 |
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
|
| 214 |
"TRACK_BUFFER": 90,
|
| 215 |
"TRACK_THRESH": 0.3,
|
| 216 |
-
"MATCH_THRESH": 0.
|
| 217 |
"SNAPSHOT_QUALITY": 95,
|
| 218 |
-
"MAX_WORKER_DISTANCE":
|
| 219 |
-
"TARGET_RESOLUTION": (384, 384)
|
| 220 |
}
|
| 221 |
|
| 222 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -548,11 +548,12 @@ def process_video(video_data, temp_dir):
|
|
| 548 |
|
| 549 |
worker_id_mapping = {}
|
| 550 |
unique_violations = {}
|
| 551 |
-
violation_frames = {}
|
| 552 |
start_time = time.time()
|
| 553 |
frame_skip = CONFIG["FRAME_SKIP"]
|
| 554 |
processed_frames = 0
|
| 555 |
last_yield_time = start_time
|
|
|
|
| 556 |
|
| 557 |
while processed_frames < total_frames:
|
| 558 |
batch_frames = []
|
|
@@ -583,7 +584,7 @@ def process_video(video_data, temp_dir):
|
|
| 583 |
break
|
| 584 |
|
| 585 |
try:
|
| 586 |
-
batch_frames_np = np.array(batch_frames)
|
| 587 |
batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
|
| 588 |
batch_frames_tensor = batch_frames_tensor.to(device)
|
| 589 |
if device.type == "cuda":
|
|
@@ -638,6 +639,7 @@ def process_video(video_data, temp_dir):
|
|
| 638 |
np.array([t["conf"] for t in track_inputs]),
|
| 639 |
np.array([t["cls"] for t in track_inputs])
|
| 640 |
)
|
|
|
|
| 641 |
|
| 642 |
for obj in tracked_objects:
|
| 643 |
tracker_id = obj['id']
|
|
@@ -648,10 +650,9 @@ def process_video(video_data, temp_dir):
|
|
| 648 |
if label is None:
|
| 649 |
continue
|
| 650 |
|
| 651 |
-
if not worker_id_mapping:
|
| 652 |
-
worker_id_mapping[tracker_id] =
|
| 653 |
-
|
| 654 |
-
worker_id_mapping[tracker_id] = worker_id_mapping[list(worker_id_mapping.keys())[0]]
|
| 655 |
|
| 656 |
worker_id = worker_id_mapping[tracker_id]
|
| 657 |
|
|
@@ -664,6 +665,7 @@ def process_video(video_data, temp_dir):
|
|
| 664 |
cap.release()
|
| 665 |
processing_time = time.time() - start_time
|
| 666 |
logger.info(f"Processing complete in {processing_time:.2f}s")
|
|
|
|
| 667 |
|
| 668 |
violations = []
|
| 669 |
for (worker_id, label), detection_time in unique_violations.items():
|
|
@@ -671,7 +673,7 @@ def process_video(video_data, temp_dir):
|
|
| 671 |
"worker_id": worker_id,
|
| 672 |
"violation": label,
|
| 673 |
"timestamp": detection_time,
|
| 674 |
-
"confidence": 0.0,
|
| 675 |
"frame_idx": violation_frames[(worker_id, label)]
|
| 676 |
})
|
| 677 |
|
|
@@ -680,7 +682,6 @@ def process_video(video_data, temp_dir):
|
|
| 680 |
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
|
| 681 |
return
|
| 682 |
|
| 683 |
-
# Reopen video to capture snapshots for violations
|
| 684 |
snapshots = []
|
| 685 |
cap = cv2.VideoCapture(video_path)
|
| 686 |
for violation in violations:
|
|
|
|
| 40 |
|
| 41 |
# ========================== # ByteTrack Implementation # ==========================
|
| 42 |
class BYTETracker:
|
| 43 |
+
def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.4, frame_rate=30):
|
| 44 |
self.track_thresh = track_thresh
|
| 45 |
self.track_buffer = track_buffer
|
| 46 |
+
self.match_thresh = match_thresh # Lowered to 0.4 to improve tracking sensitivity
|
| 47 |
self.frame_rate = frame_rate
|
| 48 |
self.next_id = 1
|
| 49 |
self.tracks = {}
|
|
|
|
| 159 |
iou = intersection_area / (box1_area + box2_area - intersection_area)
|
| 160 |
return iou
|
| 161 |
|
| 162 |
+
def _is_same_worker(self, pos1, pos2, threshold=100): # Reduced threshold for 384x384 frames
|
| 163 |
x1, y1 = pos1
|
| 164 |
x2, y2 = pos2
|
| 165 |
distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
|
|
|
| 197 |
"domain": "login"
|
| 198 |
},
|
| 199 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 200 |
+
"CONFIDENCE_THRESHOLDS": { # Lowered thresholds to improve detection
|
| 201 |
+
"no_helmet": 0.4,
|
| 202 |
+
"no_harness": 0.25,
|
| 203 |
+
"unsafe_posture": 0.25,
|
| 204 |
+
"unsafe_zone": 0.25,
|
| 205 |
+
"improper_tool_use": 0.25
|
| 206 |
},
|
| 207 |
"MIN_VIOLATION_FRAMES": 1,
|
| 208 |
"VIOLATION_COOLDOWN": 30.0,
|
|
|
|
| 213 |
"PARALLEL_WORKERS": max(1, cpu_count() - 1),
|
| 214 |
"TRACK_BUFFER": 90,
|
| 215 |
"TRACK_THRESH": 0.3,
|
| 216 |
+
"MATCH_THRESH": 0.4,
|
| 217 |
"SNAPSHOT_QUALITY": 95,
|
| 218 |
+
"MAX_WORKER_DISTANCE": 100, # Adjusted to match BYTETracker threshold
|
| 219 |
+
"TARGET_RESOLUTION": (384, 384)
|
| 220 |
}
|
| 221 |
|
| 222 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 548 |
|
| 549 |
worker_id_mapping = {}
|
| 550 |
unique_violations = {}
|
| 551 |
+
violation_frames = {}
|
| 552 |
start_time = time.time()
|
| 553 |
frame_skip = CONFIG["FRAME_SKIP"]
|
| 554 |
processed_frames = 0
|
| 555 |
last_yield_time = start_time
|
| 556 |
+
worker_counter = 1 # For assigning unique worker IDs
|
| 557 |
|
| 558 |
while processed_frames < total_frames:
|
| 559 |
batch_frames = []
|
|
|
|
| 584 |
break
|
| 585 |
|
| 586 |
try:
|
| 587 |
+
batch_frames_np = np.array(batch_frames)
|
| 588 |
batch_frames_tensor = torch.from_numpy(batch_frames_np).permute(0, 3, 1, 2).float() / 255.0
|
| 589 |
batch_frames_tensor = batch_frames_tensor.to(device)
|
| 590 |
if device.type == "cuda":
|
|
|
|
| 639 |
np.array([t["conf"] for t in track_inputs]),
|
| 640 |
np.array([t["cls"] for t in track_inputs])
|
| 641 |
)
|
| 642 |
+
logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
|
| 643 |
|
| 644 |
for obj in tracked_objects:
|
| 645 |
tracker_id = obj['id']
|
|
|
|
| 650 |
if label is None:
|
| 651 |
continue
|
| 652 |
|
| 653 |
+
if tracker_id not in worker_id_mapping:
|
| 654 |
+
worker_id_mapping[tracker_id] = worker_counter
|
| 655 |
+
worker_counter += 1
|
|
|
|
| 656 |
|
| 657 |
worker_id = worker_id_mapping[tracker_id]
|
| 658 |
|
|
|
|
| 665 |
cap.release()
|
| 666 |
processing_time = time.time() - start_time
|
| 667 |
logger.info(f"Processing complete in {processing_time:.2f}s")
|
| 668 |
+
logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")
|
| 669 |
|
| 670 |
violations = []
|
| 671 |
for (worker_id, label), detection_time in unique_violations.items():
|
|
|
|
| 673 |
"worker_id": worker_id,
|
| 674 |
"violation": label,
|
| 675 |
"timestamp": detection_time,
|
| 676 |
+
"confidence": 0.0,
|
| 677 |
"frame_idx": violation_frames[(worker_id, label)]
|
| 678 |
})
|
| 679 |
|
|
|
|
| 682 |
yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
|
| 683 |
return
|
| 684 |
|
|
|
|
| 685 |
snapshots = []
|
| 686 |
cap = cv2.VideoCapture(video_path)
|
| 687 |
for violation in violations:
|