PrashanthB461 commited on
Commit
547ffb5
·
verified ·
1 Parent(s): 2b4f925

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -209,11 +209,13 @@ def calculate_safety_score(violations):
209
 
210
  def process_video(video_data):
211
  try:
 
212
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
213
  with open(video_path, "wb") as f:
214
  f.write(video_data)
215
  logger.info(f"Video saved: {video_path}")
216
 
 
217
  video = cv2.VideoCapture(video_path)
218
  if not video.isOpened():
219
  raise ValueError("Could not open video file")
@@ -228,8 +230,9 @@ def process_video(video_data):
228
  while True:
229
  ret, frame = video.read()
230
  if not ret:
231
- break
232
 
 
233
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
234
  frame_count += 1
235
  continue
@@ -238,20 +241,25 @@ def process_video(video_data):
238
  logger.info("Processing time limit reached")
239
  break
240
 
 
241
  results = model(frame, device=device)
242
- seen_violations = set()
 
243
  for result in results:
244
  for box in result.boxes:
245
  cls, conf = int(box.cls), float(box.conf)
246
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
247
- if label not in CONFIG["VIOLATION_LABELS"].values():
248
- continue
249
- if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
250
  continue
 
251
  if label in seen_violations:
252
- continue
 
253
  seen_violations.add(label)
254
 
 
255
  violation = {
256
  "frame": frame_count,
257
  "violation": label,
@@ -261,6 +269,7 @@ def process_video(video_data):
261
  }
262
  violations.append(violation)
263
 
 
264
  if not snapshot_taken[label]:
265
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
266
  cv2.imwrite(snapshot_path, frame)
@@ -279,6 +288,7 @@ def process_video(video_data):
279
  video.release()
280
  os.remove(video_path)
281
 
 
282
  if not violations:
283
  logger.info("No violations detected")
284
  return {
@@ -290,6 +300,7 @@ def process_video(video_data):
290
  "message": "No violations detected in the video."
291
  }
292
 
 
293
  score = calculate_safety_score(violations)
294
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
295
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
 
209
 
210
  def process_video(video_data):
211
  try:
212
+ # Save video to temporary file
213
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
214
  with open(video_path, "wb") as f:
215
  f.write(video_data)
216
  logger.info(f"Video saved: {video_path}")
217
 
218
+ # Read the video
219
  video = cv2.VideoCapture(video_path)
220
  if not video.isOpened():
221
  raise ValueError("Could not open video file")
 
230
  while True:
231
  ret, frame = video.read()
232
  if not ret:
233
+ break # Break if the video has ended
234
 
235
+ # Process every frame (or based on FRAME_SKIP)
236
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
237
  frame_count += 1
238
  continue
 
241
  logger.info("Processing time limit reached")
242
  break
243
 
244
+ # Model inference
245
  results = model(frame, device=device)
246
+ seen_violations = set() # Track violations detected in the current frame
247
+
248
  for result in results:
249
  for box in result.boxes:
250
  cls, conf = int(box.cls), float(box.conf)
251
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
252
+
253
+ # Skip if it's not a relevant violation or if confidence is too low
254
+ if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
255
  continue
256
+
257
  if label in seen_violations:
258
+ continue # Avoid duplicates in the same frame
259
+
260
  seen_violations.add(label)
261
 
262
+ # Save the violation data
263
  violation = {
264
  "frame": frame_count,
265
  "violation": label,
 
269
  }
270
  violations.append(violation)
271
 
272
+ # Snapshot for the first occurrence of each violation
273
  if not snapshot_taken[label]:
274
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
275
  cv2.imwrite(snapshot_path, frame)
 
288
  video.release()
289
  os.remove(video_path)
290
 
291
+ # If no violations were detected, return a message
292
  if not violations:
293
  logger.info("No violations detected")
294
  return {
 
300
  "message": "No violations detected in the video."
301
  }
302
 
303
+ # Calculate compliance score
304
  score = calculate_safety_score(violations)
305
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
306
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)