PrashanthB461 commited on
Commit
0b5f150
·
verified ·
1 Parent(s): 366a953

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -49,7 +49,9 @@ CONFIG = {
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
- "FRAME_SKIP": 10, # Increased to process every 10th frame (faster)
 
 
53
  "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
54
  "CONFIDENCE_THRESHOLD": { # Per-class thresholds
55
  "no_helmet": 0.4,
@@ -58,8 +60,7 @@ CONFIG = {
58
  "unsafe_zone": 0.3,
59
  "improper_tool_use": 0.35
60
  },
61
- "IOU_THRESHOLD": 0.4, # For worker tracking
62
- "MIN_VIOLATION_FRAMES": 2 # Reduced to 2 frames for faster confirmation
63
  }
64
 
65
  # Setup logging
@@ -106,18 +107,6 @@ def draw_detections(frame, detections):
106
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
107
  return frame
108
 
109
- def calculate_iou(box1, box2):
110
- """Compute Intersection-over-Union for tracking."""
111
- x1, y1, w1, h1 = box1
112
- x2, y2, w2, h2 = box2
113
- x_min = max(x1 - w1/2, x2 - w2/2)
114
- y_min = max(y1 - h1/2, y2 - h2/2)
115
- x_max = min(x1 + w1/2, x2 + w2/2)
116
- y_max = min(y1 + h1/2, y2 + h2/2)
117
- intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
118
- union = w1 * h1 + w2 * h2 - intersection
119
- return intersection / union if union > 0 else 0
120
-
121
  # ==========================
122
  # Salesforce Integration
123
  # ==========================
@@ -235,9 +224,9 @@ def process_video(video_path):
235
  cap = cv2.VideoCapture(video_path)
236
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
237
  frame_count = 0
 
238
  violations = []
239
  snapshots = []
240
- workers = []
241
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
242
  frames_to_save = [] # Store frames for snapshot saving later
243
 
@@ -246,10 +235,17 @@ def process_video(video_path):
246
  if not ret:
247
  break
248
 
 
 
 
 
249
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
250
  frame_count += 1
251
  continue
252
 
 
 
 
253
  # Run detection
254
  results = model(frame, device=device)
255
  current_time = frame_count / fps
@@ -272,10 +268,8 @@ def process_video(video_path):
272
  "timestamp": current_time
273
  }
274
 
275
- # Simplified worker tracking
276
- worker_id = len(workers) + 1 # Assign new ID without IoU for speed
277
- workers.append({"id": worker_id, "bbox": bbox})
278
- detection["worker_id"] = worker_id
279
  violations.append(detection)
280
 
281
  # Store frame for snapshot if first detection of this type
@@ -288,6 +282,7 @@ def process_video(video_path):
288
  snapshot_taken[label] = True
289
 
290
  frame_count += 1
 
291
 
292
  cap.release()
293
 
 
49
  "domain": "login"
50
  },
51
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 15, # Increased to process every 15th frame
53
+ "MAX_FRAMES": 100, # Limit the number of frames to process
54
+ "FRAME_RESIZE": (640, 480), # Downscale frames to this resolution
55
  "MAX_PROCESSING_TIME": 60, # Max processing time (seconds)
56
  "CONFIDENCE_THRESHOLD": { # Per-class thresholds
57
  "no_helmet": 0.4,
 
60
  "unsafe_zone": 0.3,
61
  "improper_tool_use": 0.35
62
  },
63
+ "MIN_VIOLATION_FRAMES": 2 # Min frames to confirm a violation
 
64
  }
65
 
66
  # Setup logging
 
107
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
108
  return frame
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # ==========================
111
  # Salesforce Integration
112
  # ==========================
 
224
  cap = cv2.VideoCapture(video_path)
225
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
226
  frame_count = 0
227
+ processed_frames = 0
228
  violations = []
229
  snapshots = []
 
230
  snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
231
  frames_to_save = [] # Store frames for snapshot saving later
232
 
 
235
  if not ret:
236
  break
237
 
238
+ # Stop if max frames reached
239
+ if processed_frames >= CONFIG["MAX_FRAMES"]:
240
+ break
241
+
242
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
243
  frame_count += 1
244
  continue
245
 
246
+ # Downscale frame for faster inference
247
+ frame = cv2.resize(frame, CONFIG["FRAME_RESIZE"], interpolation=cv2.INTER_AREA)
248
+
249
  # Run detection
250
  results = model(frame, device=device)
251
  current_time = frame_count / fps
 
268
  "timestamp": current_time
269
  }
270
 
271
+ # Assign a generic worker ID (skipping tracking for speed)
272
+ detection["worker_id"] = f"Worker_{frame_count}"
 
 
273
  violations.append(detection)
274
 
275
  # Store frame for snapshot if first detection of this type
 
282
  snapshot_taken[label] = True
283
 
284
  frame_count += 1
285
+ processed_frames += 1
286
 
287
  cap.release()
288