PrashanthB461 commited on
Commit
08e92c3
·
verified ·
1 Parent(s): df5bb31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -222
app.py CHANGED
@@ -12,263 +12,235 @@ import base64
12
  from PIL import Image
13
 
14
  # ==========================
15
-
16
  # Configuration
17
-
18
  # ==========================
19
-
20
- DEFAULT\_MODEL\_PATH = "models/yolov8\_safety.pt"
21
- FALLBACK\_MODEL = "yolov8n.pt"
22
- MODEL\_PATH = os.getenv("SAFETY\_MODEL\_PATH", DEFAULT\_MODEL\_PATH)
23
- OUTPUT\_DIR = "output"
24
- os.makedirs(OUTPUT\_DIR, exist\_ok=True)
25
-
26
- VIOLATION\_LABELS = {
27
- 0: "no\_helmet",
28
- 1: "no\_harness",
29
- 2: "unsafe\_posture",
30
- 3: "unsafe\_zone"
31
  }
32
 
33
  # ==========================
34
-
35
  # Device Setup
36
-
37
  # ==========================
38
-
39
- device = torch.device("cuda" if torch.cuda.is\_available() else "cpu")
40
  print(f"✅ Using device: {device}")
41
 
42
  # ==========================
43
-
44
  # Load Model
45
-
46
  # ==========================
47
-
48
  try:
49
- selected\_model = MODEL\_PATH if os.path.isfile(MODEL\_PATH) else FALLBACK\_MODEL
50
- model = YOLO(selected\_model)
51
- print(f"✅ Model loaded: {selected\_model}")
52
  except Exception as e:
53
- print(f"❌ Failed to load model: {e}")
54
- raise
55
 
56
  # ==========================
57
-
58
  # Video Processing
59
-
60
  # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def process\_video(video\_data, frame\_skip=5, max\_frames=100):
63
- try:
64
- print("Processing video data...")
65
- \# Save uploaded video data to a temporary file
66
- video\_path = os.path.join(OUTPUT\_DIR, f"temp\_{int(time.time())}.mp4")
67
- with open(video\_path, "wb") as f:
68
- f.write(video\_data)
69
- print(f"Video saved to {video\_path}")
70
-
71
- ```
72
- video = cv2.VideoCapture(video_path)
73
- if not video.isOpened():
74
- raise ValueError("Could not open video file.")
75
-
76
- frame_count = 0
77
- violations = []
78
- snapshots = []
79
- processed_frame_count = 0
80
- start_time = time.time()
81
-
82
- while True:
83
- ret, frame = video.read()
84
- if not ret:
85
- break
86
-
87
- if frame_count % frame_skip != 0:
88
  frame_count += 1
89
- continue
90
-
91
- # Model inference
92
- results = model(frame, device=device)
93
-
94
- for result in results:
95
- for box in result.boxes:
96
- cls = int(box.cls)
97
- conf = float(box.conf)
98
- xywh = box.xywh.cpu().numpy()[0]
99
-
100
- label = VIOLATION_LABELS.get(cls, f"class_{cls}")
101
- violation = {
102
- "frame": frame_count,
103
- "violation": label,
104
- "confidence": round(conf, 2),
105
- "bounding_box": [round(x, 2) for x in xywh],
106
- "timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
107
- }
108
- violations.append(violation)
109
-
110
- # Save snapshot
111
- snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg")
112
- cv2.imwrite(snapshot_path, frame)
113
- snapshots.append({
114
- "violation": label,
115
- "frame": frame_count,
116
- "snapshot_url": f"https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/output/{os.path.basename(snapshot_path)}"
117
- })
118
-
119
- frame_count += 1
120
- processed_frame_count += 1
121
-
122
- if processed_frame_count >= max_frames:
123
- break
124
-
125
- if time.time() - start_time > 30:
126
- print("⏰ Exceeded 30 seconds of processing time.")
127
- break
128
-
129
- video.release()
130
- os.remove(video_path)
131
-
132
- score = calculate_safety_score(violations)
133
- pdf_base64 = generate_pdf_report(violations, snapshots, score)
134
-
135
- return {
136
- "violations": violations,
137
- "snapshots": snapshots,
138
- "score": score,
139
- "pdf_base64": pdf_base64
140
- }
141
-
142
- except Exception as e:
143
- print(f"❌ Error processing video: {e}")
144
- return {
145
- "violations": [],
146
- "snapshots": [],
147
- "score": 0,
148
- "pdf_base64": "",
149
- "error": str(e)
150
- }
151
- ```
152
 
