PrashanthB461 commited on
Commit
f267b6b
·
verified ·
1 Parent(s): 5a5a760

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -142
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import torch
5
  import numpy as np
6
  from ultralytics import YOLO
 
7
  import time
8
  from simple_salesforce import Salesforce
9
  from reportlab.lib.pagesizes import letter
@@ -60,11 +61,14 @@ CONFIG = {
60
  "improper_tool_use": 0.4
61
  },
62
  "MIN_VIOLATION_FRAMES": 3,
63
- "WORKER_TRACKING_DURATION": 3.0,
64
  "MAX_PROCESSING_TIME": 60, # 1 minute limit
65
- "FRAME_SKIP": 2, # Process every 2nd frame for speed
66
- "BATCH_SIZE": 16, # Frames per batch
67
- "PARALLEL_WORKERS": max(1, cpu_count() - 1) # Use all CPU cores except one
 
 
 
68
  }
69
 
70
  # Setup logging
@@ -116,54 +120,17 @@ def draw_detections(frame, detections):
116
  cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
117
  return frame
118
 
119
- def calculate_iou(box1, box2):
120
- x1, y1, w1, h1 = box1
121
- x2, y2, w2, h2 = box2
122
-
123
- x_left = max(x1 - w1/2, x2 - w2/2)
124
- y_top = max(y1 - h1/2, y2 - h2/2)
125
- x_right = min(x1 + w1/2, x2 + w2/2)
126
- y_bottom = min(y1 + h1/2, y2 + h2/2)
127
-
128
- if x_right < x_left or y_bottom < y_top:
129
- return 0.0
130
-
131
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
132
- box1_area = w1 * h1
133
- box2_area = w2 * h2
134
- union_area = box1_area + box2_area - intersection_area
135
-
136
- return intersection_area / union_area
137
-
138
- def process_frame_batch(frame_batch, frame_indices, fps):
139
- batch_results = []
140
- results = model(frame_batch, device=device, conf=0.1, verbose=False)
141
-
142
- for idx, (result, frame_idx) in enumerate(zip(results, frame_indices)):
143
- current_time = frame_idx / fps
144
- detections = []
145
-
146
- boxes = result.boxes
147
- for box in boxes:
148
- cls = int(box.cls)
149
- conf = float(box.conf)
150
- label = CONFIG["VIOLATION_LABELS"].get(cls, None)
151
-
152
- if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
153
- continue
154
-
155
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
156
- detections.append({
157
- "frame": frame_idx,
158
- "violation": label,
159
- "confidence": round(conf, 2),
160
- "bounding_box": bbox,
161
- "timestamp": current_time
162
- })
163
-
164
- batch_results.append((frame_idx, detections))
165
-
166
- return batch_results
167
 
168
  def generate_violation_pdf(violations, score):
169
  try:
@@ -193,7 +160,7 @@ def generate_violation_pdf(violations, score):
193
  else:
194
  for v in violations:
195
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
196
- text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
197
  c.drawString(1 * inch, y_position, text)
198
  y_position -= 0.3 * inch
199
  if y_position < 1 * inch:
@@ -214,18 +181,6 @@ def generate_violation_pdf(violations, score):
214
  logger.error(f"Error generating PDF: {e}")
215
  return "", "", None
216
 
217
- def calculate_safety_score(violations):
218
- penalties = {
219
- "no_helmet": 25,
220
- "no_harness": 30,
221
- "unsafe_posture": 20,
222
- "unsafe_zone": 35,
223
- "improper_tool_use": 25
224
- }
225
- total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
226
- score = 100 - total_penalty
227
- return max(score, 0)
228
-
229
  # ==========================
230
  # Salesforce Integration
231
  # ==========================
@@ -268,7 +223,7 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
268
  try:
269
  sf = connect_to_salesforce()
270
  violations_text = "\n".join(
271
- f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
272
  for v in violations
273
  ) or "No violations detected."
