PrashanthB461 commited on
Commit
9377a06
·
verified ·
1 Parent(s): dc66d57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -22
app.py CHANGED
@@ -9,7 +9,6 @@ from reportlab.lib.pagesizes import letter
9
  from reportlab.pdfgen import canvas
10
  from reportlab.lib.utils import ImageReader
11
  import base64
12
- from PIL import Image
13
 
14
  # ==========================
15
  # Configuration
@@ -47,10 +46,14 @@ except Exception as e:
47
  # ==========================
48
  # Video Processing
49
  # ==========================
50
- def process_video(video_data, frame_skip=5, max_frames=100):
 
 
 
 
 
51
  try:
52
  print("Processing video data...")
53
- # Save uploaded video data to a temporary file
54
  video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
55
  with open(video_path, "wb") as f:
56
  f.write(video_data)
@@ -94,21 +97,22 @@ def process_video(video_data, frame_skip=5, max_frames=100):
94
  }
95
  violations.append(violation)
96
 
97
- # Save snapshot
98
- snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg")
 
99
  cv2.imwrite(snapshot_path, frame)
100
  snapshots.append({
101
  "violation": label,
102
  "frame": frame_count,
103
- "snapshot_url": f"https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/output/{os.path.basename(snapshot_path)}"
104
  })
105
 
106
  frame_count += 1
107
  processed_frame_count += 1
108
 
 
109
  if processed_frame_count >= max_frames:
110
  break
111
-
112
  if time.time() - start_time > 30:
113
  print("⏰ Exceeded 30 seconds of processing time.")
114
  break
@@ -119,6 +123,12 @@ def process_video(video_data, frame_skip=5, max_frames=100):
119
  score = calculate_safety_score(violations)
120
  pdf_base64 = generate_pdf_report(violations, snapshots, score)
121
 
 
 
 
 
 
 
