PrashanthB461 commited on
Commit
04fdfdf
·
verified ·
1 Parent(s): a3c6abf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -23
app.py CHANGED
@@ -5,13 +5,21 @@ import torch
5
  import numpy as np
6
  from ultralytics import YOLO
7
  import time
 
 
 
 
 
 
8
 
9
  # ==========================
10
  # Configuration
11
  # ==========================
12
  DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
13
- FALLBACK_MODEL = "yolov8n.pt" # Use nano model if custom one is missing
14
  MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
 
 
15
 
16
  VIOLATION_LABELS = {
17
  0: "no_helmet",
@@ -27,22 +35,28 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  print(f"✅ Using device: {device}")
28
 
29
  # ==========================
30
- # Load Model (Use YOLOv8n for Faster Inference)
31
  # ==========================
32
  selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
33
  model = YOLO(selected_model)
34
 
35
  # ==========================
36
- # Video Processing with Optimizations
37
  # ==========================
38
- def process_video(video_path, frame_skip=5, max_frames=100):
39
  try:
 
 
 
 
 
40
  video = cv2.VideoCapture(video_path)
41
  if not video.isOpened():
42
  raise ValueError("Could not open video file.")
43
 
44
  frame_count = 0
45
  violations = []
 
46
  processed_frame_count = 0
47
  start_time = time.time()
48
 
@@ -55,7 +69,7 @@ def process_video(video_path, frame_skip=5, max_frames=100):
55
  frame_count += 1
56
  continue
57
 
58
- # Model inference for detecting violations
59
  results = model(frame, device=device)
60
 
61
  for result in results:
@@ -65,11 +79,22 @@ def process_video(video_path, frame_skip=5, max_frames=100):
65
  xywh = box.xywh.cpu().numpy()[0]
66
 
67
  label = VIOLATION_LABELS.get(cls, f"class_{cls}")
68
- violations.append({
69
  "frame": frame_count,
70
  "violation": label,
71
  "confidence": round(conf, 2),
72
- "bounding_box": [round(x, 2) for x in xywh]
 
 
 
 
 
 
 
 
 
 
 
73
  })
74
 
75
  frame_count += 1
@@ -78,22 +103,32 @@ def process_video(video_path, frame_skip=5, max_frames=100):
78
  if processed_frame_count >= max_frames:
79
  break
80
 
81
- elapsed_time = time.time() - start_time
82
- if elapsed_time > 30:
83
  print("⏰ Exceeded 30 seconds of processing time.")
84
  break
85
 
86
  video.release()
87
- score = calculate_safety_score(violations)
88
 
89
- # Generate the PDF report URL (using an external method or library)
90
- pdf_report_url = generate_pdf_report(violations, score)
91
 
92
- return violations, score, pdf_report_url
 
 
 
 
 
93
 
94
  except Exception as e:
95
  print(f"❌ Error processing video: {e}")
96
- return [], f"Error: {e}", None
 
 
 
 
 
 
97
 
98
  # ==========================
99
  # Safety Score Calculation
@@ -113,29 +148,76 @@ def calculate_safety_score(violations):
113
  # ==========================
114
  # PDF Report Generation
115
  # ==========================
116
- def generate_pdf_report(violations, score):
117
- # Create a PDF with violation details (This can be done using an external library or template)
118
- pdf_url = "http://path_to_pdf_report" # URL to the generated PDF (replace with actual URL)
119
- return pdf_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # ==========================
122
  # Gradio Interface
123
  # ==========================
124
  def gradio_interface(video_file):
125
  if not video_file:
126
- return "Please upload a video file.", ""
 
 
 
127
 
128
- violations, score, pdf_url = process_video(video_file)
129
- return violations, f"Safety Score: {score}%", pdf_url
 
 
 
 
 
130
 
131
  interface = gr.Interface(
132
  fn=gradio_interface,
133
  inputs=gr.Video(label="Upload Site Video"),
134
- outputs=[gr.JSON(label="Detected Safety Violations"), gr.Textbox(label="Compliance Score"), gr.Textbox(label="PDF Report URL")],
 
 
 
 
 
135
  title="Worksite Safety Violation Analyzer",
136
  description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
137
  )
138
 
139
  if __name__ == "__main__":
140
  print("🚀 Launching Safety Analyzer App...")
141
- interface.launch()
 
5
  import numpy as np
6
  from ultralytics import YOLO
7
  import time
8
+ from reportlab.lib.pagesizes import letter
9
+ from reportlab.pdfgen import canvas
10
+ from reportlab.lib.utils import ImageReader
11
+ from io import BytesIO
12
+ import base64
13
+ from PIL import Image
14
 
15
  # ==========================
16
  # Configuration
17
  # ==========================
18
  DEFAULT_MODEL_PATH = "models/yolov8_safety.pt"
19
+ FALLBACK_MODEL = "yolov8n.pt"
20
  MODEL_PATH = os.getenv("SAFETY_MODEL_PATH", DEFAULT_MODEL_PATH)
21
+ OUTPUT_DIR = "output" # Directory to store snapshots and PDFs
22
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
23
 
24
  VIOLATION_LABELS = {
25
  0: "no_helmet",
 
35
  print(f"✅ Using device: {device}")
36
 
37
  # ==========================
38
+ # Load Model
39
  # ==========================
40
  selected_model = MODEL_PATH if os.path.isfile(MODEL_PATH) else FALLBACK_MODEL
41
  model = YOLO(selected_model)
42
 
43
  # ==========================
44
+ # Video Processing
45
  # ==========================
46
+ def process_video(video_data, frame_skip=5, max_frames=100):
47
  try:
48
+ # Save uploaded video data to a temporary file
49
+ video_path = os.path.join(OUTPUT_DIR, f"temp_{int(time.time())}.mp4")
50
+ with open(video_path, "wb") as f:
51
+ f.write(video_data)
52
+
53
  video = cv2.VideoCapture(video_path)
54
  if not video.isOpened():
55
  raise ValueError("Could not open video file.")
56
 
57
  frame_count = 0
58
  violations = []
59
+ snapshots = []
60
  processed_frame_count = 0
61
  start_time = time.time()
62
 
 
69
  frame_count += 1
70
  continue
71
 
72
+ # Model inference
73
  results = model(frame, device=device)
74
 
75
  for result in results:
 
79
  xywh = box.xywh.cpu().numpy()[0]
80
 
81
  label = VIOLATION_LABELS.get(cls, f"class_{cls}")
82
+ violation = {
83
  "frame": frame_count,
84
  "violation": label,
85
  "confidence": round(conf, 2),
86
+ "bounding_box": [round(x, 2) for x in xywh],
87
+ "timestamp": frame_count / video.get(cv2.CAP_PROP_FPS)
88
+ }
89
+ violations.append(violation)
90
+
91
+ # Save snapshot
92
+ snapshot_path = os.path.join(OUTPUT_DIR, f"snapshot_{frame_count}_{label}.jpg")
93
+ cv2.imwrite(snapshot_path, frame)
94
+ snapshots.append({
95
+ "violation": label,
96
+ "frame": frame_count,
97
+ "snapshot_url": snapshot_path
98
  })
99
 
100
  frame_count += 1
 
103
  if processed_frame_count >= max_frames:
104
  break
105
 
106
+ if time.time() - start_time > 30:
 
107
  print("⏰ Exceeded 30 seconds of processing time.")
108
  break
109
 
110
  video.release()
111
+ os.remove(video_path) # Clean up temporary video file
112
 
113
+ score = calculate_safety_score(violations)
114
+ pdf_report_path = generate_pdf_report(violations, snapshots, score)
115
 
116
+ return {
117
+ "violations": violations,
118
+ "snapshots": snapshots,
119
+ "score": score,
120
+ "pdf_report_url": pdf_report_path
121
+ }
122
 
123
  except Exception as e:
124
  print(f"❌ Error processing video: {e}")
125
+ return {
126
+ "violations": [],
127
+ "snapshots": [],
128
+ "score": 0,
129
+ "pdf_report_url": "",
130
+ "error": str(e)
131
+ }
132
 
133
  # ==========================
134
  # Safety Score Calculation
 
148
  # ==========================
149
  # PDF Report Generation
150
  # ==========================
151
+ def generate_pdf_report(violations, snapshots, score):
152
+ pdf_path = os.path.join(OUTPUT_DIR, f"report_{int(time.time())}.pdf")
153
+ c = canvas.Canvas(pdf_path, pagesize=letter)
154
+ width, height = letter
155
+
156
+ # Title
157
+ c.setFont("Helvetica-Bold", 16)
158
+ c.drawString(50, height - 50, "Worksite Safety Compliance Report")
159
+
160
+ # Compliance Score
161
+ c.setFont("Helvetica", 12)
162
+ c.drawString(50, height - 80, f"Compliance Score: {score}%")
163
+
164
+ # Violations Table
165
+ y = height - 120
166
+ c.setFont("Helvetica-Bold", 12)
167
+ c.drawString(50, y, "Detected Violations:")
168
+ y -= 20
169
+
170
+ for v in violations:
171
+ c.setFont("Helvetica", 10)
172
+ text = f"Violation: {v['violation']}, Timestamp: {v['timestamp']:.2f}s, Confidence: {v['confidence']}"
173
+ c.drawString(50, y, text)
174
+ y -= 20
175
+
176
+ # Add snapshot if available
177
+ snapshot = next((s for s in snapshots if s["frame"] == v["frame"] and s["violation"] == v["violation"]), None)
178
+ if snapshot and os.path.exists(snapshot["snapshot_url"]):
179
+ img = ImageReader(snapshot["snapshot_url"])
180
+ c.drawImage(img, 50, y - 100, width=200, height=150)
181
+ y -= 170
182
+
183
+ if y < 50:
184
+ c.showPage()
185
+ y = height - 50
186
+
187
+ c.save()
188
+ return pdf_path
189
 
190
  # ==========================
191
  # Gradio Interface
192
  # ==========================
193
  def gradio_interface(video_file):
194
  if not video_file:
195
+ return {"error": "Please upload a video file."}, "", ""
196
+
197
+ with open(video_file, "rb") as f:
198
+ video_data = f.read()
199
 
200
+ result = process_video(video_data)
201
+ return (
202
+ result["violations"],
203
+ f"Safety Score: {result['score']}%",
204
+ result["pdf_report_url"],
205
+ result["snapshots"]
206
+ )
207
 
208
  interface = gr.Interface(
209
  fn=gradio_interface,
210
  inputs=gr.Video(label="Upload Site Video"),
211
+ outputs=[
212
+ gr.JSON(label="Detected Safety Violations"),
213
+ gr.Textbox(label="Compliance Score"),
214
+ gr.Textbox(label="PDF Report URL"),
215
+ gr.JSON(label="Snapshots")
216
+ ],
217
  title="Worksite Safety Violation Analyzer",
218
  description="Upload short site videos to detect safety violations (e.g., no helmet, no harness, unsafe posture)."
219
  )
220
 
221
  if __name__ == "__main__":
222
  print("🚀 Launching Safety Analyzer App...")
223
+ interface.launch()