PrashanthB461 commited on
Commit
efdcd31
·
verified ·
1 Parent(s): 0ee122d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -133
app.py CHANGED
@@ -1,173 +1,275 @@
 
1
  import os
2
  import cv2
3
  import gradio as gr
4
  import torch
 
5
  from ultralytics import YOLO
6
  import time
7
  from reportlab.lib.pagesizes import letter
8
  from reportlab.pdfgen import canvas
9
  from reportlab.lib.utils import ImageReader
10
  import base64
 
11
 
12
  # ==========================
 
13
  # Configuration
 
14
  # ==========================
15
- DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
16
- FALLBACK_MODEL = "yolov8n.pt"
17
- MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
18
- OUTPUT_DIR = "output"
19
- os.makedirs(OUTPUT_DIR, exist_ok=True)
20
-
21
- VIOLATION_LABELS = {
22
- 0: "no_helmet",
23
- 1: "no_harness",
24
- 2: "unsafe_posture",
25
- 3: "unsafe_zone"
 
26
  }
27
 
28
  # ==========================
 
29
  # Device Setup
 
30
  # ==========================
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
  print(f"✅ Using device: {device}")
33
 
34
  # ==========================
 
35
  # Load Model
 
36
  # ==========================
 
37
  try:
38
- selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
39
- model = YOLO(selected_model)
40
- print(f"✅ Model loaded: {selected_model}")
41
  except Exception as e:
42
- print(f"❌ Failed to load model: {e}")
43
- raise
44
 
45
  # ==========================
 
46
  # Video Processing
 
47
  # ==========================