122
  return {
123
  "violations": violations,
124
  "snapshots": snapshots,
@@ -168,37 +178,44 @@ def generate_pdf_report(violations, snapshots, score):
168
  c.setFont("Helvetica", 12)
169
  c.drawString(50, height - 80, f"Compliance Score: {score}%")
170
 
171
- # Violations Table
172
  y = height - 120
173
  c.setFont("Helvetica-Bold", 12)
174
  c.drawString(50, y, "Detected Violations:")
175
  y -= 20
176
 
177
  for v in violations:
 
 
 
 
178
  c.setFont("Helvetica", 10)
179
  text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
180
  c.drawString(50, y, text)
181
  y -= 20
182
 
183
- # Add snapshot if available
184
  snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
185
- if snapshot and os.path.exists(snapshot["snapshot_url"].split('/')[-1]):
186
- img = ImageReader(snapshot["snapshot_url"].split('/')[-1])
187
- c.drawImage(img, 50, y - 100, width=200, height=150)
188
- y -= 170
189
-
190
- if y < 50:
191
- c.showPage()
192
- y = height - 50
 
 
 
193
 
194
  c.save()
195
  print(f"PDF generated at {pdf_path}")
196
 
197
- # Convert PDF to base64
198
  with open(pdf_path, "rb") as f:
199
  pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
200
-
201
- # Clean up
202
  os.remove(pdf_path)
203
  print("PDF converted to base64 and file removed")
204
  return pdf_base64
@@ -234,7 +251,7 @@ interface = gr.Interface(
234
  outputs=[
235
  gr.JSON(label="Detected Safety Violations"),
236
  gr.Textbox(label="Compliance Score"),
237
- gr.Textbox(label="PDF Base64 (for API use)"),
238
  gr.JSON(label="Snapshots")
239
  ],
240
  title="Worksite Safety Violation Analyzer",
@@ -243,4 +260,4 @@ interface = gr.Interface(
243
 
244
  if __name__ == "__main__":
245
  print("🚀 Launching Safety Analyzer App...")
246
- interface.launch()
 
9
  from reportlab.pdfgen import canvas
10
  from reportlab.lib.utils import ImageReader
11
  import base64
 
12
 
13
  # ==========================
14
  # Configuration
 
46
  # ==========================
47
  # Video Processing
48
  # ==========================
49
+ def process_video(video_data, frame_skip=10, max_frames=100):
50
+ """
51
+ Processes uploaded video data to detect safety violations using YOLO model.
52
+ frame_skip: number of frames to skip between inferences to speed up processing.
53
+ max_frames: max frames to process before stopping.
54
+ """
55
  try:
56
  print("Processing video data...")
 
57
  video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
58
  with open(video_path, "wb") as f:
59
  f.write(video_data)
 
97
  }
98
  violations.append(violation)
99
 
100
+ # Save snapshot locally with a filename pattern
101
+ snapshot_filename = f"snapshot_{frame_count}_{label}.jpg"
102
+ snapshot_path = os.path.join(OUTPUT_DIR, snapshot_filename)
103
  cv2.imwrite(snapshot_path, frame)
104
  snapshots.append({
105
  "violation": label,
106
  "frame": frame_count,
107
+ "snapshot_url": snapshot_filename # just filename, local path
108
  })
109
 
110
  frame_count += 1
111
  processed_frame_count += 1
112
 
113
+ # Stop if max frames or 30 seconds elapsed
114
  if processed_frame_count >= max_frames:
115
  break
 
116
  if time.time() - start_time > 30:
117
  print("⏰ Exceeded 30 seconds of processing time.")
118
  break
 
123
  score = calculate_safety_score(violations)
124
  pdf_base64 = generate_pdf_report(violations, snapshots, score)
125
 
126
+ # Clean up snapshot images
127
+ for snap in snapshots:
128
+ snap_file = os.path.join(OUTPUT_DIR, snap["snapshot_url"])
129
+ if os.path.exists(snap_file):
130
+ os.remove(snap_file)
131
+
132
  return {
133
  "violations": violations,
134
  "snapshots": snapshots,
 
178
  c.setFont("Helvetica", 12)
179
  c.drawString(50, height - 80, f"Compliance Score: {score}%")
180
 
181
+ # Violations Table header
182
  y = height - 120
183
  c.setFont("Helvetica-Bold", 12)
184
  c.drawString(50, y, "Detected Violations:")
185
  y -= 20
186
 
187
  for v in violations:
188
+ if y < 150:
189
+ c.showPage()
190
+ y = height - 50
191
+
192
  c.setFont("Helvetica", 10)
193
  text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
194
  c.drawString(50, y, text)
195
  y -= 20
196
 
197
+ # Find matching snapshot by frame and violation label
198
  snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
199
+ if snapshot:
200
+ snapshot_file = os.path.join(OUTPUT_DIR, snapshot["snapshot_url"])
201
+ if os.path.exists(snapshot_file):
202
+ img = ImageReader(snapshot_file)
203
+ img_width = 200
204
+ img_height = 150
205
+ if y - img_height < 50:
206
+ c.showPage()
207
+ y = height - 50
208
+ c.drawImage(img, 50, y - img_height, width=img_width, height=img_height)
209
+ y -= img_height + 20
210
 
211
  c.save()
212
  print(f"PDF generated at {pdf_path}")
213
 
214
+ # Convert PDF to base64 for returning to frontend
215
  with open(pdf_path, "rb") as f:
216
  pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
217
+
218
+ # Clean up the PDF file after encoding
219
  os.remove(pdf_path)
220
  print("PDF converted to base64 and file removed")
221
  return pdf_base64
 
251
  outputs=[
252
  gr.JSON(label="Detected Safety Violations"),
253
  gr.Textbox(label="Compliance Score"),
254
+ gr.File(label="Download PDF Report"), # Changed from Textbox to File for PDF download
255
  gr.JSON(label="Snapshots")
256
  ],
257
  title="Worksite Safety Violation Analyzer",
 
260
 
261
  if __name__ == "__main__":
262
  print("🚀 Launching Safety Analyzer App...")
263
+ interface.launch()