Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,46 +18,40 @@ from retrying import retry
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
-
"MODEL_PATH": "yolov8_safety.pt",
|
| 22 |
-
"FALLBACK_MODEL_PATH": "yolov8n.pt",
|
| 23 |
"OUTPUT_DIR": "static/output",
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
| 26 |
1: "no_harness",
|
| 27 |
2: "unsafe_posture"
|
| 28 |
},
|
| 29 |
-
"DISPLAY_NAMES": {
|
| 30 |
-
"no_helmet": "
|
| 31 |
-
"no_harness": "
|
| 32 |
-
"unsafe_posture": "Unsafe Posture"
|
| 33 |
},
|
| 34 |
"SF_CREDENTIALS": {
|
| 35 |
"username": "your_username@safety.com",
|
| 36 |
"password": "your_password",
|
| 37 |
"security_token": "your_security_token",
|
| 38 |
-
"domain": "login"
|
| 39 |
},
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 41 |
-
"FRAME_SKIP": 15,
|
| 42 |
-
"
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
# Setup logging
|
| 46 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 47 |
logger = logging.getLogger(__name__)
|
| 48 |
|
| 49 |
-
# Ensure output directory exists
|
| 50 |
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
|
| 51 |
|
| 52 |
-
# ==========================
|
| 53 |
-
# Device Setup
|
| 54 |
-
# ==========================
|
| 55 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
logger.info(f"Using device: {device}")
|
| 57 |
|
| 58 |
-
# ==========================
|
| 59 |
-
# Model Loading
|
| 60 |
-
# ==========================
|
| 61 |
def load_model():
|
| 62 |
try:
|
| 63 |
model_path = CONFIG["MODEL_PATH"]
|
|
@@ -75,9 +69,6 @@ def load_model():
|
|
| 75 |
|
| 76 |
model = load_model()
|
| 77 |
|
| 78 |
-
# ==========================
|
| 79 |
-
# Salesforce Integration
|
| 80 |
-
# ==========================
|
| 81 |
@retry(stop_max_attempt_number=2, wait_fixed=1000)
|
| 82 |
def connect_to_salesforce():
|
| 83 |
try:
|
|
@@ -207,9 +198,6 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
|
| 207 |
logger.error(f"Salesforce record creation failed: {e}")
|
| 208 |
return None, ""
|
| 209 |
|
| 210 |
-
# ==========================
|
| 211 |
-
# Safety Score Calculation
|
| 212 |
-
# ==========================
|
| 213 |
def calculate_safety_score(violations):
|
| 214 |
penalties = {
|
| 215 |
"no_helmet": 25,
|
|
@@ -222,9 +210,6 @@ def calculate_safety_score(violations):
|
|
| 222 |
score -= penalties[v["violation"]]
|
| 223 |
return max(score, 0)
|
| 224 |
|
| 225 |
-
# ==========================
|
| 226 |
-
# Video Processing
|
| 227 |
-
# ==========================
|
| 228 |
def process_video(video_data):
|
| 229 |
try:
|
| 230 |
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
|
|
@@ -238,30 +223,32 @@ def process_video(video_data):
|
|
| 238 |
|
| 239 |
violations, snapshots = [], []
|
| 240 |
frame_count = 0
|
|
|
|
| 241 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 242 |
|
| 243 |
-
# Track one snapshot per violation type
|
| 244 |
snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
|
| 245 |
|
| 246 |
while True:
|
| 247 |
ret, frame = video.read()
|
| 248 |
if not ret:
|
| 249 |
-
break
|
| 250 |
|
| 251 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 252 |
frame_count += 1
|
| 253 |
continue
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
results = model(frame, device=device)
|
| 256 |
seen_violations = set()
|
| 257 |
for result in results:
|
| 258 |
for box in result.boxes:
|
| 259 |
cls, conf = int(box.cls), float(box.conf)
|
| 260 |
-
label = CONFIG["VIOLATION_LABELS"].get(cls,
|
| 261 |
-
# Only process specified violations
|
| 262 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 263 |
continue
|
| 264 |
-
# Apply confidence threshold
|
| 265 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 266 |
continue
|
| 267 |
if label in seen_violations:
|
|
@@ -277,7 +264,6 @@ def process_video(video_data):
|
|
| 277 |
}
|
| 278 |
violations.append(violation)
|
| 279 |
|
| 280 |
-
# Save only one snapshot per violation type
|
| 281 |
if not snapshot_taken[label]:
|
| 282 |
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
|
| 283 |
cv2.imwrite(snapshot_path, frame)
|
|
@@ -304,7 +290,7 @@ def process_video(video_data):
|
|
| 304 |
"score": 100,
|
| 305 |
"salesforce_record_id": None,
|
| 306 |
"violation_details_url": "",
|
| 307 |
-
"message": "No violations detected
|
| 308 |
}
|
| 309 |
|
| 310 |
score = calculate_safety_score(violations)
|
|
@@ -327,12 +313,9 @@ def process_video(video_data):
|
|
| 327 |
"score": 100,
|
| 328 |
"salesforce_record_id": None,
|
| 329 |
"violation_details_url": "",
|
| 330 |
-
"message": "Error processing video
|
| 331 |
}
|
| 332 |
|
| 333 |
-
# ==========================
|
| 334 |
-
# Gradio Interface
|
| 335 |
-
# ==========================
|
| 336 |
def gradio_interface(video_file):
|
| 337 |
if not video_file:
|
| 338 |
return "No file uploaded.", "", "No file uploaded.", "", ""
|
|
@@ -342,34 +325,26 @@ def gradio_interface(video_file):
|
|
| 342 |
result = process_video(video_data)
|
| 343 |
|
| 344 |
if result.get("message"):
|
| 345 |
-
#
|
| 346 |
-
return result["message"],
|
| 347 |
|
| 348 |
violation_table = "No violations detected."
|
| 349 |
if result["violations"]:
|
| 350 |
-
header = "| Violation
|
| 351 |
-
separator = "
|
| 352 |
rows = []
|
|
|
|
| 353 |
for v in result["violations"]:
|
| 354 |
-
display_name =
|
| 355 |
-
|
| 356 |
-
if v["violation"] == "no_helmet":
|
| 357 |
-
details = "Employee not wearing helmet"
|
| 358 |
-
elif v["violation"] == "no_harness":
|
| 359 |
-
details = "Employee not wearing proper harness"
|
| 360 |
-
elif v["violation"] == "unsafe_posture":
|
| 361 |
-
details = "Employee in unsafe posture/zone"
|
| 362 |
-
else:
|
| 363 |
-
details = "Violation detected"
|
| 364 |
-
|
| 365 |
-
row = f"| {display_name:<17} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {details:<31} |"
|
| 366 |
rows.append(row)
|
| 367 |
violation_table = header + separator + "\n".join(rows)
|
| 368 |
|
| 369 |
snapshots_text = "No snapshots captured."
|
| 370 |
if result["snapshots"]:
|
|
|
|
| 371 |
snapshots_text = "\n".join(
|
| 372 |
-
f"- Snapshot for {
|
| 373 |
for s in result["snapshots"]
|
| 374 |
)
|
| 375 |
|
|
@@ -395,7 +370,7 @@ interface = gr.Interface(
|
|
| 395 |
gr.Textbox(label="Violation Details URL")
|
| 396 |
],
|
| 397 |
title="Worksite Safety Violation Analyzer",
|
| 398 |
-
description="Upload site videos to detect safety violations (
|
| 399 |
)
|
| 400 |
|
| 401 |
if __name__ == "__main__":
|
|
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
+
"MODEL_PATH": "yolov8_safety.pt",
|
| 22 |
+
"FALLBACK_MODEL_PATH": "yolov8n.pt",
|
| 23 |
"OUTPUT_DIR": "static/output",
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
| 26 |
1: "no_harness",
|
| 27 |
2: "unsafe_posture"
|
| 28 |
},
|
| 29 |
+
"DISPLAY_NAMES": {
|
| 30 |
+
"no_helmet": "No Helmet Violation",
|
| 31 |
+
"no_harness": "No Harness Violation",
|
| 32 |
+
"unsafe_posture": "Unsafe Posture Violation"
|
| 33 |
},
|
| 34 |
"SF_CREDENTIALS": {
|
| 35 |
"username": "your_username@safety.com",
|
| 36 |
"password": "your_password",
|
| 37 |
"security_token": "your_security_token",
|
| 38 |
+
"domain": "login"
|
| 39 |
},
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 41 |
+
"FRAME_SKIP": 15,
|
| 42 |
+
"MAX_PROCESSING_TIME": 30, # Updated to 30 seconds
|
| 43 |
+
"CONFIDENCE_THRESHOLD": 0.5
|
| 44 |
}
|
| 45 |
|
| 46 |
# Setup logging
|
| 47 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 48 |
logger = logging.getLogger(__name__)
|
| 49 |
|
|
|
|
| 50 |
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
logger.info(f"Using device: {device}")
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
def load_model():
|
| 56 |
try:
|
| 57 |
model_path = CONFIG["MODEL_PATH"]
|
|
|
|
| 69 |
|
| 70 |
model = load_model()
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
@retry(stop_max_attempt_number=2, wait_fixed=1000)
|
| 73 |
def connect_to_salesforce():
|
| 74 |
try:
|
|
|
|
| 198 |
logger.error(f"Salesforce record creation failed: {e}")
|
| 199 |
return None, ""
|
| 200 |
|
|
|
|
|
|
|
|
|
|
| 201 |
def calculate_safety_score(violations):
|
| 202 |
penalties = {
|
| 203 |
"no_helmet": 25,
|
|
|
|
| 210 |
score -= penalties[v["violation"]]
|
| 211 |
return max(score, 0)
|
| 212 |
|
|
|
|
|
|
|
|
|
|
| 213 |
def process_video(video_data):
|
| 214 |
try:
|
| 215 |
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
|
|
|
|
| 223 |
|
| 224 |
violations, snapshots = [], []
|
| 225 |
frame_count = 0
|
| 226 |
+
start_time = time.time()
|
| 227 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 228 |
|
|
|
|
| 229 |
snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
|
| 230 |
|
| 231 |
while True:
|
| 232 |
ret, frame = video.read()
|
| 233 |
if not ret:
|
| 234 |
+
break # End of video
|
| 235 |
|
| 236 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 237 |
frame_count += 1
|
| 238 |
continue
|
| 239 |
|
| 240 |
+
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 241 |
+
logger.info("Processing time limit of 30 seconds reached")
|
| 242 |
+
break
|
| 243 |
+
|
| 244 |
results = model(frame, device=device)
|
| 245 |
seen_violations = set()
|
| 246 |
for result in results:
|
| 247 |
for box in result.boxes:
|
| 248 |
cls, conf = int(box.cls), float(box.conf)
|
| 249 |
+
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
|
|
|
| 250 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 251 |
continue
|
|
|
|
| 252 |
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 253 |
continue
|
| 254 |
if label in seen_violations:
|
|
|
|
| 264 |
}
|
| 265 |
violations.append(violation)
|
| 266 |
|
|
|
|
| 267 |
if not snapshot_taken[label]:
|
| 268 |
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
|
| 269 |
cv2.imwrite(snapshot_path, frame)
|
|
|
|
| 290 |
"score": 100,
|
| 291 |
"salesforce_record_id": None,
|
| 292 |
"violation_details_url": "",
|
| 293 |
+
"message": "No violations detected in the video."
|
| 294 |
}
|
| 295 |
|
| 296 |
score = calculate_safety_score(violations)
|
|
|
|
| 313 |
"score": 100,
|
| 314 |
"salesforce_record_id": None,
|
| 315 |
"violation_details_url": "",
|
| 316 |
+
"message": f"Error processing video: {e}"
|
| 317 |
}
|
| 318 |
|
|
|
|
|
|
|
|
|
|
| 319 |
def gradio_interface(video_file):
|
| 320 |
if not video_file:
|
| 321 |
return "No file uploaded.", "", "No file uploaded.", "", ""
|
|
|
|
| 325 |
result = process_video(video_data)
|
| 326 |
|
| 327 |
if result.get("message"):
|
| 328 |
+
# If message present (either no violations or error), show it plainly
|
| 329 |
+
return result["message"], "", "", "", ""
|
| 330 |
|
| 331 |
violation_table = "No violations detected."
|
| 332 |
if result["violations"]:
|
| 333 |
+
header = "| Violation | Timestamp (s) | Confidence | Bounding Box |\n"
|
| 334 |
+
separator = "|------------------------|---------------|------------|--------------------------|\n"
|
| 335 |
rows = []
|
| 336 |
+
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 337 |
for v in result["violations"]:
|
| 338 |
+
display_name = violation_name_map.get(v["violation"], v["violation"])
|
| 339 |
+
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['bounding_box']} |"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
rows.append(row)
|
| 341 |
violation_table = header + separator + "\n".join(rows)
|
| 342 |
|
| 343 |
snapshots_text = "No snapshots captured."
|
| 344 |
if result["snapshots"]:
|
| 345 |
+
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 346 |
snapshots_text = "\n".join(
|
| 347 |
+
f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: "
|
| 348 |
for s in result["snapshots"]
|
| 349 |
)
|
| 350 |
|
|
|
|
| 370 |
gr.Textbox(label="Violation Details URL")
|
| 371 |
],
|
| 372 |
title="Worksite Safety Violation Analyzer",
|
| 373 |
+
description="Upload site videos to detect safety violations (No Helmet Violation, No Harness Violation, Unsafe Posture Violation). Non-violations are ignored."
|
| 374 |
)
|
| 375 |
|
| 376 |
if __name__ == "__main__":
|