PrashanthB461 commited on
Commit
6c397e1
·
verified ·
1 Parent(s): ba9ee16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -26
app.py CHANGED
@@ -18,8 +18,8 @@ from retrying import retry
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8_safety.pt", # Make sure this file exists in your directory
22
- "FALLBACK_MODEL": "yolov8n.pt", # Fallback model for testing
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
@@ -40,7 +40,8 @@ CONFIG = {
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
- "CONFIDENCE_THRESHOLD": 0.5
 
44
  }
45
 
46
  # Setup logging
@@ -54,14 +55,12 @@ logger.info(f"Using device: {device}")
54
 
55
  def load_model():
56
  try:
57
- # Check if the model file exists
58
  if os.path.isfile(CONFIG["MODEL_PATH"]):
59
  model_path = CONFIG["MODEL_PATH"]
60
  logger.info(f"Model loaded: {model_path}")
61
  else:
62
  model_path = CONFIG["FALLBACK_MODEL"]
63
  logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
64
- # Download fallback model if necessary
65
  if not os.path.isfile(model_path):
66
  logger.info(f"Downloading fallback model: {model_path}")
67
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
@@ -73,6 +72,33 @@ def load_model():
73
 
74
  model = load_model()
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # ==========================
77
  # Salesforce Integration
78
  # ==========================
@@ -81,7 +107,7 @@ def connect_to_salesforce():
81
  try:
82
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
83
  logger.info("Connected to Salesforce")
84
- sf.describe() # verify connection and metadata fetch
85
  return sf
86
  except Exception as e:
87
  logger.error(f"Salesforce connection failed: {e}")
@@ -229,7 +255,7 @@ def process_video(video_data):
229
  fps = video.get(cv2.CAP_PROP_FPS)
230
 
231
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
232
- detected_violations = {}
233
 
234
  while True:
235
  ret, frame = video.read()
@@ -245,6 +271,7 @@ def process_video(video_data):
245
  break
246
 
247
  results = model(frame, device=device)
 
248
  for result in results:
249
  for box in result.boxes:
250
  cls, conf = int(box.cls), float(box.conf)
@@ -252,31 +279,82 @@ def process_video(video_data):
252
  if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
253
  continue
254
 
255
- violation_key = f"{label}_{frame_count}"
256
- if violation_key in detected_violations:
257
- continue
258
-
259
- detected_violations[violation_key] = {
260
- "frame": frame_count,
261
  "violation": label,
262
  "confidence": round(conf, 2),
263
- "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
264
- "timestamp": frame_count / fps
265
- }
266
- violations.append(detected_violations[violation_key])
267
-
268
- if not snapshot_taken[label]:
269
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  cv2.imwrite(snapshot_path, frame)
271
  with open(snapshot_path, "rb") as img_file:
272
  img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
273
  snapshots.append({
274
- "violation": label,
275
  "frame": frame_count,
276
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
277
  "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
278
  })
279
- snapshot_taken[label] = True
280
 
281
  frame_count += 1
282
 
@@ -334,13 +412,13 @@ def gradio_interface(video_file):
334
 
335
  violation_table = "No violations detected."
336
  if result["violations"]:
337
- header = "| Violation | Timestamp (s) | Confidence | Bounding Box |\n"
338
- separator = "|------------------------|---------------|------------|--------------------------|\n"
339
  rows = []
340
  violation_name_map = CONFIG["DISPLAY_NAMES"]
341
  for v in result["violations"]:
342
  display_name = violation_name_map.get(v["violation"], v["violation"])
343
- row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['bounding_box']} |"
344
  rows.append(row)
345
  violation_table = header + separator + "\n".join(rows)
346
 
@@ -380,4 +458,4 @@ interface = gr.Interface(
380
 
381
  if __name__ == "__main__":
382
  logger.info("Launching Safety Analyzer App...")
383
- interface.launch()
 
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8_safety.pt",
22
+ "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
 
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
  "MAX_PROCESSING_TIME": 30,
43
+ "CONFIDENCE_THRESHOLD": 0.5,
44
+ "IOU_THRESHOLD": 0.5 # Added for worker tracking
45
  }
46
 
47
  # Setup logging
 
55
 
56
  def load_model():
57
  try:
 
58
  if os.path.isfile(CONFIG["MODEL_PATH"]):
59
  model_path = CONFIG["MODEL_PATH"]
60
  logger.info(f"Model loaded: {model_path}")
61
  else:
62
  model_path = CONFIG["FALLBACK_MODEL"]
63
  logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
 
64
  if not os.path.isfile(model_path):
65
  logger.info(f"Downloading fallback model: {model_path}")
66
  torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
 
72
 
73
  model = load_model()
74
 
