Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -209,11 +209,13 @@ def calculate_safety_score(violations):
|
|
| 209 |
|
| 210 |
def process_video(video_data):
|
| 211 |
try:
|
|
|
|
| 212 |
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
|
| 213 |
with open(video_path, "wb") as f:
|
| 214 |
f.write(video_data)
|
| 215 |
logger.info(f"Video saved: {video_path}")
|
| 216 |
|
|
|
|
| 217 |
video = cv2.VideoCapture(video_path)
|
| 218 |
if not video.isOpened():
|
| 219 |
raise ValueError("Could not open video file")
|
|
@@ -228,8 +230,9 @@ def process_video(video_data):
|
|
| 228 |
while True:
|
| 229 |
ret, frame = video.read()
|
| 230 |
if not ret:
|
| 231 |
-
break
|
| 232 |
|
|
|
|
| 233 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 234 |
frame_count += 1
|
| 235 |
continue
|
|
@@ -238,20 +241,25 @@ def process_video(video_data):
|
|
| 238 |
logger.info("Processing time limit reached")
|
| 239 |
break
|
| 240 |
|
|
|
|
| 241 |
results = model(frame, device=device)
|
| 242 |
-
seen_violations = set()
|
|
|
|
| 243 |
for result in results:
|
| 244 |
for box in result.boxes:
|
| 245 |
cls, conf = int(box.cls), float(box.conf)
|
| 246 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 250 |
continue
|
|
|
|
| 251 |
if label in seen_violations:
|
| 252 |
-
continue
|
|
|
|
| 253 |
seen_violations.add(label)
|
| 254 |
|
|
|
|
| 255 |
violation = {
|
| 256 |
"frame": frame_count,
|
| 257 |
"violation": label,
|
|
@@ -261,6 +269,7 @@ def process_video(video_data):
|
|
| 261 |
}
|
| 262 |
violations.append(violation)
|
| 263 |
|
|
|
|
| 264 |
if not snapshot_taken[label]:
|
| 265 |
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
|
| 266 |
cv2.imwrite(snapshot_path, frame)
|
|
@@ -279,6 +288,7 @@ def process_video(video_data):
|
|
| 279 |
video.release()
|
| 280 |
os.remove(video_path)
|
| 281 |
|
|
|
|
| 282 |
if not violations:
|
| 283 |
logger.info("No violations detected")
|
| 284 |
return {
|
|
@@ -290,6 +300,7 @@ def process_video(video_data):
|
|
| 290 |
"message": "No violations detected in the video."
|
| 291 |
}
|
| 292 |
|
|
|
|
| 293 |
score = calculate_safety_score(violations)
|
| 294 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 295 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|
|
|
|
| 209 |
|
| 210 |
def process_video(video_data):
|
| 211 |
try:
|
| 212 |
+
# Save video to temporary file
|
| 213 |
video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
|
| 214 |
with open(video_path, "wb") as f:
|
| 215 |
f.write(video_data)
|
| 216 |
logger.info(f"Video saved: {video_path}")
|
| 217 |
|
| 218 |
+
# Read the video
|
| 219 |
video = cv2.VideoCapture(video_path)
|
| 220 |
if not video.isOpened():
|
| 221 |
raise ValueError("Could not open video file")
|
|
|
|
| 230 |
while True:
|
| 231 |
ret, frame = video.read()
|
| 232 |
if not ret:
|
| 233 |
+
break # Break if the video has ended
|
| 234 |
|
| 235 |
+
# Process every frame (or based on FRAME_SKIP)
|
| 236 |
if frame_count % CONFIG["FRAME_SKIP"] != 0:
|
| 237 |
frame_count += 1
|
| 238 |
continue
|
|
|
|
| 241 |
logger.info("Processing time limit reached")
|
| 242 |
break
|
| 243 |
|
| 244 |
+
# Model inference
|
| 245 |
results = model(frame, device=device)
|
| 246 |
+
seen_violations = set() # Track violations detected in the current frame
|
| 247 |
+
|
| 248 |
for result in results:
|
| 249 |
for box in result.boxes:
|
| 250 |
cls, conf = int(box.cls), float(box.conf)
|
| 251 |
label = CONFIG["VIOLATION_LABELS"].get(cls, None)
|
| 252 |
+
|
| 253 |
+
# Skip if it's not a relevant violation or if confidence is too low
|
| 254 |
+
if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 255 |
continue
|
| 256 |
+
|
| 257 |
if label in seen_violations:
|
| 258 |
+
continue # Avoid duplicates in the same frame
|
| 259 |
+
|
| 260 |
seen_violations.add(label)
|
| 261 |
|
| 262 |
+
# Save the violation data
|
| 263 |
violation = {
|
| 264 |
"frame": frame_count,
|
| 265 |
"violation": label,
|
|
|
|
| 269 |
}
|
| 270 |
violations.append(violation)
|
| 271 |
|
| 272 |
+
# Snapshot for the first occurrence of each violation
|
| 273 |
if not snapshot_taken[label]:
|
| 274 |
snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
|
| 275 |
cv2.imwrite(snapshot_path, frame)
|
|
|
|
| 288 |
video.release()
|
| 289 |
os.remove(video_path)
|
| 290 |
|
| 291 |
+
# If no violations were detected, return a message
|
| 292 |
if not violations:
|
| 293 |
logger.info("No violations detected")
|
| 294 |
return {
|
|
|
|
| 300 |
"message": "No violations detected in the video."
|
| 301 |
}
|
| 302 |
|
| 303 |
+
# Calculate compliance score
|
| 304 |
score = calculate_safety_score(violations)
|
| 305 |
pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
|
| 306 |
report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
|