153
  # ==========================
154
-
155
  # Safety Score Calculation
156
-
157
  # ==========================
158
-
159
- def calculate\_safety\_score(violations):
160
- base\_score = 100
161
- penalties = {
162
- "no\_helmet": 25,
163
- "no\_harness": 30,
164
- "unsafe\_posture": 20,
165
- "unsafe\_zone": 25
166
- }
167
- for v in violations:
168
- base\_score -= penalties.get(v\["violation"], 0)
169
- return max(base\_score, 0)
170
 
171
  # ==========================
172
-
173
  # PDF Report Generation
174
-
175
  # ==========================
176
-
177
- def generate\_pdf\_report(violations, snapshots, score):
178
- try:
179
- pdf\_path = os.path.join(OUTPUT\_DIR, f"report\_{int(time.time())}.pdf")
180
- c = canvas.Canvas(pdf\_path, pagesize=letter)
181
- width, height = letter
182
-
183
- ```
184
- # Title
185
- c.setFont("Helvetica-Bold", 16)
186
- c.drawString(50, height - 50, "Worksite Safety Compliance Report")
187
-
188
- # Compliance Score
189
- c.setFont("Helvetica", 12)
190
- c.drawString(50, height - 80, f"Compliance Score: {score}%")
191
-
192
- # Violations Table
193
- y = height - 120
194
- c.setFont("Helvetica-Bold", 12)
195
- c.drawString(50, y, "Detected Violations:")
196
- y -= 20
197
-
198
- for v in violations:
199
- c.setFont("Helvetica", 10)
200
- text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
201
- c.drawString(50, y, text)
202
  y -= 20
203
 
204
- # Add snapshot if available
205
- snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
206
- if snapshot and os.path.exists(snapshot["snapshot_url"].split('/')[-1]):
207
- img = ImageReader(snapshot["snapshot_url"].split('/')[-1])
208
- c.drawImage(img, 50, y - 100, width=200, height=150)
209
- y -= 170
210
-
211
- if y < 50:
212
- c.showPage()
213
- y = height - 50
214
-
215
- c.save()
216
- print(f"PDF generated at {pdf_path}")
217
-
218
- # Convert PDF to base64
219
- with open(pdf_path, "rb") as f:
220
- pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
221
-
222
- # Clean up
223
- os.remove(pdf_path)
224
- print("PDF converted to base64 and file removed")
225
- return pdf_base64
226
- except Exception as e:
227
- print(f"❌ Error generating PDF: {e}")
228
- return ""
229
- ```
 
 
 
 
 
230
 
231
  # ==========================
232
-
233
  # Gradio Interface
234
-
235
  # ==========================
236
-
237
- def gradio\_interface(video\_file):
238
- try:
239
- if not video\_file:
240
- return {"error": "Please upload a video file."}, "", "", \[]
241
-
242
- ```
243
- with open(video_file, "rb") as f:
244
- video_data = f.read()
245
-
246
- result = process_video(video_data)
247
- return (
248
- result["violations"],
249
- f"Safety Score: {result['score']}%",
250
- result["pdf_base64"],
251
- result["snapshots"]
252
- )
253
- except Exception as e:
254
- print(f"❌ Error in gradio_interface: {e}")
255
- return {"error": str(e)}, "", "", []
256
- ```
257
 
258
  interface = gr.Interface(
259
- fn=gradio\_interface,
260
- inputs=gr.Video(label="Upload Site Video"),
261
- outputs=\[
262
- gr.JSON(label="Detected Safety Violations"),
263
- gr.Textbox(label="Compliance Score"),
264
- gr.Textbox(label="PDF Base64 (for API use)"),
265
- gr.JSON(label="Snapshots")
266
- ],
267
- title="Worksite Safety Violation Analyzer",
268
- description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
269
  )
