PrashanthB461 commited on
Commit
238b4a9
·
verified ·
1 Parent(s): ac968c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +446 -218
app.py CHANGED
@@ -12,295 +12,523 @@ from reportlab.lib.units import inch
12
  from io import BytesIO
13
  import base64
14
  import logging
15
- from concurrent.futures import ThreadPoolExecutor
16
 
17
  # ==========================
18
- # Optimized Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
 
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
25
- 1: "no_harness",
26
  2: "unsafe_posture",
27
  3: "unsafe_zone",
28
  4: "improper_tool_use"
29
  },
30
  "CLASS_COLORS": {
31
- "no_helmet": (0, 0, 255),
32
- "no_harness": (0, 165, 255),
33
- "unsafe_posture": (0, 255, 0),
34
- "unsafe_zone": (255, 0, 0),
35
- "improper_tool_use": (255, 255, 0)
 
 
 
 
 
 
 
36
  },
37
- "FRAME_SKIP": 8, # Balanced speed/accuracy
38
- "CONFIDENCE_THRESHOLD": 0.35,
39
- "MIN_DETECTIONS": 2,
40
- "MAX_WORKERS": 4,
41
- "PROCESSING_RESOLUTION": (640, 640),
42
  "SF_CREDENTIALS": {
43
  "username": "prashanth1ai@safety.com",
44
  "password": "SaiPrash461",
45
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
46
  "domain": "login"
47
  },
48
- "MAX_PROCESSING_SECONDS": 30 # Timeout
 
 
 
 
 
49
  }
50
 
51
- # Setup
 
 
 
52
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
 
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- logging.basicConfig(level=logging.INFO)
55
- logger = logging.getLogger(__name__)
56
 
