Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,8 +24,7 @@ CONFIG = {
|
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
| 26 |
1: "no_harness",
|
| 27 |
-
2: "unsafe_posture"
|
| 28 |
-
3: "unsafe_zone"
|
| 29 |
},
|
| 30 |
"SF_CREDENTIALS": {
|
| 31 |
"username": "prashanth1ai@safety.com",
|
|
@@ -34,6 +33,7 @@ CONFIG = {
|
|
| 34 |
"domain": "login"
|
| 35 |
},
|
| 36 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
|
|
|
| 37 |
}
|
| 38 |
|
| 39 |
# Setup logging
|
|
@@ -196,15 +196,17 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
|
| 196 |
# ==========================
|
| 197 |
# Safety Score Calculation
|
| 198 |
# ==========================
|
| 199 |
-
def calculate_safety_score(violations):
|
| 200 |
penalties = {
|
| 201 |
-
"no_helmet": 25,
|
| 202 |
-
"no_harness": 30,
|
| 203 |
-
"unsafe_posture": 20
|
| 204 |
}
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
|
| 209 |
# ==========================
|
| 210 |
# Video Processing
|
|
@@ -222,6 +224,8 @@ def process_video(video_data):
|
|
| 222 |
|
| 223 |
violations, snapshots = [], []
|
| 224 |
frame_count = 0
|
|
|
|
|
|
|
| 225 |
|
| 226 |
while True:
|
| 227 |
ret, frame = video.read()
|
|
@@ -229,16 +233,13 @@ def process_video(video_data):
|
|
| 229 |
break
|
| 230 |
|
| 231 |
results = model(frame, device=device)
|
| 232 |
-
seen_violations = set()
|
| 233 |
for result in results:
|
| 234 |
for box in result.boxes:
|
| 235 |
cls, conf = int(box.cls), float(box.conf)
|
| 236 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
|
| 237 |
-
|
| 238 |
-
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 239 |
continue
|
| 240 |
-
if label in seen_violations:
|
| 241 |
-
continue # Skip if this violation type was already recorded in this frame
|
| 242 |
seen_violations.add(label)
|
| 243 |
|
| 244 |
violation = {
|
|
@@ -262,10 +263,13 @@ def process_video(video_data):
|
|
| 262 |
})
|
| 263 |
|
| 264 |
frame_count += 1
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
video.release()
|
| 267 |
os.remove(video_path)
|
| 268 |
-
score = calculate_safety_score(violations)
|
| 269 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 270 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
| 271 |
|
|
@@ -305,9 +309,11 @@ def gradio_interface(video_file):
|
|
| 305 |
for v in result["violations"]:
|
| 306 |
violation_name = v["violation"]
|
| 307 |
if violation_name == "no_helmet":
|
| 308 |
-
violation_name = "
|
|
|
|
|
|
|
| 309 |
else:
|
| 310 |
-
violation_name =
|
| 311 |
row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
|
| 312 |
rows.append(row)
|
| 313 |
violation_table = header + separator + "\n".join(rows)
|
|
@@ -315,7 +321,7 @@ def gradio_interface(video_file):
|
|
| 315 |
snapshots_text = "No snapshots captured."
|
| 316 |
if result["snapshots"]:
|
| 317 |
snapshots_text = "\n".join(
|
| 318 |
-
f"- Snapshot for {s['violation'].replace('
|
| 319 |
for s in result["snapshots"]
|
| 320 |
)
|
| 321 |
|
|
@@ -341,7 +347,7 @@ interface = gr.Interface(
|
|
| 341 |
gr.Textbox(label="Violation Details URL")
|
| 342 |
],
|
| 343 |
title="Worksite Safety Violation Analyzer",
|
| 344 |
-
description="Upload site videos to detect safety violations (e.g.,
|
| 345 |
)
|
| 346 |
|
| 347 |
if __name__ == "__main__":
|
|
|
|
| 24 |
"VIOLATION_LABELS": {
|
| 25 |
0: "no_helmet",
|
| 26 |
1: "no_harness",
|
| 27 |
+
2: "unsafe_posture"
|
|
|
|
| 28 |
},
|
| 29 |
"SF_CREDENTIALS": {
|
| 30 |
"username": "prashanth1ai@safety.com",
|
|
|
|
| 33 |
"domain": "login"
|
| 34 |
},
|
| 35 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 36 |
+
"MAX_PROCESSING_TIME": 30 # Reduced to 30 seconds
|
| 37 |
}
|
| 38 |
|
| 39 |
# Setup logging
|
|
|
|
| 196 |
# ==========================
|
| 197 |
# Safety Score Calculation
|
| 198 |
# ==========================
|
| 199 |
+
def calculate_safety_score(violations, total_frames):
|
| 200 |
penalties = {
|
| 201 |
+
"no_helmet": 25 / total_frames, # Normalize penalty by video length
|
| 202 |
+
"no_harness": 30 / total_frames,
|
| 203 |
+
"unsafe_posture": 20 / total_frames
|
| 204 |
}
|
| 205 |
+
score = 100
|
| 206 |
+
for v in violations:
|
| 207 |
+
penalty = penalties.get(v["violation"], 0)
|
| 208 |
+
score -= penalty
|
| 209 |
+
return max(round(score, 2), 0)
|
| 210 |
|
| 211 |
# ==========================
|
| 212 |
# Video Processing
|
|
|
|
| 224 |
|
| 225 |
violations, snapshots = [], []
|
| 226 |
frame_count = 0
|
| 227 |
+
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 228 |
+
start_time = time.time()
|
| 229 |
|
| 230 |
while True:
|
| 231 |
ret, frame = video.read()
|
|
|
|
| 233 |
break
|
| 234 |
|
| 235 |
results = model(frame, device=device)
|
| 236 |
+
seen_violations = set()
|
| 237 |
for result in results:
|
| 238 |
for box in result.boxes:
|
| 239 |
cls, conf = int(box.cls), float(box.conf)
|
| 240 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
|
| 241 |
+
if label in seen_violations or label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
|
|
|
| 242 |
continue
|
|
|
|
|
|
|
| 243 |
seen_violations.add(label)
|
| 244 |
|
| 245 |
violation = {
|
|
|
|
| 263 |
})
|
| 264 |
|
| 265 |
frame_count += 1
|
| 266 |
+
if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
|
| 267 |
+
logger.warning("Processing time limit of 30 seconds exceeded")
|
| 268 |
+
break
|
| 269 |
|
| 270 |
video.release()
|
| 271 |
os.remove(video_path)
|
| 272 |
+
score = calculate_safety_score(violations, max(total_frames, 1))
|
| 273 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 274 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
| 275 |
|
|
|
|
| 309 |
for v in result["violations"]:
|
| 310 |
violation_name = v["violation"]
|
| 311 |
if violation_name == "no_helmet":
|
| 312 |
+
violation_name = "Missing Helmet"
|
| 313 |
+
elif violation_name == "no_harness":
|
| 314 |
+
violation_name = "Missing Harness"
|
| 315 |
else:
|
| 316 |
+
violation_name = "Unsafe Posture"
|
| 317 |
row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
|
| 318 |
rows.append(row)
|
| 319 |
violation_table = header + separator + "\n".join(rows)
|
|
|
|
| 321 |
snapshots_text = "No snapshots captured."
|
| 322 |
if result["snapshots"]:
|
| 323 |
snapshots_text = "\n".join(
|
| 324 |
+
f"- Snapshot for {s['violation'].replace('no_helmet', 'Missing Helmet').replace('no_harness', 'Missing Harness').replace('unsafe_posture', 'Unsafe Posture')} at frame {s['frame']}: "
|
| 325 |
for s in result["snapshots"]
|
| 326 |
)
|
| 327 |
|
|
|
|
| 347 |
gr.Textbox(label="Violation Details URL")
|
| 348 |
],
|
| 349 |
title="Worksite Safety Violation Analyzer",
|
| 350 |
+
description="Upload site videos to detect safety violations (e.g., missing helmet, missing harness, unsafe posture)."
|
| 351 |
)
|
| 352 |
|
| 353 |
if __name__ == "__main__":
|