Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,8 +18,8 @@ from retrying import retry
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
-
"MODEL_PATH": "yolov8_safety.pt",
|
| 22 |
-
"FALLBACK_MODEL": "yolov8n.pt",
|
| 23 |
"OUTPUT_DIR": "static/output",
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
|
@@ -40,7 +40,8 @@ CONFIG = {
|
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 41 |
"FRAME_SKIP": 15,
|
| 42 |
"MAX_PROCESSING_TIME": 30,
|
| 43 |
-
"CONFIDENCE_THRESHOLD": 0.5
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
# Setup logging
|
|
@@ -54,14 +55,12 @@ logger.info(f"Using device: {device}")
|
|
| 54 |
|
| 55 |
def load_model():
|
| 56 |
try:
|
| 57 |
-
# Check if the model file exists
|
| 58 |
if os.path.isfile(CONFIG["MODEL_PATH"]):
|
| 59 |
model_path = CONFIG["MODEL_PATH"]
|
| 60 |
logger.info(f"Model loaded: {model_path}")
|
| 61 |
else:
|
| 62 |
model_path = CONFIG["FALLBACK_MODEL"]
|
| 63 |
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
|
| 64 |
-
# Download fallback model if necessary
|
| 65 |
if not os.path.isfile(model_path):
|
| 66 |
logger.info(f"Downloading fallback model: {model_path}")
|
| 67 |
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
|
|
@@ -73,6 +72,33 @@ def load_model():
|
|
| 73 |
|
| 74 |
model = load_model()
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# ==========================
|
| 77 |
# Salesforce Integration
|
| 78 |
# ==========================
|
|
@@ -81,7 +107,7 @@ def connect_to_salesforce():
|
|
| 81 |
try:
|
| 82 |
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
|
| 83 |
logger.info("Connected to Salesforce")
|
| 84 |
-
sf.describe()
|
| 85 |
return sf
|
| 86 |
except Exception as e:
|
| 87 |
logger.error(f"Salesforce connection failed: {e}")
|
|
@@ -229,7 +255,7 @@ def process_video(video_data):
|
|
| 229 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 230 |
|
| 231 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 232 |
-
|
| 233 |
|
| 234 |
while True:
|
| 235 |
ret, frame = video.read()
|
|
@@ -245,6 +271,7 @@ def process_video(video_data):
|
|
| 245 |
break
|
| 246 |
|
| 247 |
results = model(frame, device=device)
|
|
|
|
| 248 |
for result in results:
|
| 249 |
for box in result.boxes:
|
| 250 |
cls, conf = int(box.cls), float(box.conf)
|
|
@@ -252,31 +279,82 @@ def process_video(video_data):
|
|
| 252 |
if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 253 |
continue
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
continue
|
| 258 |
-
|
| 259 |
-
detected_violations[violation_key] = {
|
| 260 |
-
"frame": frame_count,
|
| 261 |
"violation": label,
|
| 262 |
"confidence": round(conf, 2),
|
| 263 |
-
"bounding_box":
|
| 264 |
-
"timestamp": frame_count / fps
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
cv2.imwrite(snapshot_path, frame)
|
| 271 |
with open(snapshot_path, "rb") as img_file:
|
| 272 |
img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
|
| 273 |
snapshots.append({
|
| 274 |
-
"violation":
|
| 275 |
"frame": frame_count,
|
| 276 |
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
|
| 277 |
"snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
|
| 278 |
})
|
| 279 |
-
snapshot_taken[
|
| 280 |
|
| 281 |
frame_count += 1
|
| 282 |
|
|
@@ -334,13 +412,13 @@ def gradio_interface(video_file):
|
|
| 334 |
|
| 335 |
violation_table = "No violations detected."
|
| 336 |
if result["violations"]:
|
| 337 |
-
header = "| Violation | Timestamp (s) | Confidence |
|
| 338 |
-
separator = "
|
| 339 |
rows = []
|
| 340 |
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 341 |
for v in result["violations"]:
|
| 342 |
display_name = violation_name_map.get(v["violation"], v["violation"])
|
| 343 |
-
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['
|
| 344 |
rows.append(row)
|
| 345 |
violation_table = header + separator + "\n".join(rows)
|
| 346 |
|
|
@@ -380,4 +458,4 @@ interface = gr.Interface(
|
|
| 380 |
|
| 381 |
if __name__ == "__main__":
|
| 382 |
logger.info("Launching Safety Analyzer App...")
|
| 383 |
-
interface.launch()
|
|
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
+
"MODEL_PATH": "yolov8_safety.pt",
|
| 22 |
+
"FALLBACK_MODEL": "yolov8n.pt",
|
| 23 |
"OUTPUT_DIR": "static/output",
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
|
|
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 41 |
"FRAME_SKIP": 15,
|
| 42 |
"MAX_PROCESSING_TIME": 30,
|
| 43 |
+
"CONFIDENCE_THRESHOLD": 0.5,
|
| 44 |
+
"IOU_THRESHOLD": 0.5 # Added for worker tracking
|
| 45 |
}
|
| 46 |
|
| 47 |
# Setup logging
|
|
|
|
| 55 |
|
| 56 |
def load_model():
|
| 57 |
try:
|
|
|
|
| 58 |
if os.path.isfile(CONFIG["MODEL_PATH"]):
|
| 59 |
model_path = CONFIG["MODEL_PATH"]
|
| 60 |
logger.info(f"Model loaded: {model_path}")
|
| 61 |
else:
|
| 62 |
model_path = CONFIG["FALLBACK_MODEL"]
|
| 63 |
logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
|
|
|
|
| 64 |
if not os.path.isfile(model_path):
|
| 65 |
logger.info(f"Downloading fallback model: {model_path}")
|
| 66 |
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
|
|
|
|
| 72 |
|
| 73 |
model = load_model()
|
| 74 |
|
| 75 |
+
# ==========================
|
| 76 |
+
# Helper Functions
|
| 77 |
+
# ==========================
|
| 78 |
+
def calculate_iou(box1, box2):
|
| 79 |
+
"""Calculate Intersection over Union (IoU) for two bounding boxes."""
|
| 80 |
+
x1, y1, w1, h1 = box1
|
| 81 |
+
x2, y2, w2, h2 = box2
|
| 82 |
+
|
| 83 |
+
# Convert to top-left and bottom-right coordinates
|
| 84 |
+
x1_min, y1_min = x1 - w1/2, y1 - h1/2
|
| 85 |
+
x1_max, y1_max = x1 + w1/2, y1 + h1/2
|
| 86 |
+
x2_min, y2_min = x2 - w2/2, y2 - h2/2
|
| 87 |
+
x2_max, y2_max = x2 + w2/2, y2 + h2/2
|
| 88 |
+
|
| 89 |
+
# Calculate intersection
|
| 90 |
+
x_min = max(x1_min, x2_min)
|
| 91 |
+
y_min = max(y1_min, y2_min)
|
| 92 |
+
x_max = min(x1_max, x2_max)
|
| 93 |
+
y_max = min(y1_max, y2_max)
|
| 94 |
+
|
| 95 |
+
intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
|
| 96 |
+
area1 = w1 * h1
|
| 97 |
+
area2 = w2 * h2
|
| 98 |
+
union = area1 + area2 - intersection
|
| 99 |
+
|
| 100 |
+
return intersection / union if union > 0 else 0
|
| 101 |
+
|
| 102 |
# ==========================
|
| 103 |
# Salesforce Integration
|
| 104 |
# ==========================
|
|
|
|
| 107 |
try:
|
| 108 |
sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
|
| 109 |
logger.info("Connected to Salesforce")
|
| 110 |
+
sf.describe()
|
| 111 |
return sf
|
| 112 |
except Exception as e:
|
| 113 |
logger.error(f"Salesforce connection failed: {e}")
|
|
|
|
| 255 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 256 |
|
| 257 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 258 |
+
workers = [] # List to track workers: [{"id": int, "violations": set(), "bbox": list, "last_frame": int}]
|
| 259 |
|
| 260 |
while True:
|
| 261 |
ret, frame = video.read()
|
|
|
|
| 271 |
break
|
| 272 |
|
| 273 |
results = model(frame, device=device)
|
| 274 |
+
current_detections = []
|
| 275 |
for result in results:
|
| 276 |
for box in result.boxes:
|
| 277 |
cls, conf = int(box.cls), float(box.conf)
|
|
|
|
| 279 |
if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 280 |
continue
|
| 281 |
|
| 282 |
+
bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
|
| 283 |
+
current_detections.append({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
"violation": label,
|
| 285 |
"confidence": round(conf, 2),
|
| 286 |
+
"bounding_box": bbox,
|
| 287 |
+
"timestamp": frame_count / fps,
|
| 288 |
+
"frame": frame_count
|
| 289 |
+
})
|
| 290 |
+
|
| 291 |
+
# Assign detections to workers
|
| 292 |
+
for detection in current_detections:
|
| 293 |
+
matched_worker = None
|
| 294 |
+
max_iou = 0
|
| 295 |
+
for worker in workers:
|
| 296 |
+
iou = calculate_iou(detection["bounding_box"], worker["bbox"])
|
| 297 |
+
if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
|
| 298 |
+
max_iou = iou
|
| 299 |
+
matched_worker = worker
|
| 300 |
+
|
| 301 |
+
if matched_worker:
|
| 302 |
+
# Update existing worker
|
| 303 |
+
if detection["violation"] not in matched_worker["violations"]:
|
| 304 |
+
matched_worker["violations"].add(detection["violation"])
|
| 305 |
+
violations.append({
|
| 306 |
+
"frame": frame_count,
|
| 307 |
+
"violation": detection["violation"],
|
| 308 |
+
"confidence": detection["confidence"],
|
| 309 |
+
"bounding_box": detection["bounding_box"],
|
| 310 |
+
"timestamp": detection["timestamp"],
|
| 311 |
+
"worker_id": matched_worker["id"]
|
| 312 |
+
})
|
| 313 |
+
# Take snapshot if not already taken for this violation type
|
| 314 |
+
if not snapshot_taken[detection["violation"]]:
|
| 315 |
+
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{detection['violation']}.jpg")
|
| 316 |
+
cv2.imwrite(snapshot_path, frame)
|
| 317 |
+
with open(snapshot_path, "rb") as img_file:
|
| 318 |
+
img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
|
| 319 |
+
snapshots.append({
|
| 320 |
+
"violation": detection["violation"],
|
| 321 |
+
"frame": frame_count,
|
| 322 |
+
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
|
| 323 |
+
"snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
|
| 324 |
+
})
|
| 325 |
+
snapshot_taken[detection["violation"]] = True
|
| 326 |
+
matched_worker["bbox"] = detection["bounding_box"]
|
| 327 |
+
matched_worker["last_frame"] = frame_count
|
| 328 |
+
else:
|
| 329 |
+
# New worker
|
| 330 |
+
worker_id = len(workers) + 1
|
| 331 |
+
workers.append({
|
| 332 |
+
"id": worker_id,
|
| 333 |
+
"violations": {detection["violation"]},
|
| 334 |
+
"bbox": detection["bounding_box"],
|
| 335 |
+
"last_frame": frame_count
|
| 336 |
+
})
|
| 337 |
+
violations.append({
|
| 338 |
+
"frame": frame_count,
|
| 339 |
+
"violation": detection["violation"],
|
| 340 |
+
"confidence": detection["confidence"],
|
| 341 |
+
"bounding_box": detection["bounding_box"],
|
| 342 |
+
"timestamp": detection["timestamp"],
|
| 343 |
+
"worker_id": worker_id
|
| 344 |
+
})
|
| 345 |
+
# Take snapshot if not already taken for this violation type
|
| 346 |
+
if not snapshot_taken[detection["violation"]]:
|
| 347 |
+
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{detection['violation']}.jpg")
|
| 348 |
cv2.imwrite(snapshot_path, frame)
|
| 349 |
with open(snapshot_path, "rb") as img_file:
|
| 350 |
img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
|
| 351 |
snapshots.append({
|
| 352 |
+
"violation": detection["violation"],
|
| 353 |
"frame": frame_count,
|
| 354 |
"snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
|
| 355 |
"snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
|
| 356 |
})
|
| 357 |
+
snapshot_taken[detection["violation"]] = True
|
| 358 |
|
| 359 |
frame_count += 1
|
| 360 |
|
|
|
|
| 412 |
|
| 413 |
violation_table = "No violations detected."
|
| 414 |
if result["violations"]:
|
| 415 |
+
header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
|
| 416 |
+
separator = "|------------------------|---------------|------------|-----------|\n"
|
| 417 |
rows = []
|
| 418 |
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 419 |
for v in result["violations"]:
|
| 420 |
display_name = violation_name_map.get(v["violation"], v["violation"])
|
| 421 |
+
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
|
| 422 |
rows.append(row)
|
| 423 |
violation_table = header + separator + "\n".join(rows)
|
| 424 |
|
|
|
|
| 458 |
|
| 459 |
if __name__ == "__main__":
|
| 460 |
logger.info("Launching Safety Analyzer App...")
|
| 461 |
+
interface.launch()
|