File size: 23,702 Bytes
dc66d57
 
 
 
6c6d7c5
dc66d57
 
64f085e
 
 
 
 
 
6c6d7c5
238b4a9
40b5e84
04fdfdf
dc66d57
238b4a9
dc66d57
f225f91
8aef0a6
238b4a9
f225f91
 
 
238b4a9
8aef0a6
 
 
2d7e132
8aef0a6
0615d03
 
 
 
 
238b4a9
 
 
 
 
 
 
f4592c4
8f347c4
8aef0a6
 
 
 
f4592c4
238b4a9
763d258
40b5e84
12dad16
 
550ca2a
 
 
40b5e84
238b4a9
c49fe29
0615d03
763d258
f225f91
bcf309d
238b4a9
 
 
 
27626a3
238b4a9
64f085e
238b4a9
f4592c4
238b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4592c4
 
238b4a9
8aef0a6
 
 
d7cc76c
 
 
238b4a9
 
 
 
 
8aef0a6
 
 
238b4a9
 
 
 
8aef0a6
 
238b4a9
 
 
8aef0a6
238b4a9
 
 
 
8aef0a6
40b5e84
238b4a9
 
 
8aef0a6
238b4a9
8aef0a6
6827f40
40b5e84
6827f40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cc76c
 
6827f40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cc76c
6827f40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cc76c
6827f40
 
 
64f085e
238b4a9
64f085e
238b4a9
64f085e
238b4a9
 
 
 
 
 
 
 
 
 
 
 
550ca2a
238b4a9
 
0615d03
550ca2a
 
238b4a9
 
 
0615d03
238b4a9
0615d03
 
 
 
 
 
 
 
 
 
 
 
 
 
238b4a9
0615d03
238b4a9
 
40b5e84
 
0615d03
 
 
 
550ca2a
238b4a9
0615d03
238b4a9
 
 
 
 
 
 
 
 
 
d7cc76c
238b4a9
 
40b5e84
d7cc76c
238b4a9
 
 
 
 
0615d03
550ca2a
238b4a9
 
 
 
550ca2a
238b4a9
0615d03
d7cc76c
238b4a9
d7cc76c
 
 
 
 
40b5e84
 
 
 
 
 
 
 
 
550ca2a
 
 
 
 
 
 
 
 
 
 
40b5e84
238b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550ca2a
 
238b4a9
40b5e84
d7cc76c
40b5e84
 
238b4a9
40b5e84
 
 
238b4a9
550ca2a
 
238b4a9
 
d7cc76c
238b4a9
 
 
 
 
 
 
 
 
 
40b5e84
 
 
550ca2a
 
 
 
 
 
 
40b5e84
238b4a9
 
 
40b5e84
 
 
 
238b4a9
 
 
 
550ca2a
 
 
 
 
238b4a9
550ca2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238b4a9
 
 
12dad16
 
238b4a9
 
 
 
 
12dad16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7cc76c
12dad16
 
 
 
 
 
 
64f085e
238b4a9
12dad16
ba9ee16
64f085e
f4592c4
64f085e
238b4a9
27626a3
238b4a9
ba9ee16
238b4a9
 
 
12dad16
 
ba9ee16
238b4a9
 
ba9ee16
921f6bb
238b4a9
ba9ee16
12dad16
238b4a9
 
 
 
 
ba9ee16
238b4a9
 
8aef0a6
ba9ee16
 
 
238b4a9
6827f40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
import uuid

