PrashanthB461 commited on
Commit
c7370b8
·
verified ·
1 Parent(s): 3a06ab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -45
app.py CHANGED
@@ -38,10 +38,9 @@ CONFIG = {
38
  "domain": "login"
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
- "FRAME_SKIP": 1, # Process every frame (set to 1 for full video processing)
42
  "MAX_PROCESSING_TIME": 30,
43
- "CONFIDENCE_THRESHOLD": 0.5,
44
- "TEMPORAL_THRESHOLD": 1.0 # Time threshold in seconds to avoid counting the same violation
45
  }
46
 
47
  # Setup logging
@@ -210,13 +209,11 @@ def calculate_safety_score(violations):
210
 
211
  def process_video(video_data):
212
  try:
213
- # Save video to temporary file
214
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
215
  with open(video_path, "wb") as f:
216
  f.write(video_data)
217
  logger.info(f"Video saved: {video_path}")
218
 
219
- # Read the video
220
  video = cv2.VideoCapture(video_path)
221
  if not video.isOpened():
222
  raise ValueError("Could not open video file")
@@ -227,14 +224,13 @@ def process_video(video_data):
227
  fps = video.get(cv2.CAP_PROP_FPS)
228
 
229
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
230
- last_detected = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()} # Tracks last violation time
231
 
232
  while True:
233
  ret, frame = video.read()
234
  if not ret:
235
- break # Break if the video has ended
236
 
237
- # Process every frame (or based on FRAME_SKIP)
238
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
239
  frame_count += 1
240
  continue
@@ -243,37 +239,27 @@ def process_video(video_data):
243
  logger.info("Processing time limit reached")
244
  break
245
 
246
- # Model inference
247
  results = model(frame, device=device)
248
- seen_violations = set() # Track violations detected in the current frame
249
-
250
  for result in results:
251
  for box in result.boxes:
252
  cls, conf = int(box.cls), float(box.conf)
253
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
254
-
255
- # Skip if it's not a relevant violation or if confidence is too low
256
- if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
257
  continue
258
-
259
- # Skip if the same violation is detected again within the temporal threshold
260
- if time.time() - last_detected[label] < CONFIG["TEMPORAL_THRESHOLD"]:
261
  continue
 
 
 
262
 
263
- # Update last detected time
264
- last_detected[label] = time.time()
265
-
266
- # Save the violation data
267
  violation = {
268
  "frame": frame_count,
269
  "violation": label,
270
  "confidence": round(conf, 2),
271
- "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
272
  "timestamp": frame_count / fps
273
  }
274
  violations.append(violation)
275
 
276
- # Snapshot for the first occurrence of each violation
277
  if not snapshot_taken[label]:
278
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
279
  cv2.imwrite(snapshot_path, frame)
@@ -292,7 +278,6 @@ def process_video(video_data):
292
  video.release()
293
  os.remove(video_path)
294
 
295
- # If no violations were detected, return a message
296
  if not violations:
297
  logger.info("No violations detected")
298
  return {
@@ -304,7 +289,6 @@ def process_video(video_data):
304
  "message": "No violations detected in the video."
305
  }
306
 
307
- # Calculate compliance score
308
  score = calculate_safety_score(violations)
309
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
310
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
@@ -330,8 +314,7 @@ def process_video(video_data):
330
 
331
  def gradio_interface(video_file):
332
  if not video_file:
333
- return "", "", "", "", "" # Return empty outputs when no file uploaded
334
-
335
  try:
336
  yield "Processing video... please wait.", "", "", "", ""
337
 
@@ -340,26 +323,25 @@ def gradio_interface(video_file):
340
 
341
  result = process_video(video_data)
342
 
343
- # Only output violations if detected, else output nothing or minimal
344
- if not result["violations"]:
345
- # No violations detected — return empty or minimal outputs
346
- yield "", "", "", "", ""
347
  return
348
 