57
- # Load and warm up model
58
- model = YOLO(CONFIG["MODEL_PATH"]).to(device)
59
- model.warmup(imgsz=[1, 3, *CONFIG["PROCESSING_RESOLUTION"]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # ==========================
62
- # Core Processing Functions
63
  # ==========================
64
  def draw_detections(frame, detections):
65
- """Draw bounding boxes on frame"""
66
  for det in detections:
67
- label = det["label"]
68
- x, y, w, h = det["bbox"]
69
- x1, y1 = int(x - w/2), int(y - h/2)
70
- x2, y2 = int(x + w/2), int(y + h/2)
 
 
 
 
 
71
 
72
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
73
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
74
- cv2.putText(frame, f"{label}: {det['confidence']:.2f}",
75
- (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
 
 
76
  return frame
77
 
78
- def process_frame(frame, frame_count, fps):
79
- """Process single frame for violations"""
80
- resized = cv2.resize(frame, CONFIG["PROCESSING_RESOLUTION"])
81
- results = model(resized, verbose=False)
82
 
83
- frame_violations = []
84
- for result in results:
85
- for box in result.boxes:
86
- conf = box.conf.item()
87
- if conf > CONFIG["CONFIDENCE_THRESHOLD"]:
88
- label = CONFIG["VIOLATION_LABELS"].get(int(box.cls.item()))
89
- if label:
90
- frame_violations.append({
91
- "label": label,
92
- "confidence": conf,
93
- "timestamp": frame_count / fps,
94
- "frame": frame_count,
95
- "bbox": box.xywh.cpu().numpy()[0]
96
- })
97
- return frame_violations
98
-
99
- def process_video(video_path):
100
- """Optimized video processing pipeline"""
101
- start_time = time.time()
102
- cap = cv2.VideoCapture(video_path)
103
- fps = cap.get(cv2.CAP_PROP_FPS) or 30
104
- frame_count = 0
105
- all_violations = []
106
- snapshots = {}
107
- last_update = time.time()
108
 
109
- with ThreadPoolExecutor(max_workers=CONFIG["MAX_WORKERS"]) as executor:
110
- futures = []
111
-
112
- while cap.isOpened():
113
- if time.time() - start_time > CONFIG["MAX_PROCESSING_SECONDS"]:
114
- logger.warning("Processing timeout reached")
115
- break
116
-
117
- ret, frame = cap.read()
118
- if not ret:
119
- break
120
-
121
- if frame_count % CONFIG["FRAME_SKIP"] == 0:
122
- futures.append(executor.submit(
123
- process_frame,
124
- frame.copy(),
125
- frame_count,
126
- fps
127
- ))
128
-
129
- frame_count += 1
130
-
131
- # Process completed frames
132
- while futures and futures[0].done():
133
- for violation in futures.pop(0).result():
134
- all_violations.append(violation)
135
- if violation["label"] not in snapshots:
136
- snapshots[violation["label"]] = {
137
- "frame": draw_detections(frame.copy(), [violation]),
138
- "timestamp": violation["timestamp"],
139
- "confidence": violation["confidence"]
140
- }
141
-
142
- cap.release()
143
 
144
- # Filter violations
145
- confirmed_violations = []
146
- violation_counts = {}
147
- for v in all_violations:
148
- key = v["label"]
149
- violation_counts[key] = violation_counts.get(key, 0) + 1
150
-
151
- for v in all_violations:
152
- if violation_counts[v["label"]] >= CONFIG["MIN_DETECTIONS"]:
153
- confirmed_violations.append(v)
154
-
155
- # Save snapshots
156
- snapshot_paths = []
157
- for label, data in snapshots.items():
158
- if violation_counts.get(label, 0) >= CONFIG["MIN_DETECTIONS"]:
159
- path = os.path.join(CONFIG["OUTPUT_DIR"], f"{label}_{int(time.time())}.jpg")
160
- cv2.imwrite(path, data["frame"])
161
- snapshot_paths.append({
162
- "label": label,
163
- "path": path,
164
- "timestamp": data["timestamp"],
165
- "confidence": data["confidence"]
166
- })
167
 
168
- return confirmed_violations, snapshot_paths, time.time() - start_time
169
 
170
  # ==========================
171
- # Reporting Functions
172
  # ==========================
173
- def generate_report(violations, snapshots, processing_time):
174
- """Generate PDF report"""
 
 
 
 
 
 
 
 
 
 
175
  try:
 
 
176
  pdf_file = BytesIO()
177
  c = canvas.Canvas(pdf_file, pagesize=letter)
178
-
179
- # Header
180
- c.setFont("Helvetica-Bold", 14)
181
- c.drawString(1*inch, 10.5*inch, "Safety Violation Report")
182
  c.setFont("Helvetica", 12)
183
-
184
- # Summary
185
- y_pos = 9.8*inch
186
- c.drawString(1*inch, y_pos, f"Processing Time: {processing_time:.1f}s")
187
- y_pos -= 0.4*inch
188
- c.drawString(1*inch, y_pos, f"Total Violations: {len(violations)}")
189
- y_pos -= 0.6*inch
190
-
191
- # Violations
192
- c.setFont("Helvetica-Bold", 12)
193
- c.drawString(1*inch, y_pos, "Violation Details:")
194
- y_pos -= 0.3*inch
195
  c.setFont("Helvetica", 10)
196
-
197
- for v in violations[:50]: # Limit to first 50
198
- text = f"{v['label']} at {v['timestamp']:.1f}s (Confidence: {v['confidence']:.2f})"
199
- c.drawString(1*inch, y_pos, text)
200
- y_pos -= 0.2*inch
201
- if y_pos < 1*inch:
202
- c.showPage()
203
- y_pos = 10*inch
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  c.save()
206
  pdf_file.seek(0)
207
- return pdf_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except Exception as e:
209
- logger.error(f"Report generation failed: {e}")
210
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # ==========================
213
- # Salesforce Integration
214
  # ==========================
215
- def upload_to_salesforce(violations, snapshots, processing_time):
216
- """Upload results to Salesforce"""
217
  try:
218
- sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # Create record
221
- record_data = {
222
- "Processing_Time__c": processing_time,
223
- "Violation_Count__c": len(violations),
224
- "Status__c": "Completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  }
226
- record = sf.Safety_Video_Report__c.create(record_data)
227
-
228
- # Upload report
229
- pdf = generate_report(violations, snapshots, processing_time)
230
- if pdf:
231
- encoded = base64.b64encode(pdf.getvalue()).decode("utf-8")
232
- sf.ContentVersion.create({
233
- "Title": f"Safety_Report_{record['id']}",
234
- "PathOnClient": "report.pdf",
235
- "VersionData": encoded,
236
- "FirstPublishLocationId": record['id']
237
- })
238
-
239
- return record['id']
240
  except Exception as e:
241
- logger.error(f"Salesforce upload failed: {e}")
242
- return None
 
 
 
 
 
 
 
243
 
244
  # ==========================
245
  # Gradio Interface
246
  # ==========================
247
- def analyze_video(video_file):
248
- """Main processing function"""
249
  if not video_file:
250
- return "No video uploaded", "", "", ""
251
-
252
  try:
253
- # Process video
254
- violations, snapshots, proc_time = process_video(video_file)
255
-
256
- # Generate outputs
257
- violation_table = (
258
- "| Violation | Time (s) | Confidence |\n"
259
- "|-----------|----------|------------|\n" +
260
- "\n".join(
261
- f"| {v['label']} | {v['timestamp']:.1f} | {v['confidence']:.2f} |"
262
- for v in violations[:20] # Show first 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  )
 
 
 
 
 
 
 
264
  )
265
-
266
- snapshot_markdown = "\n".join(
267
- f"**{s['label']}** ({s['timestamp']:.1f}s, {s['confidence']:.2f}):\n"
268
- f"![]({s['path']})"
269
- for s in snapshots
270
- )
271
-
272
- # Salesforce upload (async)
273
- sf_id = "Not uploaded"
274
- if violations:
275
- sf_id = upload_to_salesforce(violations, snapshots, proc_time) or "Upload failed"
276
-
277
- return (
278
- f"Processed in {proc_time:.1f}s\n{violation_table}",
279
- snapshot_markdown,
280
- f"Found {len(violations)} violations",
281
- f"Salesforce ID: {sf_id}"
282
- )
283
-
284
  except Exception as e:
285
- return f"Error: {str(e)}", "", "", ""
 
286
 
287
- # Launch interface
288
  interface = gr.Interface(
289
- fn=analyze_video,
290
  inputs=gr.Video(label="Upload Site Video"),
291
- outputs=[
292
- gr.Markdown("## Violation Results"),
293
- gr.Markdown("## Evidence Snapshots"),
294
- gr.Textbox(label="Summary"),
295
- gr.Textbox(label="Salesforce Record")
 
296
  ],
297
- title="AI Safety Compliance Inspector",
298
- description=(
299
- "Upload worksite video for automated safety violation detection. "
300
- "Detects: Missing helmets, improper harnessing, unsafe zones, and more."
301
- ),
302
  allow_flagging="never"
303
  )
304
 
305
  if __name__ == "__main__":
306
- interface.launch(server_name="0.0.0.0", server_port=7860)
 
 
12
  from io import BytesIO
13
  import base64
14
  import logging
15
+ from retrying import retry
16
 
17
  # ==========================
18
+ # Enhanced Configuration
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
22
+ "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
26
+ 1: "no_harness",
27
  2: "unsafe_posture",
28
  3: "unsafe_zone",
29
  4: "improper_tool_use"
30
  },
31
  "CLASS_COLORS": {
32
+ "no_helmet": (0, 0, 255), # Red
33
+ "no_harness": (0, 165, 255), # Orange
34
+ "unsafe_posture": (0, 255, 0), # Green
35
+ "unsafe_zone": (255, 0, 0), # Blue
36
+ "improper_tool_use": (255, 255, 0) # Yellow
37
+ },
38
+ "DISPLAY_NAMES": {
39
+ "no_helmet": "No Helmet Violation",
40
+ "no_harness": "No Harness Violation",
41
+ "unsafe_posture": "Unsafe Posture Violation",
42
+ "unsafe_zone": "Unsafe Zone Entry",
43
+ "improper_tool_use": "Improper Tool Use"
44
  },
 
 
 
 
 
45
  "SF_CREDENTIALS": {
46
  "username": "prashanth1ai@safety.com",
47
  "password": "SaiPrash461",
48
  "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
49
  "domain": "login"
50
  },
51
+ "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
52
+ "FRAME_SKIP": 5, # Reduced for better detection
53
+ "MAX_PROCESSING_TIME": 60,
54
+ "CONFIDENCE_THRESHOLD": 0.25, # Lower threshold for all violations
55
+ "IOU_THRESHOLD": 0.4,
56
+ "MIN_VIOLATION_FRAMES": 3 # Minimum consecutive frames to confirm violation
57
  }
58
 
59
+ # Setup logging
60
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
61
+ logger = logging.getLogger(__name__)
62
+
63
  os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
64
+
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ logger.info(f"Using device: {device}")
 
67
 
68
+ def load_model():
69
+ try:
70
+ if os.path.isfile(CONFIG["MODEL_PATH"]):
71
+ model_path = CONFIG["MODEL_PATH"]
72
+ logger.info(f"Model loaded: {model_path}")
73
+ else:
74
+ model_path = CONFIG["FALLBACK_MODEL"]
75
+ logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
76
+ if not os.path.isfile(model_path):
77
+ logger.info(f"Downloading fallback model: {model_path}")
78
+ torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
79
+ model = YOLO(model_path).to(device)
80
+ return model
81
+ except Exception as e:
82
+ logger.error(f"Failed to load model: {e}")
83
+ raise
84
+
85
+ model = load_model()
86
 
87
  # ==========================
88
+ # Enhanced Helper Functions
89
  # ==========================
90
  def draw_detections(frame, detections):
91
+ """Draw bounding boxes and labels on frame"""
92
  for det in detections:
93
+ label = det["violation"]
94
+ confidence = det["confidence"]
95
+ x, y, w, h = det["bounding_box"]
96
+
97
+ # Convert from center coordinates to corner coordinates
98
+ x1 = int(x - w/2)
99
+ y1 = int(y - h/2)
100
+ x2 = int(x + w/2)
101
+ y2 = int(y + h/2)
102
 
103
  color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
104
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
105
+
106
+ display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
107
+ cv2.putText(frame, display_text, (x1, y1-10),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
109
  return frame
110
 
111
+ def calculate_iou(box1, box2):
112
+ """Calculate Intersection over Union (IoU) for two bounding boxes."""
113
+ x1, y1, w1, h1 = box1
114
+ x2, y2, w2, h2 = box2
115
 
116
+ # Convert to top-left and bottom-right coordinates
117
+ x1_min, y1_min = x1 - w1/2, y1 - h1/2
118
+ x1_max, y1_max = x1 + w1/2, y1 + h1/2
119
+ x2_min, y2_min = x2 - w2/2, y2 - h2/2
120
+ x2_max, y2_max = x2 + w2/2, y2 + h2/2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ # Calculate intersection
123
+ x_min = max(x1_min, x2_min)
124
+ y_min = max(y1_min, y2_min)
125
+ x_max = min(x1_max, x2_max)
126
+ y_max = min(y1_max, y2_max)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ intersection = max(0, x_max - x_min) * max(0, y_max - y_min)
129
+ area1 = w1 * h1
130
+ area2 = w2 * h2
131
+ union = area1 + area2 - intersection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ return intersection / union if union > 0 else 0
134
 
135
  # ==========================
136
+ # Salesforce Integration (unchanged)
137
  # ==========================
138
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
139
+ def connect_to_salesforce():
140
+ try:
141
+ sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
142
+ logger.info("Connected to Salesforce")
143
+ sf.describe()
144
+ return sf
145
+ except Exception as e:
146
+ logger.error(f"Salesforce connection failed: {e}")
147
+ raise
148
+
149
+ def generate_violation_pdf(violations, score):
150
  try:
151
+ pdf_filename = f"violations_{int(time.time())}.pdf"
152
+ pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
153
  pdf_file = BytesIO()
154
  c = canvas.Canvas(pdf_file, pagesize=letter)
 
 
 
 
155
  c.setFont("Helvetica", 12)
156
+ c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
 
 
 
 
 
 
 
 
 
 
 
157
  c.setFont("Helvetica", 10)
158
+
159
+ y_position = 9.5 * inch
160
+ report_data = {
161
+ "Compliance Score": f"{score}%",
162
+ "Violations Found": len(violations),
163
+ "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
164
+ }
165
+ for key, value in report_data.items():
166
+ c.drawString(1 * inch, y_position, f"{key}: {value}")
167
+ y_position -= 0.3 * inch
168
+
169
+ y_position -= 0.3 * inch
170
+ c.drawString(1 * inch, y_position, "Violation Details:")
171
+ y_position -= 0.3 * inch
172
+ if not violations:
173
+ c.drawString(1 * inch, y_position, "No violations detected.")
174
+ else:
175
+ for v in violations:
176
+ display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
177
+ text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
178
+ c.drawString(1 * inch, y_position, text)
179
+ y_position -= 0.3 * inch
180
+ if y_position < 1 * inch:
181
+ c.showPage()
182
+ c.setFont("Helvetica", 10)
183
+ y_position = 10 * inch
184
+
185
+ c.showPage()
186
  c.save()
187
  pdf_file.seek(0)
188
+
189
+ with open(pdf_path, "wb") as f:
190
+ f.write(pdf_file.getvalue())
191
+ public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
192
+ logger.info(f"PDF generated: {public_url}")
193
+ return pdf_path, public_url, pdf_file
194
+ except Exception as e:
195
+ logger.error(f"Error generating PDF: {e}")
196
+ return "", "", None
197
+
198
+ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
199
+ try:
200
+ if not pdf_file:
201
+ logger.error("No PDF file provided for upload")
202
+ return ""
203
+ encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
204
+ content_version_data = {
205
+ "Title": f"Safety_Violation_Report_{int(time.time())}",
206
+ "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
207
+ "VersionData": encoded_pdf,
208
+ "FirstPublishLocationId": report_id
209
+ }
210
+ content_version = sf.ContentVersion.create(content_version_data)
211
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
212
+ if not result['records']:
213
+ logger.error("Failed to retrieve ContentVersion")
214
+ return ""
215
+ file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
216
+ logger.info(f"PDF uploaded to Salesforce: {file_url}")
217
+ return file_url
218
+ except Exception as e:
219
+ logger.error(f"Error uploading PDF to Salesforce: {e}")
220
+ return ""
221
+
222
+ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
223
+ try:
224
+ sf = connect_to_salesforce()
225
+ violations_text = "\n".join(
226
+ f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
227
+ for v in violations
228
+ ) or "No violations detected."
229
+ pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
230
+
231
+ record_data = {
232
+ "Compliance_Score__c": score,
233
+ "Violations_Found__c": len(violations),
234
+ "Violations_Details__c": violations_text,
235
+ "Status__c": "Pending",
236
+ "PDF_Report_URL__c": pdf_url
237
+ }
238
+ logger.info(f"Creating Salesforce record with data: {record_data}")
239
+ try:
240
+ record = sf.Safety_Video_Report__c.create(record_data)
241
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
242
+ except Exception as e:
243
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
244
+ record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
245
+ logger.warning(f"Fell back to Account record: {record['id']}")
246
+ record_id = record["id"]
247
+
248
+ if pdf_file:
249
+ uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
250
+ if uploaded_url:
251
+ try:
252
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
253
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
254
+ except Exception as e:
255
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
256
+ sf.Account.update(record_id, {"Description": uploaded_url})
257
+ logger.info(f"Updated Account record {record_id} with PDF URL")
258
+ pdf_url = uploaded_url
259
+
260
+ return record_id, pdf_url
261
  except Exception as e:
262
+ logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
263
+ return None, ""
264
+
265
+ def calculate_safety_score(violations):
266
+ penalties = {
267
+ "no_helmet": 25,
268
+ "no_harness": 30,
269
+ "unsafe_posture": 20,
270
+ "unsafe_zone": 35,
271
+ "improper_tool_use": 25
272
+ }
273
+ # Count unique violations per worker
274
+ unique_violations = set()
275
+ for v in violations:
276
+ key = (v["worker_id"], v["violation"])
277
+ unique_violations.add(key)
278
+
279
+ total_penalty = sum(penalties.get(violation, 0) for _, violation in unique_violations)
280
+ score = 100 - total_penalty
281
+ return max(score, 0)
282
 
283
  # ==========================
284
+ # Enhanced Video Processing
285
  # ==========================
286
+ def process_video(video_data):
 
287
  try:
288
+ video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
289
+ with open(video_path, "wb") as f:
290
+ f.write(video_data)
291
+ logger.info(f"Video saved: {video_path}")
292
+
293
+ video = cv2.VideoCapture(video_path)
294
+ if not video.isOpened():
295
+ raise ValueError("Could not open video file")
296
+
297
+ violations = []
298
+ snapshots = []
299
+ frame_count = 0
300
+ start_time = time.time()
301
+ fps = video.get(cv2.CAP_PROP_FPS)
302
+ if fps <= 0:
303
+ fps = 30 # Default assumption if FPS cannot be determined
304
+
305
+ # Structure to track workers and their violations
306
+ workers = []
307
+ violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
308
+ snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
309
+
310
+ logger.info(f"Processing video with FPS: {fps}")
311
+ logger.info(f"Looking for violations: {CONFIG['VIOLATION_LABELS']}")
312
+
313
+ while True:
314
+ ret, frame = video.read()
315
+ if not ret:
316
+ break
317
+
318
+ if frame_count % CONFIG["FRAME_SKIP"] != 0:
319
+ frame_count += 1
320
+ continue
321
+
322
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
323
+ logger.info("Processing time limit reached")
324
+ break
325
+
326
+ current_time = frame_count / fps
327
+
328
+ # Run detection on this frame
329
+ results = model(frame, device=device)
330
+
331
+ current_detections = []
332
+ for result in results:
333
+ boxes = result.boxes
334
+ for box in boxes:
335
+ cls = int(box.cls)
336
+ conf = float(box.conf)
337
+ label = CONFIG["VIOLATION_LABELS"].get(cls, None)
338
+
339
+ if label is None:
340
+ continue
341
+
342
+ if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
343
+ continue
344
+
345
+ bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
346
+
347
+ current_detections.append({
348
+ "frame": frame_count,
349
+ "violation": label,
350
+ "confidence": round(conf, 2),
351
+ "bounding_box": bbox,
352
+ "timestamp": current_time
353
+ })
354
+
355
+ # Process detections and associate with workers
356
+ for detection in current_detections:
357
+ # Find matching worker
358
+ matched_worker = None
359
+ max_iou = 0
360
+
361
+ for worker in workers:
362
+ iou = calculate_iou(detection["bounding_box"], worker["bbox"])
363
+ if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
364
+ max_iou = iou
365
+ matched_worker = worker
366
+
367
+ if matched_worker:
368
+ # Update worker's position
369
+ matched_worker["bbox"] = detection["bounding_box"]
370
+ matched_worker["last_seen"] = current_time
371
+ worker_id = matched_worker["id"]
372
+ else:
373
+ # New worker
374
+ worker_id = len(workers) + 1
375
+ workers.append({
376
+ "id": worker_id,
377
+ "bbox": detection["bounding_box"],
378
+ "first_seen": current_time,
379
+ "last_seen": current_time
380
+ })
381
+
382
+ # Add to violation history
383
+ detection["worker_id"] = worker_id
384
+ violation_history[detection["violation"]].append(detection)
385
+
386
+ frame_count += 1
387
+
388
+ video.release()
389
+ os.remove(video_path)
390
 
391
+ # Process violation history to confirm persistent violations
392
+ for violation_type, detections in violation_history.items():
393
+ if not detections:
394
+ continue
395
+
396
+ # Group by worker
397
+ worker_violations = {}
398
+ for det in detections:
399
+ if det["worker_id"] not in worker_violations:
400
+ worker_violations[det["worker_id"]] = []
401
+ worker_violations[det["worker_id"]].append(det)
402
+
403
+ # Check each worker's violations for persistence
404
+ for worker_id, worker_dets in worker_violations.items():
405
+ if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
406
+ # Take the highest confidence detection
407
+ best_detection = max(worker_dets, key=lambda x: x["confidence"])
408
+ violations.append(best_detection)
409
+
410
+ # Capture snapshot if not already taken
411
+ if not snapshot_taken[violation_type]:
412
+ # Get the frame for this violation
413
+ cap = cv2.VideoCapture(video_path)
414
+ cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
415
+ ret, snapshot_frame = cap.read()
416
+ cap.release()
417
+
418
+ if ret:
419
+ # Draw detections on snapshot
420
+ snapshot_frame = draw_detections(snapshot_frame, [best_detection])
421
+
422
+ snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
423
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
424
+ cv2.imwrite(snapshot_path, snapshot_frame)
425
+ snapshots.append({
426
+ "violation": violation_type,
427
+ "frame": best_detection["frame"],
428
+ "snapshot_path": snapshot_path,
429
+ "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
430
+ })
431
+ snapshot_taken[violation_type] = True
432
+
433
+ # Final processing
434
+ if not violations:
435
+ logger.info("No persistent violations detected")
436
+ return {
437
+ "violations": [],
438
+ "snapshots": [],
439
+ "score": 100,
440
+ "salesforce_record_id": None,
441
+ "violation_details_url": "",
442
+ "message": "No violations detected in the video."
443
+ }
444
+
445
+ score = calculate_safety_score(violations)
446
+ pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
447
+ report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
448
+
449
+ return {
450
+ "violations": violations,
451
+ "snapshots": snapshots,
452
+ "score": score,
453
+ "salesforce_record_id": report_id,
454
+ "violation_details_url": final_pdf_url,
455
+ "message": ""
456
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  except Exception as e:
458
+ logger.error(f"Error processing video: {e}", exc_info=True)
459
+ return {
460
+ "violations": [],
461
+ "snapshots": [],
462
+ "score": 100,
463
+ "salesforce_record_id": None,
464
+ "violation_details_url": "",
465
+ "message": f"Error processing video: {e}"
466
+ }
467
 
468
  # ==========================
469
  # Gradio Interface
470
  # ==========================
471
+ def gradio_interface(video_file):
 
472
  if not video_file:
473
+ return "No file uploaded.", "", "No file uploaded.", "", ""
 
474
  try:
475
+ yield "Processing video... please wait.", "", "", "", ""
476
+
477
+ with open(video_file, "rb") as f:
478
+ video_data = f.read()
479
+
480
+ result = process_video(video_data)
481
+
482
+ if result.get("message"):
483
+ yield result["message"], "", "", "", ""
484
+ return
485
+
486
+ violation_table = "No violations detected."
487
+ if result["violations"]:
488
+ header = "| Violation | Timestamp (s) | Confidence | Worker ID |\n"
489
+ separator = "|------------------------|---------------|------------|-----------|\n"
490
+ rows = []
491
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
492
+ for v in result["violations"]:
493
+ display_name = violation_name_map.get(v["violation"], v["violation"])
494
+ row = f"| {display_name:<22} | {v['timestamp']:.2f} | {v['confidence']:.2f} | {v['worker_id']} |"
495
+ rows.append(row)
496
+ violation_table = header + separator + "\n".join(rows)
497
+
498
+ snapshots_text = "No snapshots captured."
499
+ if result["snapshots"]:
500
+ violation_name_map = CONFIG["DISPLAY_NAMES"]
501
+ snapshots_text = "\n".join(
502
+ f"- Snapshot for {violation_name_map.get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
503
+ for s in result["snapshots"]
504
  )
505
+
506
+ yield (
507
+ violation_table,
508
+ f"Safety Score: {result['score']}%",
509
+ snapshots_text,
510
+ f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
511
+ result["violation_details_url"] or "N/A"
512
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  except Exception as e:
514
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
515
+ yield f"Error: {str(e)}", "", "Error in processing.", "", ""
516
 
 
517
  interface = gr.Interface(
518
+ fn=gradio_interface,
519
  inputs=gr.Video(label="Upload Site Video"),
520
+ outputs=[
521
+ gr.Markdown(label="Detected Safety Violations"),
522
+ gr.Textbox(label="Compliance Score"),
523
+ gr.Markdown(label="Snapshots"),
524
+ gr.Textbox(label="Salesforce Record ID"),
525
+ gr.Textbox(label="Violation Details URL")
526
  ],
527
+ title="Worksite Safety Violation Analyzer",
528
+ description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
 
 
 
529
  allow_flagging="never"
530
  )
531
 
532
  if __name__ == "__main__":
533
+ logger.info("Launching Enhanced Safety Analyzer App...")
534
+ interface.launch()