# ==========================
# Enhanced Configuration
# ==========================
CONFIG = {
    "MODEL_PATH": "yolov8_safety.pt",
    "FALLBACK_MODEL": "yolov8n.pt",
    "OUTPUT_DIR": "static/output",
    "VIOLATION_LABELS": {
        0: "no_helmet",
        1: "no_harness",
        2: "unsafe_posture",
        3: "unsafe_zone",
        4: "improper_tool_use"
    },
    "CLASS_COLORS": {
        "no_helmet": (0, 0, 255),
        "no_harness": (0, 165, 255),
        "unsafe_posture": (0, 255, 0),
        "unsafe_zone": (255, 0, 0),
        "improper_tool_use": (255, 255, 0)
    },
    "DISPLAY_NAMES": {
        "no_helmet": "No Helmet Violation",
        "no_harness": "No Harness Violation",
        "unsafe_posture": "Unsafe Posture Violation",
        "unsafe_zone": "Unsafe Zone Entry",
        "improper_tool_use": "Improper Tool Use"
    },
    "SF_CREDENTIALS": {
        "username": "prashanth1ai@safety.com",
        "password": "SaiPrash461",
        "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
        "domain": "login"
    },
    "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
    "FRAME_SKIP": 50,  # Increased to process fewer frames (1 frame every 5 frames)
    "CONFIDENCE_THRESHOLDS": {
        "no_helmet": 0.5,
        "no_harness": 0.15,
        "unsafe_posture": 0.15,
        "unsafe_zone": 0.15,
        "improper_tool_use": 0.15
    },
    "IOU_THRESHOLD": 0.4,
    "MIN_VIOLATION_FRAMES":2,  # Reduced to ensure violations are detected with fewer frames
    "HELMET_CONFIDENCE_THRESHOLD": 0.7,
    "MAX_PROCESSING_TIME": 120  # Enforce 30-second processing limit
}

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

def load_model():
    try:
        if os.path.isfile(CONFIG["MODEL_PATH"]):
            model_path = CONFIG["MODEL_PATH"]
            logger.info(f"Model loaded: {model_path}")
        else:
            model_path = CONFIG["FALLBACK_MODEL"]
            logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
            if not os.path.isfile(model_path):
                logger.info(f"Downloading fallback model: {model_path}")
                torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
        model = YOLO(model_path).to(device)
        return model
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise

model = load_model()

