PrashanthB461 commited on
Commit
60028e1
·
verified ·
1 Parent(s): cc1585c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -24,8 +24,7 @@ CONFIG = {
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
- 2: "unsafe_posture",
28
- 3: "unsafe_zone"
29
  },
30
  "SF_CREDENTIALS": {
31
  "username": "prashanth1ai@safety.com",
@@ -34,6 +33,7 @@ CONFIG = {
34
  "domain": "login"
35
  },
36
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
 
37
  }
38
 
39
  # Setup logging
@@ -196,15 +196,17 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
196
  # ==========================
197
  # Safety Score Calculation
198
  # ==========================
199
- def calculate_safety_score(violations):
200
  penalties = {
201
- "no_helmet": 25,
202
- "no_harness": 30,
203
- "unsafe_posture": 20
204
  }
205
- # Only penalize for detected violations (no_helmet, no_harness, unsafe_posture)
206
- score = 100 - sum(penalties.get(v["violation"], 0) for v in violations)
207
- return max(score, 0)
 
 
208
 
209
  # ==========================
210
  # Video Processing
@@ -222,6 +224,8 @@ def process_video(video_data):
222
 
223
  violations, snapshots = [], []
224
  frame_count = 0
 
 
225
 
226
  while True:
227
  ret, frame = video.read()
@@ -229,16 +233,13 @@ def process_video(video_data):
229
  break
230
 
231
  results = model(frame, device=device)
232
- seen_violations = set() # Track unique violations in this frame
233
  for result in results:
234
  for box in result.boxes:
235
  cls, conf = int(box.cls), float(box.conf)
236
  label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
237
- # Only consider "no_helmet", "no_harness", "unsafe_posture" as violations
238
- if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
239
  continue
240
- if label in seen_violations:
241
- continue # Skip if this violation type was already recorded in this frame
242
  seen_violations.add(label)
243
 
244
  violation = {
@@ -262,10 +263,13 @@ def process_video(video_data):
262
  })
263
 
264
  frame_count += 1
 
 
 
265
 
266
  video.release()
267
  os.remove(video_path)
268
- score = calculate_safety_score(violations)
269
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
270
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
271
 
@@ -305,9 +309,11 @@ def gradio_interface(video_file):
305
  for v in result["violations"]:
306
  violation_name = v["violation"]
307
  if violation_name == "no_helmet":
308
- violation_name = "no_helmet"
 
 
309
  else:
310
- violation_name = violation_name.replace("no_", "").replace("unsafe_", "")
311
  row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
312
  rows.append(row)
313
  violation_table = header + separator + "\n".join(rows)
@@ -315,7 +321,7 @@ def gradio_interface(video_file):
315
  snapshots_text = "No snapshots captured."
316
  if result["snapshots"]:
317
  snapshots_text = "\n".join(
318
- f"- Snapshot for {s['violation'].replace('no_', '').replace('unsafe_', '')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
319
  for s in result["snapshots"]
320
  )
321
 
@@ -341,7 +347,7 @@ interface = gr.Interface(
341
  gr.Textbox(label="Violation Details URL")
342
  ],
343
  title="Worksite Safety Violation Analyzer",
344
- description="Upload site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
345
  )
346
 
347
  if __name__ == "__main__":
 
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
  1: "no_harness",
27
+ 2: "unsafe_posture"
 
28
  },
29
  "SF_CREDENTIALS": {
30
  "username": "prashanth1ai@safety.com",
 
33
  "domain": "login"
34
  },
35
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
36
+ "MAX_PROCESSING_TIME": 30 # Reduced to 30 seconds
37
  }
38
 
39
  # Setup logging
 
196
  # ==========================
197
  # Safety Score Calculation
198
  # ==========================
199
+ def calculate_safety_score(violations, total_frames):
200
  penalties = {
201
+ "no_helmet": 25 / total_frames, # Normalize penalty by video length
202
+ "no_harness": 30 / total_frames,
203
+ "unsafe_posture": 20 / total_frames
204
  }
205
+ score = 100
206
+ for v in violations:
207
+ penalty = penalties.get(v["violation"], 0)
208
+ score -= penalty
209
+ return max(round(score, 2), 0)
210
 
211
  # ==========================
212
  # Video Processing
 
224
 
225
  violations, snapshots = [], []
226
  frame_count = 0
227
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
228
+ start_time = time.time()
229
 
230
  while True:
231
  ret, frame = video.read()
 
233
  break
234
 
235
  results = model(frame, device=device)
236
+ seen_violations = set()
237
  for result in results:
238
  for box in result.boxes:
239
  cls, conf = int(box.cls), float(box.conf)
240
  label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
241
+ if label in seen_violations or label not in ["no_helmet", "no_harness", "unsafe_posture"]:
 
242
  continue
 
 
243
  seen_violations.add(label)
244
 
245
  violation = {
 
263
  })
264
 
265
  frame_count += 1
266
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
267
+ logger.warning("Processing time limit of 30 seconds exceeded")
268
+ break
269
 
270
  video.release()
271
  os.remove(video_path)
272
+ score = calculate_safety_score(violations, max(total_frames, 1))
273
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
274
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
275
 
 
309
  for v in result["violations"]:
310
  violation_name = v["violation"]
311
  if violation_name == "no_helmet":
312
+ violation_name = "Missing Helmet"
313
+ elif violation_name == "no_harness":
314
+ violation_name = "Missing Harness"
315
  else:
316
+ violation_name = "Unsafe Posture"
317
  row = f"| {violation_name:<13} | {v['timestamp']:.2f}s | {v['confidence']:.2f} | {v['bounding_box']} | {result['violation_details_url']} |"
318
  rows.append(row)
319
  violation_table = header + separator + "\n".join(rows)
 
321
  snapshots_text = "No snapshots captured."
322
  if result["snapshots"]:
323
  snapshots_text = "\n".join(
324
+ f"- Snapshot for {s['violation'].replace('no_helmet', 'Missing Helmet').replace('no_harness', 'Missing Harness').replace('unsafe_posture', 'Unsafe Posture')} at frame {s['frame']}: ![]({s['snapshot_base64']})"
325
  for s in result["snapshots"]
326
  )
327
 
 
347
  gr.Textbox(label="Violation Details URL")
348
  ],
349
  title="Worksite Safety Violation Analyzer",
350
+ description="Upload site videos to detect safety violations (e.g., missing helmet, missing harness, unsafe posture)."
351
  )
352
 
353
  if __name__ == "__main__":