48
- def process_video(video_data, frame_skip=10, max_frames=100):
49
- try:
50
- print("Processing video data...")
51
- video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
52
- with open(video_path, "wb") as f:
53
- f.write(video_data)
54
- print(f"Video saved to {video_path}")
55
-
56
- video = cv2.VideoCapture(video_path)
57
- if not video.isOpened():
58
- raise ValueError("Could not open video file.")
59
-
60
- frame_count = 0
61
- violations = []
62
- snapshots = []
63
- processed_frame_count = 0
64
- start_time = time.time()
65
-
66
- while True:
67
- ret, frame = video.read()
68
- if not ret:
69
- break
70
-
71
- if frame_count % frame_skip != 0:
72
- frame_count += 1
73
- continue
74
-
75
- results = model(frame, device=device)
76
-
77
- for result in results:
78
- for box in result.boxes:
79
- cls = int(box.cls)
80
- conf = float(box.conf)
81
- xywh = box.xywh.cpu().numpy()[0]
82
-
83
- label = VIOLATION_LABELS.get(cls, f"class_{cls}")
84
- violation = {
85
- "frame": frame_count,
86
- "violation": label,
87
- "confidence": round(conf, 2),
88
- "bounding_box": [round(x, 2) for x in xywh],
89
- "timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
90
- }
91
- violations.append(violation)
92
-
93
- snapshot_filename = f"snapshot_{frame_count}_{label}.jpg"
94
- snapshot_path = os.path.join(OUTPUT_DIR, snapshot_filename)
95
- cv2.imwrite(snapshot_path, frame)
96
- snapshots.append({
97
- "violation": label,
98
- "frame": frame_count,
99
- "snapshot_url": snapshot_filename
100
- })
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  frame_count += 1
103
- processed_frame_count += 1
104
-
105
- if processed_frame_count >= max_frames:
106
- break
107
- if time.time() - start_time > 30:
108
- print("⏰ Exceeded 30 seconds of processing time.")
109
- break
110
-
111
- video.release()
112
- os.remove(video_path)
113
-
114
- score = calculate_safety_score(violations)
115
- pdf_url = generate_pdf_report(violations, snapshots, score)
116
-
117
- # Clean up snapshots
118
- for snap in snapshots:
119
- snap_file = os.path.join(OUTPUT_DIR, snap["snapshot_url"])
120
- if os.path.exists(snap_file):
121
- os.remove(snap_file)
122
-
123
- return {
124
- "violations": violations,
125
- "snapshots": snapshots,
126
- "score": score,
127
- "pdf_url": pdf_url
128
- }
129
-
130
- except Exception as e:
131
- print(f"❌ Error processing video: {e}")
132
- return {
133
- "violations": [],
134
- "snapshots": [],
135
- "score": 0,
136
- "pdf_url": "",
137
- "error": str(e)
138
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # ==========================
 
141
  # Safety Score Calculation
 
142
  # ==========================
143
- def calculate_safety_score(violations):
144
- base_score = 100
145
- penalties = {
146
- "no_helmet": 25,
147
- "no_harness": 30,
148
- "unsafe_posture": 20,
149
- "unsafe_zone": 25
150
- }
151
- for v in violations:
152
- base_score -= penalties.get(v["violation"], 0)
153
- return max(base_score, 0)
 
154
 
155
  # ==========================
 
156
  # PDF Report Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  # ==========================
158
- def generate_pdf_report(violations, snapshots, score):
159
- try:
160
- timestamp = int(time.time())
161
- pdf_filename = f"report_{timestamp}.pdf"
162
- pdf_path = os.path.join(OUTPUT_DIR, pdf_filename)
163
- c = canvas.Canvas(pdf_path, pagesize=letter)
164
- width, height = letter
165
-
166
- c.setFont("Helvetica-Bold", 16)
167
- c.drawString(50, height - 50, "Worksite Safety Compliance Report")
168
-
169
- c.setFont("Helvetica", 12)
170
- c.drawString(50, height - 80, f"Compliance Score: {score}%")
171
-
172
- y = height - 120
173
- c.setFont("Helvetica-Bold", 12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Here is my app.py code
2
  import os
3
  import cv2
4
  import gradio as gr
5
  import torch
6
+ import numpy as np
7
  from ultralytics import YOLO
8
  import time
9
  from reportlab.lib.pagesizes import letter
10
  from reportlab.pdfgen import canvas
11
  from reportlab.lib.utils import ImageReader
12
  import base64
13
+ from PIL import Image
14
 
15
  # ==========================
16
+
17
  # Configuration
18
+
19
  # ==========================
20
+
21
+ DEFAULT\_MODEL\_PATH = "models/yolov8\_safety.pt"
22
+ FALLBACK\_MODEL = "yolov8n.pt"
23
+ MODEL\_PATH = os.getenv("SAFETY\_MODEL\_PATH", DEFAULT\_MODEL\_PATH)
24
+ OUTPUT\_DIR = "output"
25
+ os.makedirs(OUTPUT\_DIR, exist\_ok=True)
26
+
27
+ VIOLATION\_LABELS = {
28
+ 0: "no\_helmet",
29
+ 1: "no\_harness",
30
+ 2: "unsafe\_posture",
31
+ 3: "unsafe\_zone"
32
  }
33
 
34
  # ==========================
35
+
36
  # Device Setup
37
+
38
  # ==========================
39
+
40
+ device = torch.device("cuda" if torch.cuda.is\_available() else "cpu")
41
  print(f"✅ Using device: {device}")
42
 
43
  # ==========================
44
+
45
  # Load Model
46
+
47
  # ==========================
48
+
49
  try:
50
+ selected\_model = MODEL\_PATH if os.path.isfile(MODEL\_PATH) else FALLBACK\_MODEL
51
+ model = YOLO(selected\_model)
52
+ print(f"✅ Model loaded: {selected\_model}")
53
  except Exception as e:
54
+ print(f"❌ Failed to load model: {e}")
55
+ raise
56
 
57
  # ==========================
58
+
59
  # Video Processing
60
+
61
  # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def process\_video(video\_data, frame\_skip=5, max\_frames=100):
64
+ try:
65
+ print("Processing video data...")
66
+ \# Save uploaded video data to a temporary file
67
+ video\_path = os.path.join(OUTPUT\_DIR, f"temp\_{int(time.time())}.mp4")
68
+ with open(video\_path, "wb") as f:
69
+ f.write(video\_data)
70
+ print(f"Video saved to {video\_path}")
71
+
72
+ ```
73
+ video = cv2.VideoCapture(video_path)
74
+ if not video.isOpened():
75
+ raise ValueError("Could not open video file.")
76
+
77
+ frame_count = 0
78
+ violations = []
79
+ snapshots = []
80
+ processed_frame_count = 0
81
+ start_time = time.time()
82
+
83
+ while True:
84
+ ret, frame = video.read()
85
+ if not ret:
86
+ break
87
+
88
+ if frame_count % frame_skip != 0:
89
  frame_count += 1
90
+ continue
91
+
92
+ # Model inference
93
+ results = model(frame, device=device)
94
+
95
+ for result in results:
96
+ for box in result.boxes:
97
+ cls = int(box.cls)
98
+ conf = float(box.conf)
99
+ xywh = box.xywh.cpu().numpy()[0]
100
+
101
+ label = VIOLATION_LABELS.get(cls, f"class_{cls}")
102
+ violation = {
103
+ "frame": frame_count,
104
+ "violation": label,
105
+ "confidence": round(conf, 2),
106
+ "bounding_box": [round(x, 2) for x in xywh],
107
+ "timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
108
+ }
109
+ violations.append(violation)
110
+
111
+ # Save snapshot
112
+ snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg")
113
+ cv2.imwrite(snapshot_path, frame)
114
+ snapshots.append({
115
+ "violation": label,
116
+ "frame": frame_count,
117
+ "snapshot_url": f"https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/output/{os.path.basename(snapshot_path)}"
118
+ })
119
+
120
+ frame_count += 1
121
+ processed_frame_count += 1
122
+
123
+ if processed_frame_count >= max_frames:
124
+ break
125
+
126
+ if time.time() - start_time > 30:
127
+ print("⏰ Exceeded 30 seconds of processing time.")
128
+ break
129
+
130
+ video.release()
131
+ os.remove(video_path)
132
+
133
+ score = calculate_safety_score(violations)
134
+ pdf_base64 = generate_pdf_report(violations, snapshots, score)
135
+
136
+ return {
137
+ "violations": violations,
138
+ "snapshots": snapshots,
139
+ "score": score,
140
+ "pdf_base64": pdf_base64
141
+ }
142
+
143
+ except Exception as e:
144
+ print(f"❌ Error processing video: {e}")
145
+ return {
146
+ "violations": [],
147
+ "snapshots": [],
148
+ "score": 0,
149
+ "pdf_base64": "",
150
+ "error": str(e)
151
+ }
152
+ ```
153
 
154
  # ==========================
155
+
156
  # Safety Score Calculation
157
+
158
  # ==========================
159
+
160
+ def calculate\_safety\_score(violations):
161
+ base\_score = 100
162
+ penalties = {
163
+ "no\_helmet": 25,
164
+ "no\_harness": 30,
165
+ "unsafe\_posture": 20,
166
+ "unsafe\_zone": 25
167
+ }
168
+ for v in violations:
169
+ base\_score -= penalties.get(v\["violation"], 0)
170
+ return max(base\_score, 0)
171
 
172
  # ==========================
173
+
174
  # PDF Report Generation
175
+
176
+ # ==========================
177
+
178
+ def generate\_pdf\_report(violations, snapshots, score):
179
+ try:
180
+ pdf\_path = os.path.join(OUTPUT\_DIR, f"report\_{int(time.time())}.pdf")
181
+ c = canvas.Canvas(pdf\_path, pagesize=letter)
182
+ width, height = letter
183
+
184
+ ```
185
+ # Title
186
+ c.setFont("Helvetica-Bold", 16)
187
+ c.drawString(50, height - 50, "Worksite Safety Compliance Report")
188
+
189
+ # Compliance Score
190
+ c.setFont("Helvetica", 12)
191
+ c.drawString(50, height - 80, f"Compliance Score: {score}%")
192
+
193
+ # Violations Table
194
+ y = height - 120
195
+ c.setFont("Helvetica-Bold", 12)
196
+ c.drawString(50, y, "Detected Violations:")
197
+ y -= 20
198
+
199
+ for v in violations:
200
+ c.setFont("Helvetica", 10)
201
+ text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
202
+ c.drawString(50, y, text)
203
+ y -= 20
204
+
205
+ # Add snapshot if available
206
+ snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
207
+ if snapshot and os.path.exists(snapshot["snapshot_url"].split('/')[-1]):
208
+ img = ImageReader(snapshot["snapshot_url"].split('/')[-1])
209
+ c.drawImage(img, 50, y - 100, width=200, height=150)
210
+ y -= 170
211
+
212
+ if y < 50:
213
+ c.showPage()
214
+ y = height - 50
215
+
216
+ c.save()
217
+ print(f"PDF generated at {pdf_path}")
218
+
219
+ # Convert PDF to base64
220
+ with open(pdf_path, "rb") as f:
221
+ pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
222
+
223
+ # Clean up
224
+ os.remove(pdf_path)
225
+ print("PDF converted to base64 and file removed")
226
+ return pdf_base64
227
+ except Exception as e:
228
+ print(f"❌ Error generating PDF: {e}")
229
+ return ""
230
+ ```
231
+
232
+ # ==========================
233
+
234
+ # Gradio Interface
235
+
236
  # ==========================
237
+
238
+ def gradio\_interface(video\_file):
239
+ try:
240
+ if not video\_file:
241
+ return {"error": "Please upload a video file."}, "", "", \[]
242
+
243
+ ```
244
+ with open(video_file, "rb") as f:
245
+ video_data = f.read()
246
+
247
+ result = process_video(video_data)
248
+ return (
249
+ result["violations"],
250
+ f"Safety Score: {result['score']}%",
251
+ result["pdf_base64"],
252
+ result["snapshots"]
253
+ )
254
+ except Exception as e:
255
+ print(f"❌ Error in gradio_interface: {e}")
256
+ return {"error": str(e)}, "", "", []
257
+ ```
258
+
259
+ interface = gr.Interface(
260
+ fn=gradio\_interface,
261
+ inputs=gr.Video(label="Upload Site Video"),
262
+ outputs=\[
263
+ gr.JSON(label="Detected Safety Violations"),
264
+ gr.Textbox(label="Compliance Score"),
265
+ gr.Textbox(label="PDF Base64 (for API use)"),
266
+ gr.JSON(label="Snapshots")
267
+ ],
268
+ title="Worksite Safety Violation Analyzer",
269
+ description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
270
+ )
271
+
272
+ if **name** == "**main**":
273
+ print("🚀 Launching Safety Analyzer App...")
274
+ interface.launch()
275
+