# ==========================
# Enhanced Helper Functions
# ==========================
def draw_detections(frame, detections):
    for det in detections:
        label = det.get("violation", "Unknown")
        confidence = det.get("confidence", 0.0)
        x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
        
        x1 = int(x - w/2)
        y1 = int(y - h/2)
        x2 = int(x + w/2)
        y2 = int(y + h/2)
        
        color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
        
        display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)}: {confidence:.2f}"
        cv2.putText(frame, display_text, (x1, y1-10), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    return frame

def calculate_iou(box1, box2):
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2
    
    x1_min, y1_min = x1 - w1/2, y1 - h1/2
    x1_max, y1_max = x1 + w1/2, y1 + h1/2
    x2_min, y2_min = x2 - w2/2, y2 - h2/2
    x2_max, y2_max = x2 + w2/2, y2 + h2/2
    
    intersection = max(0, x1_max - x1_min) * max(0, y1_max - y1_min)
    area1 = w1 * h1
    area2 = w2 * h2
    union = area1 + area2 - intersection
    
    return intersection / union if union > 0 else 0

# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
    try:
        sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
        logger.info("Connected to Salesforce")
        sf.describe()
        return sf
    except Exception as e:
        logger.error(f"Salesforce connection failed: {e}")
        raise

def generate_violation_pdf(violations, score):
    try:
        pdf_filename = f"violations_{int(time.time())}.pdf"
        pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
        pdf_file = BytesIO()
        c = canvas.Canvas(pdf_file, pagesize=letter)
        c.setFont("Helvetica", 12)
        c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
        c.setFont("Helvetica", 10)

        y_position = 9.5 * inch
        report_data = {
            "Compliance Score": f"{score}%",
            "Violations Found": len(violations),
            "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        for key, value in report_data.items():
            c.drawString(1 * inch, y_position, f"{key}: {value}")
            y_position -= 0.3 * inch

        y_position -= 0.3 * inch
        c.drawString(1 * inch, y_position, "Violation Details:")
        y_position -= 0.3 * inch
        if not violations:
            c.drawString(1 * inch, y_position, "No violations detected.")
        else:
            for v in violations:
                display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
                text = f"{display_name} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
                c.drawString(1 * inch, y_position, text)
                y_position -= 0.3 * inch
                if y_position < 1 * inch:
                    c.showPage()
                    c.setFont("Helvetica", 10)
                    y_position = 10 * inch

        c.showPage()
        c.save()
        pdf_file.seek(0)

        with open(pdf_path, "wb") as f:
            f.write(pdf_file.getvalue())
        public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
        logger.info(f"PDF generated: {public_url}")
        return pdf_path, public_url, pdf_file
    except Exception as e:
        logger.error(f"Error generating PDF: {e}")
        return "", "", None

def upload_pdf_to_salesforce(sf, pdf_file, report_id):
    try:
        if not pdf_file:
            logger.error("No PDF file provided for upload")
            return ""
        encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
        content_version_data = {
            "Title": f"Safety_Violation_Report_{int(time.time())}",
            "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
            "VersionData": encoded_pdf,
            "FirstPublishLocationId": report_id
        }
        content_version = sf.ContentVersion.create(content_version_data)
        result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
        if not result['records']:
            logger.error("Failed to retrieve ContentVersion")
            return ""
        file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
        logger.info(f"PDF uploaded to Salesforce: {file_url}")
        return file_url
    except Exception as e:
        logger.error(f"Error uploading PDF to Salesforce: {e}")
        return ""

def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
    try:
        sf = connect_to_salesforce()
        violations_text = "\n".join(
            f"{CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')} at {v.get('timestamp', 0.0):.2f}s (Confidence: {v.get('confidence', 0.0):.2f})"
            for v in violations
        ) or "No violations detected."
        pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""

        record_data = {
            "Compliance_Score__c": score,
            "Violations_Found__c": len(violations),
            "Violations_Details__c": violations_text,
            "Status__c": "Pending",
            "PDF_Report_URL__c": pdf_url
        }
        logger.info(f"Creating Salesforce record with data: {record_data}")
        try:
            record = sf.Safety_Video_Report__c.create(record_data)
            logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
        except Exception as e:
            logger.error(f"Failed to create Safety_Video_Report__c: {e}")
            record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
            logger.warning(f"Fell back to Account record: {record['id']}")
        record_id = record["id"]

        if pdf_file:
            uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
            if uploaded_url:
                try:
                    sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
                    logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
                except Exception as e:
                    logger.error(f"Failed to update Safety_Video_Report__c: {e}")
                    sf.Account.update(record_id, {"Description": uploaded_url})
                    logger.info(f"Updated Account record {record_id} with PDF URL")
                pdf_url = uploaded_url

        return record_id, pdf_url
    except Exception as e:
        logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
        return None, ""

def calculate_safety_score(violations):
    penalties = {
        "no_helmet": 25,
        "no_harness": 30,
        "unsafe_posture": 20,
        "unsafe_zone": 35,
        "improper_tool_use": 25
    }
    total_penalty = sum(penalties.get(v.get("violation", "Unknown"), 0) for v in violations)
    score = 100 - total_penalty
    return max(score, 0)

# ==========================
# Enhanced Video Processing
# ==========================
def process_video(video_data):
    try:
        video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
        with open(video_path, "wb") as f:
            f.write(video_data)
        logger.info(f"Video saved: {video_path}")

        video = cv2.VideoCapture(video_path)
        if not video.isOpened():
            raise ValueError("Could not open video file")

        violations = []
        snapshots = []
        frame_count = 0
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = video.get(cv2.CAP_PROP_FPS)
        if fps <= 0:
            fps = 30
        video_duration = total_frames / fps
        logger.info(f"Video duration: {video_duration:.2f} seconds, Total frames: {total_frames}, FPS: {fps}")

        workers = []
        violation_history = {label: [] for label in CONFIG["VIOLATION_LABELS"].values()}
        confirmed_violations = {}
        snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
        helmet_compliance = {}
        detection_counts = {label: 0 for label in CONFIG["VIOLATION_LABELS"].values()}
        start_time = time.time()

        # Calculate frames to process within 30 seconds
        target_frames = int(total_frames / CONFIG["FRAME_SKIP"])
        frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)

        processed_frames = 0
        for idx in frame_indices:
            elapsed_time = time.time() - start_time
            if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
                logger.info(f"Processing time limit of {CONFIG['MAX_PROCESSING_TIME']} seconds reached. Processed {processed_frames}/{target_frames} frames.")
                break

            video.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = video.read()
            if not ret:
                continue

            processed_frames += 1
            current_time = idx / fps
            progress = (processed_frames / target_frames) * 100
            yield f"Processing video... {progress:.1f}% complete (Frame {idx}/{total_frames})", "", "", "", ""

            # Run detection on this frame
            results = model(frame, device=device, conf=0.1, iou=CONFIG["IOU_THRESHOLD"])
            
            current_detections = []
            for result in results:
                boxes = result.boxes
                for box in boxes:
                    cls = int(box.cls)
                    conf = float(box.conf)
                    label = CONFIG["VIOLATION_LABELS"].get(cls, None)
                    
                    if label is None:
                        logger.warning(f"Unknown class ID {cls} detected, skipping")
                        continue
                        
                    if conf < CONFIG["CONFIDENCE_THRESHOLDS"].get(label, 0.25):
                        logger.debug(f"Detection {label} with confidence {conf:.2f} below threshold, skipping")
                        continue

                    bbox = [round(x, 2) for x in box.xywh.cpu().numpy()[0]]
                    
                    current_detections.append({
                        "frame": idx,
                        "violation": label,
                        "confidence": round(conf, 2),
                        "bounding_box": bbox,
                        "timestamp": current_time
                    })
                    detection_counts[label] += 1

            logger.debug(f"Frame {idx}: Detected {len(current_detections)} violations: {[d['violation'] for d in current_detections]}")

            for detection in current_detections:
                violation_type = detection.get("violation", None)
                if violation_type is None:
                    logger.error(f"Invalid detection, missing 'violation' key: {detection}")
                    continue

                if violation_type == "no_helmet":
                    matched_worker = None
                    max_iou = 0
                    for worker in workers:
                        iou = calculate_iou(detection["bounding_box"], worker["bbox"])
                        if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
                            max_iou = iou
                            matched_worker = worker
                    
                    if matched_worker:
                        worker_id = matched_worker["id"]
                        if worker_id not in helmet_compliance:
                            helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
                        helmet_compliance[worker_id]["no_helmet_frames"] += 1
                        if detection["confidence"] < CONFIG["HELMET_CONFIDENCE_THRESHOLD"]:
                            helmet_compliance[worker_id]["compliant"] = True
                            logger.debug(f"Worker {worker_id} marked as helmet compliant due to low no_helmet confidence")
                        if helmet_compliance[worker_id]["compliant"]:
                            logger.debug(f"Worker {worker_id} has helmet, skipping no_helmet violation")
                            continue
                
                matched_worker = None
                max_iou = 0
                
                for worker in workers:
                    iou = calculate_iou(detection["bounding_box"], worker["bbox"])
                    if iou > max_iou and iou > CONFIG["IOU_THRESHOLD"]:
                        max_iou = iou
                        matched_worker = worker

                if matched_worker:
                    matched_worker["bbox"] = detection["bounding_box"]
                    matched_worker["last_seen"] = current_time
                    worker_id = matched_worker["id"]
                else:
                    worker_id = len(workers) + 1
                    workers.append({
                        "id": worker_id,
                        "bbox": detection["bounding_box"],
                        "first_seen": current_time,
                        "last_seen": current_time
                    })
                    if worker_id not in helmet_compliance:
                        helmet_compliance[worker_id] = {"no_helmet_frames": 0, "compliant": False}
                
                if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
                    logger.debug(f"Violation {violation_type} already confirmed for worker {worker_id}, skipping")
                    continue
                
                detection["worker_id"] = worker_id
                violation_history[violation_type].append(detection)

            workers = [w for w in workers if current_time - w["last_seen"] < 5.0]

        logger.info(f"Detection counts: {detection_counts}")

        for violation_type, detections in violation_history.items():
            if not detections:
                logger.info(f"No detections for {violation_type}")
                continue
                
            worker_violations = {}
            for det in detections:
                if det["worker_id"] not in worker_violations:
                    worker_violations[det["worker_id"]] = []
                worker_violations[det["worker_id"]].append(det)
            
            for worker_id, worker_dets in worker_violations.items():
                if len(worker_dets) >= CONFIG["MIN_VIOLATION_FRAMES"]:
                    if worker_id in confirmed_violations and violation_type in confirmed_violations[worker_id]:
                        continue
                        
                    if violation_type == "no_helmet":
                        if worker_id in helmet_compliance and helmet_compliance[worker_id]["compliant"]:
                            logger.debug(f"Skipping no_helmet for worker {worker_id} due to helmet compliance")
                            continue
                        if helmet_compliance[worker_id]["no_helmet_frames"] < CONFIG["MIN_VIOLATION_FRAMES"] * 2:
                            logger.debug(f"Skipping no_helmet for worker {worker_id}, not enough persistent detections")
                            continue
                    
                    best_detection = max(worker_dets, key=lambda x: x["confidence"])
                    violations.append(best_detection)
                    
                    if worker_id not in confirmed_violations:
                        confirmed_violations[worker_id] = set()
                    confirmed_violations[worker_id].add(violation_type)
                    
                    if not snapshot_taken[violation_type]:
                        cap = cv2.VideoCapture(video_path)
                        cap.set(cv2.CAP_PROP_POS_FRAMES, best_detection["frame"])
                        ret, snapshot_frame = cap.read()
                        if not ret:
                            logger.error(f"Failed to capture snapshot for {violation_type} at frame {best_detection['frame']}")
                            cap.release()
                            continue
                        snapshot_frame = draw_detections(snapshot_frame, [best_detection])
                        
                        snapshot_filename = f"{violation_type}_{best_detection['frame']}.jpg"
                        snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], snapshot_filename)
                        cv2.imwrite(snapshot_path, snapshot_frame)
                        snapshots.append({
                            "violation": violation_type,
                            "frame": best_detection["frame"],
                            "snapshot_path": snapshot_path,
                            "snapshot_base64": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}"
                        })
                        snapshot_taken[violation_type] = True
                        logger.info(f"Snapshot taken for {violation_type} at frame {best_detection['frame']}")
                        cap.release()

        video.release()
        os.remove(video_path)
        logger.info(f"Video file {video_path} removed")

        if not violations:
            logger.info("No persistent violations detected")
            yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A"
            return

        score = calculate_safety_score(violations)
        pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
        report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)

        violation_table = "| Violation              | Timestamp (s) | Confidence | Worker ID |\n"
        violation_table += "|------------------------|---------------|------------|-----------|\n"
        for v in violations:
            display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
            row = f"| {display_name:<22} | {v.get('timestamp', 0.0):.2f}        | {v.get('confidence', 0.0):.2f}      | {v.get('worker_id', 'N/A')} |\n"
            violation_table += row

        snapshots_text = "No snapshots captured."
        if snapshots:
            violation_name_map = CONFIG["DISPLAY_NAMES"]
            snapshots_text = "\n".join(
                f"- Snapshot for {violation_name_map.get(s.get('violation', 'Unknown'), 'Unknown')} at frame {s.get('frame', 0)}: ![]({s.get('snapshot_base64', '')})"
                for s in snapshots
            )

        logger.info(f"Processing complete: {len(violations)} violations detected, score: {score}%")
        yield (
            violation_table,
            f"Safety Score: {score}%",
            snapshots_text,
            f"Salesforce Record ID: {report_id or 'N/A'}",
            final_pdf_url or "N/A"
        )
    except Exception as e:
        logger.error(f"Error processing video: {e}", exc_info=True)
        yield f"Error processing video: {e}", "", "", "", ""

# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
    if not video_file:
        return "No file uploaded.", "", "No file uploaded.", "", ""
    try:
        with open(video_file, "rb") as f:
            video_data = f.read()

        for status, score, snapshots_text, record_id, details_url in process_video(video_data):
            yield status, score, snapshots_text, record_id, details_url
    except Exception as e:
        logger.error(f"Error in Gradio interface: {e}", exc_info=True)
        yield f"Error: {str(e)}", "", "Error in processing.", "", ""

interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Video(label="Upload Site Video"),
    outputs=[
        gr.Markdown(label="Detected Safety Violations"),
        gr.Textbox(label="Compliance Score"),
        gr.Markdown(label="Snapshots"),
        gr.Textbox(label="Salesforce Record ID"),
        gr.Textbox(label="Violation Details URL")
    ],
    title="Worksite Safety Violation Analyzer",
    description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Non-violations are ignored.",
    allow_flagging="never"
)

if __name__ == "__main__":
    logger.info("Launching Enhanced Safety Analyzer App...")
    interface.launch()