274
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
@@ -325,18 +280,22 @@ def process_video(video_data):
325
 
326
  # Get video properties
327
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
328
- fps = cap.get(cv2.CAP_PROP_FPS)
329
- if fps <= 0:
330
- fps = 30
331
  duration = total_frames / fps
332
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
333
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
334
-
335
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
336
 
337
- workers = []
338
- violations = []
339
- helmet_violations = {}
 
 
 
 
 
 
 
340
  snapshots = []
341
  start_time = time.time()
342
  frame_skip = CONFIG["FRAME_SKIP"]
@@ -375,14 +334,15 @@ def process_video(video_data):
375
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
376
  current_time = frame_idx / fps
377
 
378
- # Update progress periodically
379
- if time.time() - start_time > 1.0: # Update every second
380
  progress = (frame_idx / total_frames) * 100
381
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
382
  start_time = time.time()
383
 
384
- # Process detections in this frame
385
  boxes = result.boxes
 
386
  for box in boxes:
387
  cls = int(box.cls)
388
  conf = float(box.conf)
@@ -391,78 +351,75 @@ def process_video(video_data):
391
  if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
392
  continue
393
 
394
- bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  detection = {
396
  "frame": frame_idx,
397
  "violation": label,
398
  "confidence": round(conf, 2),
399
- "bounding_box": bbox,
400
- "timestamp": current_time
 
401
  }
402
 
403
- # Worker tracking
404
- worker_id = None
405
- max_iou = 0
406
- for idx, worker in enumerate(workers):
407
- iou = calculate_iou(bbox, worker["bbox"])
408
- if iou > max_iou and iou > 0.4: # IOU threshold
409
- max_iou = iou
410
- worker_id = worker["id"]
411
- workers[idx]["bbox"] = bbox
412
- workers[idx]["last_seen"] = current_time
413
-
414
- if worker_id is None:
415
- worker_id = len(workers) + 1
416
- workers.append({
417
- "id": worker_id,
418
- "bbox": bbox,
419
- "first_seen": current_time,
420
- "last_seen": current_time
421
- })
422
-
423
- detection["worker_id"] = worker_id
424
-
425
- # Track helmet violations with stricter criteria
426
- if detection["violation"] == "no_helmet":
427
- # Only include high-confidence no_helmet detections
428
- if conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
429
- if worker_id not in helmet_violations:
430
- helmet_violations[worker_id] = []
431
- helmet_violations[worker_id].append(detection)
432
- else:
433
- violations.append(detection)
434
-
435
- # Remove inactive workers
436
- workers = [w for w in workers if current_time - w["last_seen"] < CONFIG["WORKER_TRACKING_DURATION"]]
437
 
438
  cap.release()
439
  os.remove(video_path)
440
  processing_time = time.time() - start_time
441
- logger.info(f"Processing complete in {processing_time:.2f}s. {len(violations)} violations found.")
442
-
443
- # Confirm helmet violations (require multiple detections)
444
- for worker_id, detections in helmet_violations.items():
445
- if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
446
- # Select the detection with the highest confidence
447
- best_detection = max(detections, key=lambda x: x["confidence"])
448
- violations.append(best_detection)
449
-
450
- # Capture snapshot for confirmed no_helmet violation
451
- cap = cv2.VideoCapture(video_path)
452
- cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
453
- ret, snapshot_frame = cap.read()
454
- if ret:
455
- snapshot_frame = draw_detections(snapshot_frame, [best_detection])
456
- snapshot_filename = f"no_helmet_{best_detection['frame']}.jpg"
457
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
458
- cv2.imwrite(snapshot_path, snapshot_frame)
459
- snapshots.append({
460
- "violation": "no_helmet",
461
- "frame": best_detection["frame"],
462
- "snapshot_path": snapshot_path,
463
- "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
464
- })
465
- cap.release()
 
 
 
 
466
 
