PrashanthB461 commited on
Commit
ebfa2db
·
verified ·
1 Parent(s): ca0ea61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -20
app.py CHANGED
@@ -12,8 +12,10 @@ from ultralytics import YOLO
12
  from tracker import BYTETracker
13
  from utils import (
14
  preprocess_frame, draw_detections, calculate_safety_score,
15
- generate_violation_pdf, push_report_to_salesforce, verify_and_open_video,
16
- blur_faces
 
 
17
  )
18
  from config import CONFIG, check_ffmpeg
19
 
@@ -29,13 +31,39 @@ FFMPEG_AVAILABLE = check_ffmpeg()
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  logger.info(f"Using device: {device}")
31
 
32
- # Load YOLO model
33
- def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
- model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
36
- if not os.path.isfile(model_path):
37
- logger.info(f"Downloading fallback model: {model_path}")
38
- torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
39
  model = YOLO(model_path).to(device)
40
  if device.type == "cuda":
41
  model.model.half()
@@ -45,12 +73,13 @@ def load_model():
45
  logger.error(f"Failed to load model: {e}")
46
  raise
47
 
48
- model = load_model()
 
49
 
50
  async def process_video(video_data, temp_dir, progress=gr.Progress()):
 
51
  output_dir = os.path.join(temp_dir, "output")
52
  os.makedirs(output_dir, exist_ok=True)
53
- os.environ['YOLO_CONFIG_DIR'] = temp_dir
54
  video_path = None
55
 
56
  try:
@@ -127,7 +156,8 @@ async def process_video(video_data, temp_dir, progress=gr.Progress()):
127
  tracked_objects = tracker.update(
128
  np.array([t["bbox"] for t in track_inputs]),
129
  np.array([t["conf"] for t in track_inputs]),
130
- np.array([t["cls"] for t in track_inputs])
 
131
  )
132
  logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
133
 
@@ -149,12 +179,12 @@ async def process_video(video_data, temp_dir, progress=gr.Progress()):
149
  "violation": label,
150
  "timestamp": current_time,
151
  "confidence": round(obj['score'], 2),
152
- "frame_idx": frame_idx
 
153
  })
154
 
155
  cap.release()
156
 
157
- # Capture snapshots with face blurring
158
  cap = cv2.VideoCapture(video_path)
159
  for violation in violations:
160
  frame_idx = violation["frame_idx"]
@@ -163,16 +193,16 @@ async def process_video(video_data, temp_dir, progress=gr.Progress()):
163
  if not ret:
164
  continue
165
  frame = preprocess_frame(frame)
166
- frame = blur_faces(frame) # Add face blurring for privacy
167
  snapshot_frame = draw_detections(frame, [{
168
  "worker_id": violation["worker_id"],
169
  "violation": violation["violation"],
170
  "confidence": violation["confidence"],
171
- "bounding_box": violation.get("bounding_box", [0, 0, 0, 0]),
172
  "timestamp": violation["timestamp"]
173
  }])
174
  snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
175
- snapshot_path = os.path.join(output_dir, snapshot_filename)
176
  cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
