PrashanthB461 commited on
Commit
0ee122d
·
verified ·
1 Parent(s): 9377a06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -98
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import cv2
3
  import gradio as gr
4
  import torch
5
- import numpy as np
6
  from ultralytics import YOLO
7
  import time
8
  from reportlab.lib.pagesizes import letter
@@ -47,11 +46,6 @@ except Exception as e:
47
  # Video Processing
48
  # ==========================
49
  def process_video(video_data, frame_skip=10, max_frames=100):
50
- """
51
- Processes uploaded video data to detect safety violations using YOLO model.
52
- frame_skip: number of frames to skip between inferences to speed up processing.
53
- max_frames: max frames to process before stopping.
54
- """
55
  try:
56
  print("Processing video data...")
57
  video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
@@ -78,7 +72,6 @@ def process_video(video_data, frame_skip=10, max_frames=100):
78
  frame_count += 1
79
  continue
80
 
81
- # Model inference
82
  results = model(frame, device=device)
83
 
84
  for result in results:
@@ -97,20 +90,18 @@ def process_video(video_data, frame_skip=10, max_frames=100):
97
  }
98
  violations.append(violation)
99
 
100
- # Save snapshot locally with a filename pattern
101
  snapshot_filename = f"snapshot_{frame_count}_{label}.jpg"
102
  snapshot_path = os.path.join(OUTPUT_DIR, snapshot_filename)
103
  cv2.imwrite(snapshot_path, frame)
104
  snapshots.append({
105
  "violation": label,
106
  "frame": frame_count,
107
- "snapshot_url": snapshot_filename # just filename, local path
108
  })
109
 
110
  frame_count += 1
111
  processed_frame_count += 1
112
 
113
- # Stop if max frames or 30 seconds elapsed
114
  if processed_frame_count >= max_frames:
115
  break
116
  if time.time() - start_time > 30:
@@ -121,9 +112,9 @@ def process_video(video_data, frame_skip=10, max_frames=100):
121
  os.remove(video_path)
122
 
123
  score = calculate_safety_score(violations)
124
- pdf_base64 = generate_pdf_report(violations, snapshots, score)
125
 
126
- # Clean up snapshot images
127
  for snap in snapshots:
128
  snap_file = os.path.join(OUTPUT_DIR, snap["snapshot_url"])
129
  if os.path.exists(snap_file):
@@ -133,7 +124,7 @@ def process_video(video_data, frame_skip=10, max_frames=100):
133
  "violations": violations,
134
  "snapshots": snapshots,
135
  "score": score,
136
- "pdf_base64": pdf_base64
137
  }
138
 
139
  except Exception as e:
@@ -142,7 +133,7 @@ def process_video(video_data, frame_skip=10, max_frames=100):
142
  "violations": [],
143
  "snapshots": [],
144
  "score": 0,
145
- "pdf_base64": "",
146
  "error": str(e)
147
  }
148
 
@@ -166,98 +157,17 @@ def calculate_safety_score(violations):
166
  # ==========================
167
  def generate_pdf_report(violations, snapshots, score):
168
  try:
169
- pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf")
 
 
170
  c = canvas.Canvas(pdf_path, pagesize=letter)
171
  width, height = letter
172
 
173
- # Title
174
  c.setFont("Helvetica-Bold", 16)
175
  c.drawString(50, height - 50, "Worksite Safety Compliance Report")
176
 
177
- # Compliance Score
178
  c.setFont("Helvetica", 12)
179
  c.drawString(50, height - 80, f"Compliance Score: {score}%")
180
 
181
- # Violations Table header
182
  y = height - 120
183
  c.setFont("Helvetica-Bold", 12)