467
  # Generate results
468
  if not violations:
@@ -473,11 +430,11 @@ def process_video(video_data):
473
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
474
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
475
 
476
- violation_table = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
477
- violation_table += "|------------------------|---------------|------------|-----------|\n"
478
- for v in sorted(violations, key=lambda x: x["timestamp"]):
479
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
480
- row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
481
  violation_table += row
482
 
483
  snapshots_text = "\n".join(
 
4
  import torch
5
  import numpy as np
6
  from ultralytics import YOLO
7
+ from bytetrack import BYTETracker # Added ByteTrack
8
  import time
9
  from simple_salesforce import Salesforce
10
  from reportlab.lib.pagesizes import letter
 
61
  "improper_tool_use": 0.4
62
  },
63
  "MIN_VIOLATION_FRAMES": 3,
64
+ "WORKER_TRACKING_DURATION": 5.0, # Increased to 5s for better continuity
65
  "MAX_PROCESSING_TIME": 60, # 1 minute limit
66
+ "FRAME_SKIP": 1, # Process all frames to avoid missing violations
67
+ "BATCH_SIZE": 32, # Increased for performance
68
+ "PARALLEL_WORKERS": max(1, cpu_count() - 1), # Use all CPU cores except one
69
+ "TRACK_BUFFER": 30, # Frames to keep a track alive
70
+ "TRACK_THRESH": 0.4, # Tracking confidence threshold
71
+ "MATCH_THRESH": 0.8 # IOU threshold for matching
72
  }
73
 
74
  # Setup logging
 
