Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,7 +39,6 @@ CONFIG = {
|
|
| 39 |
},
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 41 |
"FRAME_SKIP": 15, # Process every 15th frame
|
| 42 |
-
"MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
|
| 43 |
"CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
|
| 44 |
}
|
| 45 |
|
|
@@ -239,27 +238,20 @@ def process_video(video_data):
|
|
| 239 |
|
| 240 |
violations, snapshots = [], []
|
| 241 |
frame_count = 0
|
| 242 |
-
start_time = time.time()
|
| 243 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 244 |
-
max_frames = int(60 * fps) # Process up to 1 minute
|
| 245 |
|
| 246 |
# Track one snapshot per violation type
|
| 247 |
snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
|
| 248 |
|
| 249 |
while True:
|
| 250 |
ret, frame = video.read()
|
| 251 |
-
if not ret
|
| 252 |
break
|
| 253 |
|
| 254 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 255 |
frame_count += 1
|
| 256 |
continue
|
| 257 |
|
| 258 |
-
# Stop if processing time exceeds 25 seconds
|
| 259 |
-
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 260 |
-
logger.info("Processing time limit reached")
|
| 261 |
-
break
|
| 262 |
-
|
| 263 |
results = model(frame, device=device)
|
| 264 |
seen_violations = set()
|
| 265 |
for result in results:
|
|
@@ -268,11 +260,9 @@ def process_video(video_data):
|
|
| 268 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
|
| 269 |
# Only process specified violations
|
| 270 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 271 |
-
logger.warning(f"Unexpected detection: {label} (cls: {cls}, conf: {conf}) - ignored")
|
| 272 |
continue
|
| 273 |
# Apply confidence threshold
|
| 274 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 275 |
-
logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
|
| 276 |
continue
|
| 277 |
if label in seen_violations:
|
| 278 |
continue
|
|
@@ -313,7 +303,8 @@ def process_video(video_data):
|
|
| 313 |
"snapshots": [],
|
| 314 |
"score": 100,
|
| 315 |
"salesforce_record_id": None,
|
| 316 |
-
"violation_details_url": ""
|
|
|
|
| 317 |
}
|
| 318 |
|
| 319 |
score = calculate_safety_score(violations)
|
|
@@ -325,7 +316,8 @@ def process_video(video_data):
|
|
| 325 |
"snapshots": snapshots,
|
| 326 |
"score": score,
|
| 327 |
"salesforce_record_id": report_id,
|
| 328 |
-
"violation_details_url": final_pdf_url
|
|
|
|
| 329 |
}
|
| 330 |
except Exception as e:
|
| 331 |
logger.error(f"Error processing video: {e}")
|
|
@@ -334,7 +326,8 @@ def process_video(video_data):
|
|
| 334 |
"snapshots": [],
|
| 335 |
"score": 100,
|
| 336 |
"salesforce_record_id": None,
|
| 337 |
-
"violation_details_url": ""
|
|
|
|
| 338 |
}
|
| 339 |
|
| 340 |
# ==========================
|
|
@@ -348,14 +341,28 @@ def gradio_interface(video_file):
|
|
| 348 |
video_data = f.read()
|
| 349 |
result = process_video(video_data)
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
violation_table = "No violations detected."
|
| 352 |
if result["violations"]:
|
| 353 |
-
header = "| Violation
|
| 354 |
-
separator = "
|
| 355 |
rows = []
|
| 356 |
for v in result["violations"]:
|
| 357 |
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
rows.append(row)
|
| 360 |
violation_table = header + separator + "\n".join(rows)
|
| 361 |
|
|
|
|
| 39 |
},
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 41 |
"FRAME_SKIP": 15, # Process every 15th frame
|
|
|
|
| 42 |
"CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
|
| 43 |
}
|
| 44 |
|
|
|
|
| 238 |
|
| 239 |
violations, snapshots = [], []
|
| 240 |
frame_count = 0
|
|
|
|
| 241 |
fps = video.get(cv2.CAP_PROP_FPS)
|
|
|
|
| 242 |
|
| 243 |
# Track one snapshot per violation type
|
| 244 |
snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
|
| 245 |
|
| 246 |
while True:
|
| 247 |
ret, frame = video.read()
|
| 248 |
+
if not ret:
|
| 249 |
break
|
| 250 |
|
| 251 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 252 |
frame_count += 1
|
| 253 |
continue
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
results = model(frame, device=device)
|
| 256 |
seen_violations = set()
|
| 257 |
for result in results:
|
|
|
|
| 260 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
|
| 261 |
# Only process specified violations
|
| 262 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
|
|
|
| 263 |
continue
|
| 264 |
# Apply confidence threshold
|
| 265 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
|
|
|
| 266 |
continue
|
| 267 |
if label in seen_violations:
|
| 268 |
continue
|
|
|
|
| 303 |
"snapshots": [],
|
| 304 |
"score": 100,
|
| 305 |
"salesforce_record_id": None,
|
| 306 |
+
"violation_details_url": "",
|
| 307 |
+
"message": "No violations detected here."
|
| 308 |
}
|
| 309 |
|
| 310 |
score = calculate_safety_score(violations)
|
|
|
|
| 316 |
"snapshots": snapshots,
|
| 317 |
"score": score,
|
| 318 |
"salesforce_record_id": report_id,
|
| 319 |
+
"violation_details_url": final_pdf_url,
|
| 320 |
+
"message": ""
|
| 321 |
}
|
| 322 |
except Exception as e:
|
| 323 |
logger.error(f"Error processing video: {e}")
|
|
|
|
| 326 |
"snapshots": [],
|
| 327 |
"score": 100,
|
| 328 |
"salesforce_record_id": None,
|
| 329 |
+
"violation_details_url": "",
|
| 330 |
+
"message": "Error processing video."
|
| 331 |
}
|
| 332 |
|
| 333 |
# ==========================
|
|
|
|
| 341 |
video_data = f.read()
|
| 342 |
result = process_video(video_data)
|
| 343 |
|
| 344 |
+
if result.get("message"):
|
| 345 |
+
# Show message (like "No violations detected here.")
|
| 346 |
+
return result["message"], f"Safety Score: {result['score']}%", "", "N/A", "N/A"
|
| 347 |
+
|
| 348 |
violation_table = "No violations detected."
|
| 349 |
if result["violations"]:
|
| 350 |
+
header = "| Violation | Timestamp | Confidence | Violation Details |\n"
|
| 351 |
+
separator = "|-------------------|-----------|------------|---------------------------------|\n"
|
| 352 |
rows = []
|
| 353 |
for v in result["violations"]:
|
| 354 |
display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
|
| 355 |
+
# Provide clearer human-readable violation explanation
|
| 356 |
+
if v["violation"] == "no_helmet":
|
| 357 |
+
details = "Employee not wearing helmet"
|
| 358 |
+
elif v["violation"] == "no_harness":
|
| 359 |
+
details = "Employee not wearing proper harness"
|
| 360 |
+
elif v["violation"] == "unsafe_posture":
|
| 361 |
+
details = "Employee in unsafe posture/zone"
|
| 362 |
+
else:
|
| 363 |
+
details = "Violation detected"
|
| 364 |
+
|
| 365 |
+
row = f"| {display_name:<17} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {details:<31} |"
|
| 366 |
rows.append(row)
|
| 367 |
violation_table = header + separator + "\n".join(rows)
|
| 368 |
|