349
- # Build violation table only if violations exist
350
- header = "| Violation | Timestamp (s) | Confidence |\n"
351
- separator = "|------------------------|---------------|------------|\n"
352
- rows = []
353
- violation_name_map = CONFIG["DISPLAY_NAMES"]
354
- for v in result["violations"]:
355
- display_name = violation_name_map.get(v["violation"], v["violation"])
356
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} |"
357
- rows.append(row)
358
- violation_table = header + separator + "\n".join(rows)
359
-
360
- # Prepare snapshots if available
361
- snapshots_text = ""
362
  if result["snapshots"]:
 
363
  snapshots_text = "\n".join(
364
  f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
365
  for s in result["snapshots"]
@@ -374,7 +356,7 @@ def gradio_interface(video_file):
374
  )
375
  except Exception as e:
376
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
377
- yield f"Error: {str(e)}", "", "", "", ""
378
 
379
  interface = gr.Interface(
380
  fn=gradio_interface,
 
38
  "domain": "login"
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
+ "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
+ "CONFIDENCE_THRESHOLD": 0.5
 
44
  }
45
 
46
  # Setup logging
 
209
 
210
  def process_video(video_data):
211
  try:
 
212
  video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
213
  with open(video_path, "wb") as f:
214
  f.write(video_data)
215
  logger.info(f"Video saved: {video_path}")
216
 
 
217
  video = cv2.VideoCapture(video_path)
218
  if not video.isOpened():
219
  raise ValueError("Could not open video file")
 
224
  fps = video.get(cv2.CAP_PROP_FPS)
225
 
226
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
227
+ seen_violations = set() # to track violations and avoid repeating
228
 
229
  while True:
230
  ret, frame = video.read()
231
  if not ret:
232
+ break
233
 
 
234
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
235
  frame_count += 1
236
  continue
 
239
  logger.info("Processing time limit reached")
240
  break
241
 
 
242
  results = model(frame, device=device)
 
 
243
  for result in results:
244
  for box in result.boxes:
245
  cls, conf = int(box.cls), float(box.conf)
246
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
247
+ if label not in CONFIG["VIOLATION_LABELS"].values():
 
 
248
  continue
249
+ if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
 
 
250
  continue
251
+ if label in seen_violations:
252
+ continue
253
+ seen_violations.add(label)
254
 
 
 
 
 
255
  violation = {
256
  "frame": frame_count,
257
  "violation": label,
258
  "confidence": round(conf, 2),
 
259
  "timestamp": frame_count / fps
260
  }
261
  violations.append(violation)
262
 
 
263
  if not snapshot_taken[label]:
264
  snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
265
  cv2.imwrite(snapshot_path, frame)
 
278
  video.release()
279
  os.remove(video_path)
280
 
 
281
  if not violations:
282
  logger.info("No violations detected")
283
  return {
 
289
  "message": "No violations detected in the video."
290
  }
291
 
 
292
  score = calculate_safety_score(violations)
293
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
294
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
 
314
 
315
  def gradio_interface(video_file):
316
  if not video_file:
317
+ return "No file uploaded.", "", "No file uploaded.", "", ""
 
318
  try:
319
  yield "Processing video... please wait.", "", "", "", ""
320
 
 
323
 
324
  result = process_video(video_data)
325
 
326
+ if result.get("message"):
327
+ yield result["message"], "", "", "", ""
 
 
328
  return
329
 
330
+ violation_table = "No violations detected."
331
+ if result["violations"]:
332
+ header = "| Violation | Timestamp (s) | Confidence | \n"
333
+ separator = "|------------------------|---------------|------------|\n"
334
+ rows = []
335
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
336
+ for v in result["violations"]:
337
+ display_name = violation_name_map.get(v["violation"], v["violation"])
338
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} |"
339
+ rows.append(row)
340
+ violation_table = header + separator + "\n".join(rows)
341
+
342
+ snapshots_text = "No snapshots captured."
343
  if result["snapshots"]:
344
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
345
  snapshots_text = "\n".join(
346
  f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
347
  for s in result["snapshots"]
 
356
  )
357
  except Exception as e:
358
  logger.error(f"Error in Gradio interface: {e}", exc_info=True)
359
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
360
 
361
  interface = gr.Interface(
362
  fn=gradio_interface,