120
  cv2.putText(frame, display_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
121
  return frame
122
 
123
+ def calculate_safety_score(violations):
124
+ penalties = {
125
+ "no_helmet": 25,
126
+ "no_harness": 30,
127
+ "unsafe_posture": 20,
128
+ "unsafe_zone": 35,
129
+ "improper_tool_use": 25
130
+ }
131
+ total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
132
+ score = 100 - total_penalty
133
+ return max(score, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  def generate_violation_pdf(violations, score):
136
  try:
 
160
  else:
161
  for v in violations:
162
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
163
+ text = f"{display_name} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})"
164
  c.drawString(1 * inch, y_position, text)
165
  y_position -= 0.3 * inch
166
  if y_position < 1 * inch:
 
181
  logger.error(f"Error generating PDF: {e}")
182
  return "", "", None
183
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # ==========================
185
  # Salesforce Integration
186
  # ==========================
 
223
  try:
224
  sf = connect_to_salesforce()
225
  violations_text = "\n".join(
226
+ f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} from {v.get('start_timestamp', 0.0):.2f}s to {v.get('end_timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f}, Worker ID: {v.get('worker_id', 'N/A')})"
227
  for v in violations
228
  ) or "No violations detected."
229
  pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
 
280
 
281
  # Get video properties
282
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
283
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
 
284
  duration = total_frames / fps
285
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
286
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
287
  logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")
288
 
289
+ # Initialize ByteTrack
290
+ tracker = BYTETracker(
291
+ track_thresh=CONFIG["TRACK_THRESH"],
292
+ track_buffer=CONFIG["TRACK_BUFFER"],
293
+ match_thresh=CONFIG["MATCH_THRESH"],
294
+ frame_rate=fps
295
+ )
296
+
297
+ # Track violations by worker ID and type
298
+ violation_tracker = {} # {worker_id: {violation_type: [detections]}}
299
  snapshots = []
300
  start_time = time.time()
301
  frame_skip = CONFIG["FRAME_SKIP"]
 
334
  for i, (result, frame_idx) in enumerate(zip(results, batch_indices)):
335
  current_time = frame_idx / fps
336
 
337
+ # Update progress
338
+ if time.time() - start_time > 1.0:
339
  progress = (frame_idx / total_frames) * 100
340
  yield f"Processing video... {progress:.1f}% complete (Frame {frame_idx}/{total_frames})", "", "", "", ""
341
  start_time = time.time()
342
 
343
+ # Prepare detections for ByteTrack
344
  boxes = result.boxes
345
+ track_inputs = []
346
  for box in boxes:
347
  cls = int(box.cls)
348
  conf = float(box.conf)
 
351
  if label is None or conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
352
  continue
353
 
354
+ bbox = box.xywh.cpu().numpy()[0]
355
+ track_inputs.append({
356
+ "bbox": bbox, # [x, y, w, h]
357
+ "conf": conf,
358
+ "cls": cls
359
+ })
360
+
361
+ # Update tracker
362
+ tracked_objects = tracker.update(
363
+ np.array([t["bbox"] for t in track_inputs]),
364
+ np.array([t["conf"] for t in track_inputs]),
365
+ np.array([t["cls"] for t in track_inputs])
366
+ )
367
+
368
+ # Process tracked objects
369
+ for obj, track_input in zip(tracked_objects, track_inputs):
370
+ worker_id = obj.id
371
+ label = CONFIG["VIOLATION_LABELS"].get(int(obj.cls), None)
372
+ bbox = track_input["bbox"]
373
+ conf = track_input["conf"]
374
+
375
  detection = {
376
  "frame": frame_idx,
377
  "violation": label,
378
  "confidence": round(conf, 2),
379
+ "bounding_box": [round(x, 2) for x in bbox],
380
+ "timestamp": current_time,
381
+ "worker_id": worker_id
382
  }
383
 
384
+ # Track violations by worker_id and type
385
+ if worker_id not in violation_tracker:
386
+ violation_tracker[worker_id] = {}
387
+ if label not in violation_tracker[worker_id]:
388
+ violation_tracker[worker_id][label] = []
389
+ violation_tracker[worker_id][label].append(detection)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  cap.release()
392
  os.remove(video_path)
393
  processing_time = time.time() - start_time
394
+ logger.info(f"Processing complete in {processing_time:.2f}s")
395
+
396
+ # Consolidate violations
397
+ violations = []
398
+ for worker_id, worker_violations in violation_tracker.items():
399
+ for label, detections in worker_violations.items():
400
+ if len(detections) >= CONFIG["MIN_VIOLATION_FRAMES"]:
401
+ # Select highest-confidence detection
402
+ best_detection = max(detections, key=lambda x: x["confidence"])
403
+ best_detection["start_timestamp"] = min(d["timestamp"] for d in detections)
404
+ best_detection["end_timestamp"] = max(d["timestamp"] for d in detections)
405
+ violations.append(best_detection)
406
+
407
+ # Capture snapshot for confirmed violation
408
+ cap = cv2.VideoCapture(video_path)
409
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
410
+ ret, snapshot_frame = cap.read()
411
+ if ret:
412
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
413
+ snapshot_filename = f"{label}_{best_detection['frame']}.jpg"
414
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
415
+ cv2.imwrite(snapshot_path, snapshot_frame)
416
+ snapshots.append({
417
+ "violation": label,
418
+ "frame": best_detection["frame"],
419
+ "snapshot_path": snapshot_path,
420
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
421
+ })
422
+ cap.release()
423
 
424
  # Generate results
425
  if not violations:
 
430
  pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
431
  report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
432
 
433
+ violation_table = "| Violation | Time Range (s) | Confidence | Worker ID |\n"
434
+ violation_table += "|------------------------|----------------|------------|-----------|\n"
435
+ for v in sorted(violations, key=lambda x: x["start_timestamp"]):
436
  display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
437
+ row = f"| {display_name:<22} | {v.get('start_timestamp', 0.0):.2f}-{v.get('end_timestamp', 0.0):.2f} | {v.get('confidence', 0.0):.2f} | {v.get('worker_id', 'N/A')} |\n"
438
  violation_table += row
439
 
440
  snapshots_text = "\n".join(