184
- c.drawString(50, y, "Detected Violations:")
185
- y -= 20
186
-
187
- for v in violations:
188
- if y < 150:
189
- c.showPage()
190
- y = height - 50
191
-
192
- c.setFont("Helvetica", 10)
193
- text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
194
- c.drawString(50, y, text)
195
- y -= 20
196
-
197
- # Find matching snapshot by frame and violation label
198
- snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
199
- if snapshot:
200
- snapshot_file = os.path.join(OUTPUT_DIR, snapshot["snapshot_url"])
201
- if os.path.exists(snapshot_file):
202
- img = ImageReader(snapshot_file)
203
- img_width = 200
204
- img_height = 150
205
- if y - img_height < 50:
206
- c.showPage()
207
- y = height - 50
208
- c.drawImage(img, 50, y - img_height, width=img_width, height=img_height)
209
- y -= img_height + 20
210
-
211
- c.save()
212
- print(f"PDF generated at {pdf_path}")
213
-
214
- # Convert PDF to base64 for returning to frontend
215
- with open(pdf_path, "rb") as f:
216
- pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
217
-
218
- # Clean up the PDF file after encoding
219
- os.remove(pdf_path)
220
- print("PDF converted to base64 and file removed")
221
- return pdf_base64
222
- except Exception as e:
223
- print(f"❌ Error generating PDF: {e}")
224
- return ""
225
-
226
- # ==========================
227
- # Gradio Interface
228
- # ==========================
229
- def gradio_interface(video_file):
230
- try:
231
- if not video_file:
232
- return {"error": "Please upload a video file."}, "", "", []
233
-
234
- with open(video_file, "rb") as f:
235
- video_data = f.read()
236
-
237
- result = process_video(video_data)
238
- return (
239
- result["violations"],
240
- f"Safety Score: {result['score']}%",
241
- result["pdf_base64"],
242
- result["snapshots"]
243
- )
244
- except Exception as e:
245
- print(f"❌ Error in gradio_interface: {e}")
246
- return {"error": str(e)}, "", "", []
247
-
248
- interface = gr.Interface(
249
- fn=gradio_interface,
250
- inputs=gr.Video(label="Upload Site Video"),
251
- outputs=[
252
- gr.JSON(label="Detected Safety Violations"),
253
- gr.Textbox(label="Compliance Score"),
254
- gr.File(label="Download PDF Report"), # Changed from Textbox to File for PDF download
255
- gr.JSON(label="Snapshots")
256
- ],
257
- title="Worksite Safety Violation Analyzer",
258
- description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
259
- )
260
-
261
- if __name__ == "__main__":
262
- print("🚀 Launching Safety Analyzer App...")
263
- interface.launch()
 
2
  import cv2
3
  import gradio as gr
4
  import torch
 
5
  from ultralytics import YOLO
6
  import time
7
  from reportlab.lib.pagesizes import letter
 
46
  # Video Processing
47
  # ==========================
48
  def process_video(video_data, frame_skip=10, max_frames=100):
 
 
 
 
 
49
  try:
50
  print("Processing video data...")
51
  video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
 
72
  frame_count += 1
73
  continue
74
 
 
75
  results = model(frame, device=device)
76
 
77
  for result in results:
 
90
  }
91
  violations.append(violation)
92
 
 
93
  snapshot_filename = f"snapshot_{frame_count}_{label}.jpg"
94
  snapshot_path = os.path.join(OUTPUT_DIR, snapshot_filename)
95
  cv2.imwrite(snapshot_path, frame)
96
  snapshots.append({
97
  "violation": label,
98
  "frame": frame_count,
99
+ "snapshot_url": snapshot_filename
100
  })
101
 
102
  frame_count += 1
103
  processed_frame_count += 1
104
 
 
105
  if processed_frame_count >= max_frames:
106
  break
107
  if time.time() - start_time > 30:
 
112
  os.remove(video_path)
113
 
114
  score = calculate_safety_score(violations)
115
+ pdf_url = generate_pdf_report(violations, snapshots, score)
116
 
117
+ # Clean up snapshots
118
  for snap in snapshots:
119
  snap_file = os.path.join(OUTPUT_DIR, snap["snapshot_url"])
120
  if os.path.exists(snap_file):
 
124
  "violations": violations,
125
  "snapshots": snapshots,
126
  "score": score,
127
+ "pdf_url": pdf_url
128
  }
129
 
130
  except Exception as e:
 
133
  "violations": [],
134
  "snapshots": [],
135
  "score": 0,
136
+ "pdf_url": "",
137
  "error": str(e)
138
  }
139
 
 
157
  # ==========================
158
  def generate_pdf_report(violations, snapshots, score):
159
  try:
160
+ timestamp = int(time.time())
161
+ pdf_filename = f"report_{timestamp}.pdf"
162
+ pdf_path = os.path.join(OUTPUT_DIR, pdf_filename)
163
  c = canvas.Canvas(pdf_path, pagesize=letter)
164
  width, height = letter
165
 
 
166
  c.setFont("Helvetica-Bold", 16)
167
  c.drawString(50, height - 50, "Worksite Safety Compliance Report")
168
 
 
169
  c.setFont("Helvetica", 12)
170
  c.drawString(50, height - 80, f"Compliance Score: {score}%")
171
 
 
172
  y = height - 120
173
  c.setFont("Helvetica-Bold", 12)