270
 
271
- if **name** == "**main**":
272
- print("🚀 Launching Safety Analyzer App...")
273
- interface.launch()
274
-
 
12
  from PIL import Image
13
 
14
  # ==========================
 
15
  # Configuration
 
16
  # ==========================
17
+ DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
18
+ FALLBACK_MODEL = "yolov8n.pt"
19
+ MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
20
+ OUTPUT_DIR = "output"
21
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
22
+
23
+ VIOLATION_LABELS = {
24
+ 0: "no_helmet",
25
+ 1: "no_harness",
26
+ 2: "unsafe_posture",
27
+ 3: "unsafe_zone"
 
28
  }
29
 
30
  # ==========================
 
31
  # Device Setup
 
32
  # ==========================
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
34
  print(f"✅ Using device: {device}")
35
 
36
  # ==========================
 
37
  # Load Model
 
38
  # ==========================
 
39
  try:
40
+ selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
41
+ model = YOLO(selected_model)
42
+ print(f"✅ Model loaded: {selected_model}")
43
  except Exception as e:
44
+ print(f"❌ Failed to load model: {e}")
45
+ raise
46
 
47
  # ==========================
 
48
  # Video Processing
 
49
  # ==========================
50
+ def process_video(video_data, frame_skip=5, max_frames=100):
51
+ try:
52
+ print("Processing video data...")
53
+ # Save uploaded video data to a temporary file
54
+ video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
55
+ with open(video_path, "wb") as f:
56
+ f.write(video_data)
57
+ print(f"Video saved to {video_path}")
58
+
59
+ video = cv2.VideoCapture(video_path)
60
+ if not video.isOpened():
61
+ raise ValueError("Could not open video file.")
62
+
63
+ frame_count = 0
64
+ violations = []
65
+ snapshots = []
66
+ processed_frame_count = 0
67
+ start_time = time.time()
68
+
69
+ while True:
70
+ ret, frame = video.read()
71
+ if not ret:
72
+ break
73
+
74
+ if frame_count % frame_skip != 0:
75
+ frame_count += 1
76
+ continue
77
+
78
+ # Model inference
79
+ results = model(frame, device=device)
80
+
81
+ for result in results:
82
+ for box in result.boxes:
83
+ cls = int(box.cls)
84
+ conf = float(box.conf)
85
+ xywh = box.xywh.cpu().numpy()[0]
86
+
87
+ label = VIOLATION_LABELS.get(cls, f"class_{cls}")
88
+ violation = {
89
+ "frame": frame_count,
90
+ "violation": label,
91
+ "confidence": round(conf, 2),
92
+ "bounding_box": [round(x, 2) for x in xywh],
93
+ "timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
94
+ }
95
+ violations.append(violation)
96
+
97
+ # Save snapshot
98
+ snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg")
99
+ cv2.imwrite(snapshot_path, frame)
100
+ snapshots.append({
101
+ "violation": label,
102
+ "frame": frame_count,
103
+ "snapshot_url": f"https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/output/{os.path.basename(snapshot_path)}"
104
+ })
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  frame_count += 1
107
+ processed_frame_count += 1
108
+
109
+ if processed_frame_count >= max_frames:
110
+ break
111
+
112
+ if time.time() - start_time > 30:
113
+ print("⏰ Exceeded 30 seconds of processing time.")
114
+ break
115
+
116
+ video.release()
117
+ os.remove(video_path)
118
+
119
+ score = calculate_safety_score(violations)
120
+ pdf_base64 = generate_pdf_report(violations, snapshots, score)
121
+
122
+ return {
123
+ "violations": violations,
124
+ "snapshots": snapshots,
125
+ "score": score,
126
+ "pdf_base64": pdf_base64
127
+ }
128
+
129
+ except Exception as e:
130
+ print(f"❌ Error processing video: {e}")
131
+ return {
132
+ "violations": [],
133
+ "snapshots": [],
134
+ "score": 0,
135
+ "pdf_base64": "",
136
+ "error": str(e)
137
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # ==========================
 
140
  # Safety Score Calculation
 
141
  # ==========================
142
+ def calculate_safety_score(violations):
143
+ base_score = 100
144
+ penalties = {
145
+ "no_helmet": 25,
146
+ "no_harness": 30,
147
+ "unsafe_posture": 20,
148
+ "unsafe_zone": 25
149
+ }
150
+ for v in violations:
151
+ base_score -= penalties.get(v["violation"], 0)
152
+ return max(base_score, 0)
 
153
 
154
  # ==========================
 
155
  # PDF Report Generation
 
156
  # ==========================
157
+ def generate_pdf_report(violations, snapshots, score):
158
+ try:
159
+ pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf")
160
+ c = canvas.Canvas(pdf_path, pagesize=letter)
161
+ width, height = letter
162
+
163
+ # Title
164
+ c.setFont("Helvetica-Bold", 16)
165
+ c.drawString(50, height - 50, "Worksite Safety Compliance Report")
166
+
167
+ # Compliance Score
168
+ c.setFont("Helvetica", 12)
169
+ c.drawString(50, height - 80, f"Compliance Score: {score}%")
170
+
171
+ # Violations Table
172
+ y = height - 120
173
+ c.setFont("Helvetica-Bold", 12)
174
+ c.drawString(50, y, "Detected Violations:")
 
 
 
 
 
 
 
 
175
  y -= 20
176
 
177
+ for v in violations:
178
+ c.setFont("Helvetica", 10)
179
+ text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
180
+ c.drawString(50, y, text)
181
+ y -= 20
182
+
183
+ # Add snapshot if available
184
+ snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
185
+ if snapshot and os.path.exists(snapshot["snapshot_url"].split('/')[-1]):
186
+ img = ImageReader(snapshot["snapshot_url"].split('/')[-1])
187
+ c.drawImage(img, 50, y - 100, width=200, height=150)
188
+ y -= 170
189
+
190
+ if y < 50:
191
+ c.showPage()
192
+ y = height - 50
193
+
194
+ c.save()
195
+ print(f"PDF generated at {pdf_path}")
196
+
197
+ # Convert PDF to base64
198
+ with open(pdf_path, "rb") as f:
199
+ pdf_base64 = base64.b64encode(f.read()).decode('utf-8')
200
+
201
+ # Clean up
202
+ os.remove(pdf_path)
203
+ print("PDF converted to base64 and file removed")
204
+ return pdf_base64
205
+ except Exception as e:
206
+ print(f"❌ Error generating PDF: {e}")
207
+ return ""
208
 
209
  # ==========================
 
210
  # Gradio Interface
 
211
  # ==========================
212
+ def gradio_interface(video_file):
213
+ try:
214
+ if not video_file:
215
+ return {"error": "Please upload a video file."}, "", "", []
216
+
217
+ with open(video_file, "rb") as f:
218
+ video_data = f.read()
219
+
220
+ result = process_video(video_data)
221
+ return (
222
+ result["violations"],
223
+ f"Safety Score: {result['score']}%",
224
+ result["pdf_base64"],
225
+ result["snapshots"]
226
+ )
227
+ except Exception as e:
228
+ print(f"❌ Error in gradio_interface: {e}")
229
+ return {"error": str(e)}, "", "", []
 
 
 
230
 
231
  interface = gr.Interface(
232
+ fn=gradio_interface,
233
+ inputs=gr.Video(label="Upload Site Video"),
234
+ outputs=[
235
+ gr.JSON(label="Detected Safety Violations"),
236
+ gr.Textbox(label="Compliance Score"),
237
+ gr.Textbox(label="PDF Base64 (for API use)"),
238
+ gr.JSON(label="Snapshots")
239
+ ],
240
+ title="Worksite Safety Violation Analyzer",
241
+ description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
242
  )
243
 
244
+ if __name__ == "__main__":
245
+ print("🚀 Launching Safety Analyzer App...")
246
+ interface.launch()