Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,10 +38,9 @@ CONFIG = {
|
|
| 38 |
"domain": "login"
|
| 39 |
},
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 41 |
-
"FRAME_SKIP":
|
| 42 |
"MAX_PROCESSING_TIME": 30,
|
| 43 |
-
"CONFIDENCE_THRESHOLD": 0.5
|
| 44 |
-
"TEMPORAL_THRESHOLD": 1.0 # Time threshold in seconds to avoid counting the same violation
|
| 45 |
}
|
| 46 |
|
| 47 |
# Setup logging
|
|
@@ -210,13 +209,11 @@ def calculate_safety_score(violations):
|
|
| 210 |
|
| 211 |
def process_video(video_data):
|
| 212 |
try:
|
| 213 |
-
# Save video to temporary file
|
| 214 |
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
|
| 215 |
with open(video_path, "wb") as f:
|
| 216 |
f.write(video_data)
|
| 217 |
logger.info(f"Video saved: {video_path}")
|
| 218 |
|
| 219 |
-
# Read the video
|
| 220 |
video = cv2.VideoCapture(video_path)
|
| 221 |
if not video.isOpened():
|
| 222 |
raise ValueError("Could not open video file")
|
|
@@ -227,14 +224,13 @@ def process_video(video_data):
|
|
| 227 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 228 |
|
| 229 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 230 |
-
|
| 231 |
|
| 232 |
while True:
|
| 233 |
ret, frame = video.read()
|
| 234 |
if not ret:
|
| 235 |
-
break
|
| 236 |
|
| 237 |
-
# Process every frame (or based on FRAME_SKIP)
|
| 238 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 239 |
frame_count += 1
|
| 240 |
continue
|
|
@@ -243,37 +239,27 @@ def process_video(video_data):
|
|
| 243 |
logger.info("Processing time limit reached")
|
| 244 |
break
|
| 245 |
|
| 246 |
-
# Model inference
|
| 247 |
results = model(frame, device=device)
|
| 248 |
-
seen_violations = set() # Track violations detected in the current frame
|
| 249 |
-
|
| 250 |
for result in results:
|
| 251 |
for box in result.boxes:
|
| 252 |
cls, conf = int(box.cls), float(box.conf)
|
| 253 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 254 |
-
|
| 255 |
-
# Skip if it's not a relevant violation or if confidence is too low
|
| 256 |
-
if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 257 |
continue
|
| 258 |
-
|
| 259 |
-
# Skip if the same violation is detected again within the temporal threshold
|
| 260 |
-
if time.time() - last_detected[label] < CONFIG["TEMPORAL_THRESHOLD"]:
|
| 261 |
continue
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
# Update last detected time
|
| 264 |
-
last_detected[label] = time.time()
|
| 265 |
-
|
| 266 |
-
# Save the violation data
|
| 267 |
violation = {
|
| 268 |
"frame": frame_count,
|
| 269 |
"violation": label,
|
| 270 |
"confidence": round(conf, 2),
|
| 271 |
-
"bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
|
| 272 |
"timestamp": frame_count / fps
|
| 273 |
}
|
| 274 |
violations.append(violation)
|
| 275 |
|
| 276 |
-
# Snapshot for the first occurrence of each violation
|
| 277 |
if not snapshot_taken[label]:
|
| 278 |
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
|
| 279 |
cv2.imwrite(snapshot_path, frame)
|
|
@@ -292,7 +278,6 @@ def process_video(video_data):
|
|
| 292 |
video.release()
|
| 293 |
os.remove(video_path)
|
| 294 |
|
| 295 |
-
# If no violations were detected, return a message
|
| 296 |
if not violations:
|
| 297 |
logger.info("No violations detected")
|
| 298 |
return {
|
|
@@ -304,7 +289,6 @@ def process_video(video_data):
|
|
| 304 |
"message": "No violations detected in the video."
|
| 305 |
}
|
| 306 |
|
| 307 |
-
# Calculate compliance score
|
| 308 |
score = calculate_safety_score(violations)
|
| 309 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 310 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
|
@@ -330,8 +314,7 @@ def process_video(video_data):
|
|
| 330 |
|
| 331 |
def gradio_interface(video_file):
|
| 332 |
if not video_file:
|
| 333 |
-
return "", "", "", "", ""
|
| 334 |
-
|
| 335 |
try:
|
| 336 |
yield "Processing video... please wait.", "", "", "", ""
|
| 337 |
|
|
@@ -340,26 +323,25 @@ def gradio_interface(video_file):
|
|
| 340 |
|
| 341 |
result = process_video(video_data)
|
| 342 |
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
# No violations detected — return empty or minimal outputs
|
| 346 |
-
yield "", "", "", "", ""
|
| 347 |
return
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
snapshots_text = ""
|
| 362 |
if result["snapshots"]:
|
|
|
|
| 363 |
snapshots_text = "\n".join(
|
| 364 |
f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: "
|
| 365 |
for s in result["snapshots"]
|
|
@@ -374,7 +356,7 @@ def gradio_interface(video_file):
|
|
| 374 |
)
|
| 375 |
except Exception as e:
|
| 376 |
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 377 |
-
yield f"Error: {str(e)}", "", "", "", ""
|
| 378 |
|
| 379 |
interface = gr.Interface(
|
| 380 |
fn=gradio_interface,
|
|
|
|
| 38 |
"domain": "login"
|
| 39 |
},
|
| 40 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
|
| 41 |
+
"FRAME_SKIP": 15,
|
| 42 |
"MAX_PROCESSING_TIME": 30,
|
| 43 |
+
"CONFIDENCE_THRESHOLD": 0.5
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
# Setup logging
|
|
|
|
| 209 |
|
| 210 |
def process_video(video_data):
|
| 211 |
try:
|
|
|
|
| 212 |
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
|
| 213 |
with open(video_path, "wb") as f:
|
| 214 |
f.write(video_data)
|
| 215 |
logger.info(f"Video saved: {video_path}")
|
| 216 |
|
|
|
|
| 217 |
video = cv2.VideoCapture(video_path)
|
| 218 |
if not video.isOpened():
|
| 219 |
raise ValueError("Could not open video file")
|
|
|
|
| 224 |
fps = video.get(cv2.CAP_PROP_FPS)
|
| 225 |
|
| 226 |
snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
|
| 227 |
+
seen_violations = set() # to track violations and avoid repeating
|
| 228 |
|
| 229 |
while True:
|
| 230 |
ret, frame = video.read()
|
| 231 |
if not ret:
|
| 232 |
+
break
|
| 233 |
|
|
|
|
| 234 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 235 |
frame_count += 1
|
| 236 |
continue
|
|
|
|
| 239 |
logger.info("Processing time limit reached")
|
| 240 |
break
|
| 241 |
|
|
|
|
| 242 |
results = model(frame, device=device)
|
|
|
|
|
|
|
| 243 |
for result in results:
|
| 244 |
for box in result.boxes:
|
| 245 |
cls, conf = int(box.cls), float(box.conf)
|
| 246 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 247 |
+
if label not in CONFIG["VIOLATION_LABELS"].values():
|
|
|
|
|
|
|
| 248 |
continue
|
| 249 |
+
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
|
|
|
|
|
|
| 250 |
continue
|
| 251 |
+
if label in seen_violations:
|
| 252 |
+
continue
|
| 253 |
+
seen_violations.add(label)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
violation = {
|
| 256 |
"frame": frame_count,
|
| 257 |
"violation": label,
|
| 258 |
"confidence": round(conf, 2),
|
|
|
|
| 259 |
"timestamp": frame_count / fps
|
| 260 |
}
|
| 261 |
violations.append(violation)
|
| 262 |
|
|
|
|
| 263 |
if not snapshot_taken[label]:
|
| 264 |
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
|
| 265 |
cv2.imwrite(snapshot_path, frame)
|
|
|
|
| 278 |
video.release()
|
| 279 |
os.remove(video_path)
|
| 280 |
|
|
|
|
| 281 |
if not violations:
|
| 282 |
logger.info("No violations detected")
|
| 283 |
return {
|
|
|
|
| 289 |
"message": "No violations detected in the video."
|
| 290 |
}
|
| 291 |
|
|
|
|
| 292 |
score = calculate_safety_score(violations)
|
| 293 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 294 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
|
|
|
| 314 |
|
| 315 |
def gradio_interface(video_file):
|
| 316 |
if not video_file:
|
| 317 |
+
return "No file uploaded.", "", "No file uploaded.", "", ""
|
|
|
|
| 318 |
try:
|
| 319 |
yield "Processing video... please wait.", "", "", "", ""
|
| 320 |
|
|
|
|
| 323 |
|
| 324 |
result = process_video(video_data)
|
| 325 |
|
| 326 |
+
if result.get("message"):
|
| 327 |
+
yield result["message"], "", "", "", ""
|
|
|
|
|
|
|
| 328 |
return
|
| 329 |
|
| 330 |
+
violation_table = "No violations detected."
|
| 331 |
+
if result["violations"]:
|
| 332 |
+
header = "| Violation | Timestamp (s) | Confidence | \n"
|
| 333 |
+
separator = "|------------------------|---------------|------------|\n"
|
| 334 |
+
rows = []
|
| 335 |
+
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 336 |
+
for v in result["violations"]:
|
| 337 |
+
display_name = violation_name_map.get(v["violation"], v["violation"])
|
| 338 |
+
row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} |"
|
| 339 |
+
rows.append(row)
|
| 340 |
+
violation_table = header + separator + "\n".join(rows)
|
| 341 |
+
|
| 342 |
+
snapshots_text = "No snapshots captured."
|
| 343 |
if result["snapshots"]:
|
| 344 |
+
violation_name_map = CONFIG["DISPLAY_NAMES"]
|
| 345 |
snapshots_text = "\n".join(
|
| 346 |
f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: "
|
| 347 |
for s in result["snapshots"]
|
|
|
|
| 356 |
)
|
| 357 |
except Exception as e:
|
| 358 |
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 359 |
+
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
|
| 360 |
|
| 361 |
interface = gr.Interface(
|
| 362 |
fn=gradio_interface,
|