PrashanthB461 commited on
Commit
7d35a2c
·
verified ·
1 Parent(s): 6fafb13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -53
app.py CHANGED
@@ -18,7 +18,7 @@ from PIL import Image
18
  DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
19
  FALLBACK_MODEL = "yolov8n.pt"
20
  MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
21
- OUTPUT_DIR = "output" # Directory to store snapshots and PDFs
22
  os.makedirs(OUTPUT_DIR, exist_ok=True)
23
 
24
  VIOLATION_LABELS = {
@@ -37,18 +37,25 @@ print(f"✅ Using device: {device}")
37
  # ==========================
38
  # Load Model
39
  # ==========================
40
- selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
41
- model = YOLO(selected_model)
 
 
 
 
 
42
 
43
  # ==========================
44
  # Video Processing
45
  # ==========================
46
  def process_video(video_data, frame_skip=5, max_frames=100):
47
  try:
 
48
  # Save uploaded video data to a temporary file
49
  video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
50
  with open(video_path, "wb") as f:
51
  f.write(video_data)
 
52
 
53
  video = cv2.VideoCapture(video_path)
54
  if not video.isOpened():
@@ -94,7 +101,7 @@ def process_video(video_data, frame_skip=5, max_frames=100):
94
  snapshots.append({
95
  "violation": label,
96
  "frame": frame_count,
97
- "snapshot_url": snapshot_path
98
  })
99
 
100
  frame_count += 1
@@ -108,7 +115,7 @@ def process_video(video_data, frame_skip=5, max_frames=100):
108
  break
109
 
110
  video.release()
111
- os.remove(video_path) # Clean up temporary video file
112
 
113
  score = calculate_safety_score(violations)
114
  pdf_report_path = generate_pdf_report(violations, snapshots, score)
@@ -149,61 +156,74 @@ def calculate_safety_score(violations):
149
  # PDF Report Generation
150
  # ==========================
151
  def generate_pdf_report(violations, snapshots, score):
152
- pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf")
153
- c = canvas.Canvas(pdf_path, pagesize=letter)
154
- width, height = letter
155
-
156
- # Title
157
- c.setFont("Helvetica-Bold", 16)
158
- c.drawString(50, height - 50, "Worksite Safety Compliance Report")
159
-
160
- # Compliance Score
161
- c.setFont("Helvetica", 12)
162
- c.drawString(50, height - 80, f"Compliance Score: {score}%")
163
-
164
- # Violations Table
165
- y = height - 120
166
- c.setFont("Helvetica-Bold", 12)
167
- c.drawString(50, y, "Detected Violations:")
168
- y -= 20
169
-
170
- for v in violations:
171
- c.setFont("Helvetica", 10)
172
- text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
173
- c.drawString(50, y, text)
174
  y -= 20
175
 
176
- # Add snapshot if available
177
- snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
178
- if snapshot and os.path.exists(snapshot["snapshot_url"]):
179
- img = ImageReader(snapshot["snapshot_url"])
180
- c.drawImage(img, 50, y - 100, width=200, height=150)
181
- y -= 170
182
-
183
- if y < 50:
184
- c.showPage()
185
- y = height - 50
186
-
187
- c.save()
188
- return pdf_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # ==========================
191
  # Gradio Interface
192
  # ==========================
193
  def gradio_interface(video_file):
194
- if not video_file:
195
- return {"error": "Please upload a video file."}, "", ""
196
-
197
- with open(video_file, "rb") as f:
198
- video_data = f.read()
199
-
200
- result = process_video(video_data)
201
- return (
202
- result["violations"],
203
- f"Safety Score: {result['score']}%",
204
- result["pdf_report_url"],
205
- result["snapshots"]
206
- )
 
 
 
 
207
 
208
  interface = gr.Interface(
209
  fn=gradio_interface,
 
18
  DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
19
  FALLBACK_MODEL = "yolov8n.pt"
20
  MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
21
+ OUTPUT_DIR = "output"
22
  os.makedirs(OUTPUT_DIR, exist_ok=True)
23
 
24
  VIOLATION_LABELS = {
 
37
  # ==========================
38
  # Load Model
39
  # ==========================
40
+ try:
41
+ selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
42
+ model = YOLO(selected_model)
43
+ print(f"✅ Model loaded: {selected_model}")
44
+ except Exception as e:
45
+ print(f"❌ Failed to load model: {e}")
46
+ raise
47
 
48
  # ==========================
49
  # Video Processing
50
  # ==========================
51
  def process_video(video_data, frame_skip=5, max_frames=100):
52
  try:
53
+ print("Processing video data...")
54
  # Save uploaded video data to a temporary file
55
  video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
56
  with open(video_path, "wb") as f:
57
  f.write(video_data)
58
+ print(f"Video saved to {video_path}")
59
 
60
  video = cv2.VideoCapture(video_path)
61
  if not video.isOpened():
 
101
  snapshots.append({
102
  "violation": label,
103
  "frame": frame_count,
104
+ "snapshot_url": f"https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/output/{os.path.basename(snapshot_path)}"
105
  })
106
 
107
  frame_count += 1
 
115
  break
116
 
117
  video.release()
118
+ os.remove(video_path)
119
 
120
  score = calculate_safety_score(violations)
121
  pdf_report_path = generate_pdf_report(violations, snapshots, score)
 
156
  # PDF Report Generation
157
  # ==========================
158
  def generate_pdf_report(violations, snapshots, score):
159
+ try:
160
+ pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf")
161
+ c = canvas.Canvas(pdf_path, pagesize=letter)
162
+ width, height = letter
163
+
164
+ # Title
165
+ c.setFont("Helvetica-Bold", 16)
166
+ c.drawString(50, height - 50, "Worksite Safety Compliance Report")
167
+
168
+ # Compliance Score
169
+ c.setFont("Helvetica", 12)
170
+ c.drawString(50, height - 80, f"Compliance Score: {score}%")
171
+
172
+ # Violations Table
173
+ y = height - 120
174
+ c.setFont("Helvetica-Bold", 12)
175
+ c.drawString(50, y, "Detected Violations:")
 
 
 
 
 
176
  y -= 20
177
 
178
+ for v in violations:
179
+ c.setFont("Helvetica", 10)
180
+ text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
181
+ c.drawString(50, y, text)
182
+ y -= 20
183
+
184
+ # Add snapshot if available
185
+ snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
186
+ if snapshot and os.path.exists(snapshot["snapshot_url"].split('/')[-1]):
187
+ img = ImageReader(snapshot["snapshot_url"].split('/')[-1])
188
+ c.drawImage(img, 50, y - 100, width=200, height=150)
189
+ y -= 170
190
+
191
+ if y < 50:
192
+ c.showPage()
193
+ y = height - 50
194
+
195
+ c.save()
196
+ print(f"PDF generated at {pdf_path}")
197
+ # Return a publicly accessible URL
198
+ base_url = "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1"
199
+ pdf_url = f"{base_url}/output/{os.path.basename(pdf_path)}"
200
+ print(f"PDF URL: {pdf_url}")
201
+ return pdf_url
202
+ except Exception as e:
203
+ print(f"❌ Error generating PDF: {e}")
204
+ return ""
205
 
206
  # ==========================
207
  # Gradio Interface
208
  # ==========================
209
  def gradio_interface(video_file):
210
+ try:
211
+ if not video_file:
212
+ return {"error": "Please upload a video file."}, "", "", []
213
+
214
+ with open(video_file, "rb") as f:
215
+ video_data = f.read()
216
+
217
+ result = process_video(video_data)
218
+ return (
219
+ result["violations"],
220
+ f"Safety Score: {result['score']}%",
221
+ result["pdf_report_url"],
222
+ result["snapshots"]
223
+ )
224
+ except Exception as e:
225
+ print(f"❌ Error in gradio_interface: {e}")
226
+ return {"error": str(e)}, "", "", []
227
 
228
  interface = gr.Interface(
229
  fn=gradio_interface,