75
+ # ==========================
76
+ # Helper Functions
77
+ # ==========================
78
+ def calculate_iou(box1, box2):
79
+ """Calculate Intersection over Union (IoU) for two bounding boxes."""
80
+ x1, y1, w1, h1 = box1
81
+ x2, y2, w2, h2 = box2
82
+
83
+ # Convert to top-left and bottom-right coordinates
84
+ x1_min, y1_min = x1 - w1/2, y1 - h1/2
85
+ x1_max, y1_max = x1 + w1/2, y1 + h1/2
86
+ x2_min, y2_min = x2 - w2/2, y2 - h2/2
87
+ x2_max, y2_max = x2 + w2/2, y2 + h2/2
88
+
89
+ # Calculate intersection
90
+ x_min = max(x1_min, x2_min)
91
+ y_min = max(y1_min, y2_min)
92
+ x_max = min(x1_max, x2_max)
93
+ y_max = min(y1_max, y2_max)
94
+
95
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
96
+ area1 = w1 * h1
97
+ area2 = w2 * h2
98
+ union = area1 + area2 - intersection
99
+
100
+ return intersection / union if union > 0 else 0
101
+
102
  # ==========================
103
  # Salesforce Integration
104
  # ==========================
 
107
  try:
108
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
109
  logger.info("Connected to Salesforce")
110
+ sf.describe()
111
  return sf
112
  except Exception as e:
113
  logger.error(f"Salesforce connection failed: {e}")
 
255
  fps = video.get(cv2.CAP_PROP_FPS)
256
 
257
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
258
+ workers = [] # List to track workers: [{"id": int, "violations": set(), "bbox": list, "last_frame": int}]
259
 
260
  while True:
261
  ret, frame = video.read()
 
271
  break
272
 
273
  results = model(frame, device=device)
274
+ current_detections = []
275
  for result in results:
276
  for box in result.boxes:
277
  cls, conf = int(box.cls), float(box.conf)
 
279
  if label not in CONFIG["VIOLATION_LABELS"].values() or conf < CONFIG["CONFIDENCE_THRESHOLD"]:
280
  continue
281
 
282
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
283
+ current_detections.append({
 
 
 
 
284
  "violation": label,
285
  "confidence": round(conf, 2),
286
+ "bounding_box": bbox,
287
+ "timestamp": frame_count / fps,
288
+ "frame": frame_count
289
+ })
290
+
291
+ # Assign detections to workers
292
+ for detection in current_detections:
293
+ matched_worker = None
294
+ max_iou = 0
295
+ for worker in workers:
296
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
297
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
298
+ max_iou = iou
299
+ matched_worker = worker
300
+
301
+ if matched_worker:
302
+ # Update existing worker
303
+ if detection["violation"] not in matched_worker["violations"]:
304
+ matched_worker["violations"].add(detection["violation"])
305
+ violations.append({
306
+ "frame": frame_count,
307
+ "violation": detection["violation"],
308
+ "confidence": detection["confidence"],
309
+ "bounding_box": detection["bounding_box"],
310
+ "timestamp": detection["timestamp"],
311
+ "worker_id": matched_worker["id"]
312
+ })
313
+ # Take snapshot if not already taken for this violation type
314
+ if not snapshot_taken[detection["violation"]]:
315
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{detection['violation']}.jpg")
316
+ cv2.imwrite(snapshot_path, frame)
317
+ with open(snapshot_path, "rb") as img_file:
318
+ img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
319
+ snapshots.append({
320
+ "violation": detection["violation"],
321
+ "frame": frame_count,
322
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
323
+ "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
324
+ })
325
+ snapshot_taken[detection["violation"]] = True
326
+ matched_worker["bbox"] = detection["bounding_box"]
327
+ matched_worker["last_frame"] = frame_count
328
+ else:
329
+ # New worker
330
+ worker_id = len(workers) + 1
331
+ workers.append({
332
+ "id": worker_id,
333
+ "violations": {detection["violation"]},
334
+ "bbox": detection["bounding_box"],
335
+ "last_frame": frame_count
336
+ })
337
+ violations.append({
338
+ "frame": frame_count,
339
+ "violation": detection["violation"],
340
+ "confidence": detection["confidence"],
341
+ "bounding_box": detection["bounding_box"],
342
+ "timestamp": detection["timestamp"],
343
+ "worker_id": worker_id
344
+ })
345
+ # Take snapshot if not already taken for this violation type
346
+ if not snapshot_taken[detection["violation"]]:
347
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{detection['violation']}.jpg")
348
  cv2.imwrite(snapshot_path, frame)
349
  with open(snapshot_path, "rb") as img_file:
350
  img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
351
  snapshots.append({
352
+ "violation": detection["violation"],
353
  "frame": frame_count,
354
  "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
355
  "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
356
  })
357
+ snapshot_taken[detection["violation"]] = True
358
 
359
  frame_count += 1
360
 
 
412
 
413
  violation_table = "No violations detected."
414
  if result["violations"]:
415
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
416
+ separator = "|------------------------|---------------|------------|-----------|\n"
417
  rows = []
418
  violation_name_map = CONFIG["DISPLAY_NAMES"]
419
  for v in result["violations"]:
420
  display_name = violation_name_map.get(v["violation"], v["violation"])
421
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
422
  rows.append(row)
423
  violation_table = header + separator + "\n".join(rows)
424
 
 
458
 
459
  if __name__ == "__main__":
460
  logger.info("Launching Safety Analyzer App...")
461
+ interface.launch()