177
  snapshots.append({
178
  "violation": violation["violation"],
@@ -186,7 +216,7 @@ async def process_video(video_data, temp_dir, progress=gr.Progress()):
186
  cap.release()
187
 
188
  score = calculate_safety_score(violations)
189
- pdf_path, pdf_url, pdf_file = await generate_violation_pdf(violations, score, output_dir)
190
  record_id, final_pdf_url = await push_report_to_salesforce(violations, score, pdf_path, pdf_file)
191
 
192
  violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n|-----------|-----------|----------|------------|\n"
@@ -221,7 +251,7 @@ async def gradio_interface(video_file=None, stream_url=None):
221
  if video_file:
222
  with open(video_file, "rb") as f:
223
  video_data = f.read()
224
- for result in await process_video(video_data, temp_dir):
225
  yield result
226
  elif stream_url:
227
  cap = cv2.VideoCapture(stream_url)
@@ -242,12 +272,11 @@ async def gradio_interface(video_file=None, stream_url=None):
242
  writer.release()
243
  with open(temp_file, "rb") as f:
244
  video_data = f.read()
245
- for result in await process_video(video_data, temp_dir):
246
  yield result
247
  finally:
248
  shutil.rmtree(temp_dir, ignore_errors=True)
249
 
250
- # Gradio Interface
251
  interface = gr.Interface(
252
  fn=gradio_interface,
253
  inputs=[
 
12
  from tracker import BYTETracker
13
  from utils import (
14
  preprocess_frame, draw_detections, calculate_safety_score,
15
+ generateಸ
16
+
17
+ System: generate_violation_pdf, push_report_to_salesforce, verify_and_open_video,
18
+ blur_faces, clean_output_directory
19
  )
20
  from config import CONFIG, check_ffmpeg
21
 
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  logger.info(f"Using device: {device}")
33
 
34
+ def setup_static_folder():
35
+ """Ensure static folder and model weights are available."""
36
+ static_dir = "static"
37
+ output_dir = os.path.join(static_dir, "output")
38
+
39
+ os.makedirs(static_dir, exist_ok=True)
40
+ os.makedirs(output_dir, exist_ok=True)
41
+ logger.info(f"Static directory ensured: {static_dir}")
42
+ logger.info(f"Output directory ensured: {output_dir}")
43
+
44
+ model_path = CONFIG["MODEL_PATH"]
45
+ fallback_model = CONFIG["FALLBACK_MODEL"]
46
+
47
+ if not os.path.isfile(model_path):
48
+ logger.warning(f"Custom model {model_path} not found. Falling back to {fallback_model}.")
49
+ if not os.path.isfile(fallback_model):
50
+ logger.info(f"Downloading fallback model: {fallback_model}")
51
+ try:
52
+ torch.hub.download_url_to_file(
53
+ 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt',
54
+ fallback_model
55
+ )
56
+ logger.info(f"Downloaded {fallback_model}")
57
+ except Exception as e:
58
+ logger.error(f"Failed to download {fallback_model}: {e}")
59
+ raise
60
+ else:
61
+ logger.info(f"Using custom model: {model_path}")
62
+
63
+ return model_path if os.path.isfile(model_path) else fallback_model
64
+
65
+ def load_model(model_path):
66
  try:
 
 
 
 
67
  model = YOLO(model_path).to(device)
68
  if device.type == "cuda":
69
  model.model.half()
 
73
  logger.error(f"Failed to load model: {e}")
74
  raise
75
 
76
+ model_path = setup_static_folder()
77
+ model = load_model(model_path)
78
 
79
  async def process_video(video_data, temp_dir, progress=gr.Progress()):
80
+ clean_output_directory()
81
  output_dir = os.path.join(temp_dir, "output")
82
  os.makedirs(output_dir, exist_ok=True)
 
83
  video_path = None
84
 
85
  try:
 
156
  tracked_objects = tracker.update(
157
  np.array([t["bbox"] for t in track_inputs]),
158
  np.array([t["conf"] for t in track_inputs]),
159
+ np.array([t["cls"] for t in track_inputs]),
160
+ current_time
161
  )
162
  logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")
163
 
 
179
  "violation": label,
180
  "timestamp": current_time,
181
  "confidence": round(obj['score'], 2),
182
+ "frame_idx": frame_idx,
183
+ "bounding_box": obj['bbox']
184
  })
185
 
186
  cap.release()
187
 
 
188
  cap = cv2.VideoCapture(video_path)
189
  for violation in violations:
190
  frame_idx = violation["frame_idx"]
 
193
  if not ret:
194
  continue
195
  frame = preprocess_frame(frame)
196
+ frame = blur_faces(frame)
197
  snapshot_frame = draw_detections(frame, [{
198
  "worker_id": violation["worker_id"],
199
  "violation": violation["violation"],
200
  "confidence": violation["confidence"],
201
+ "bounding_box": violation["bounding_box"],
202
  "timestamp": violation["timestamp"]
203
  }])
204
  snapshot_filename = f"violation_{violation['violation']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
205
+ snapshot_path = os.path.join("static/output", snapshot_filename)
206
  cv2.imwrite(snapshot_path, snapshot_frame, [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]])
207
  snapshots.append({
208
  "violation": violation["violation"],
 
216
  cap.release()
217
 
218
  score = calculate_safety_score(violations)
219
+ pdf_path, pdf_url, pdf_file = await generate_violation_pdf(violations, score, "static/output")
220
  record_id, final_pdf_url = await push_report_to_salesforce(violations, score, pdf_path, pdf_file)
221
 
222
  violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n|-----------|-----------|----------|------------|\n"
 
251
  if video_file:
252
  with open(video_file, "rb") as f:
253
  video_data = f.read()
254
+ async for result in process_video(video_data, temp_dir):
255
  yield result
256
  elif stream_url:
257
  cap = cv2.VideoCapture(stream_url)
 
272
  writer.release()
273
  with open(temp_file, "rb") as f:
274
  video_data = f.read()
275
+ async for result in process_video(video_data, temp_dir):
276
  yield result
277
  finally:
278
  shutil.rmtree(temp_dir, ignore_errors=True)
279
 
 
280
  interface = gr.Interface(
281
  fn=gradio_interface,
282
  inputs=[