File size: 45,109 Bytes
08fbbb7
d047708
 
 
 
dc66d57
b0446d7
a87de62
 
a409c73
 
a87de62
 
 
 
 
 
 
 
c6381a2
 
 
9f5a84e
 
54ff363
a409c73
c92dc21
d047708
 
 
 
9f5a84e
 
 
 
 
 
 
 
 
 
 
a409c73
d047708
15feb42
d047708
02b45d6
6f75206
d047708
 
a00ff02
 
 
a409c73
 
 
d047708
 
 
 
 
15feb42
a409c73
 
15feb42
 
 
 
 
 
 
a409c73
 
15feb42
 
a409c73
 
15feb42
 
 
d047708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f75206
a409c73
 
 
 
6f75206
a409c73
d047708
 
 
 
 
 
 
 
 
15feb42
a409c73
 
15feb42
d047708
 
 
 
 
15feb42
 
6f75206
a409c73
 
 
 
6f75206
d047708
15feb42
d047708
 
 
 
15feb42
 
d047708
 
15feb42
 
 
a409c73
15feb42
 
 
 
 
 
6f75206
a409c73
 
 
 
6f75206
15feb42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f75206
a409c73
 
 
 
6f75206
15feb42
 
 
 
 
 
 
d047708
 
a00ff02
d047708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f75206
d047708
 
a409c73
d047708
6f75206
 
 
a409c73
 
 
d047708
a87de62
01caa2c
a87de62
a409c73
 
 
 
 
a87de62
 
a00ff02
 
 
 
 
a87de62
 
7ecb9d9
 
a87de62
7ecb9d9
a87de62
 
 
4c4b661
 
 
a87de62
 
 
15feb42
a409c73
af9bf78
 
 
 
a87de62
a409c73
9f5a84e
6f75206
c7afc58
 
 
d047708
6f75206
d047708
6f75206
9f5a84e
6f75206
c7afc58
a409c73
a87de62
 
 
 
01caa2c
 
a87de62
 
5acdd5f
 
 
 
 
 
 
a87de62
a409c73
 
02b45d6
a409c73
 
 
 
a87de62
5acdd5f
a87de62
 
a409c73
a87de62
d047708
220ca2f
 
 
a409c73
c7afc58
220ca2f
d047708
a409c73
 
 
 
 
5acdd5f
a409c73
 
 
 
 
5acdd5f
a409c73
 
 
 
 
5acdd5f
a409c73
 
 
a87de62
d047708
a87de62
c6381a2
 
 
d047708
c6381a2
 
 
 
 
6f75206
 
d047708
 
 
 
 
 
 
577f505
d047708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a409c73
 
a87de62
54ff363
a87de62
c6381a2
54ff363
c6381a2
 
d047708
577f505
d047708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6381a2
d047708
 
 
 
c6381a2
d047708
c6381a2
d047708
 
 
 
c6381a2
d047708
 
 
 
 
357b766
d047708
 
 
 
 
357b766
 
 
 
c6381a2
 
a87de62
c6381a2
 
 
 
a87de62
c6381a2
 
a87de62
 
 
 
 
c6381a2
 
a87de62
 
 
 
 
c6381a2
7ecb9d9
c6381a2
 
 
 
 
 
 
 
 
7ecb9d9
c6381a2
 
 
 
 
 
 
 
7ecb9d9
c6381a2
 
7ecb9d9
c6381a2
a87de62
7ecb9d9
d047708
 
 
 
 
 
 
 
 
c6381a2
7ecb9d9
a87de62
 
7ecb9d9
 
c6381a2
7ecb9d9
c6381a2
7ecb9d9
 
5acdd5f
c6381a2
 
7ecb9d9
c6381a2
 
 
 
 
 
5acdd5f
c6381a2
 
01caa2c
c6381a2
5acdd5f
c6381a2
 
a87de62
9f5a84e
 
a87de62
54ff363
 
 
 
 
 
 
 
 
 
 
 
 
a00ff02
54ff363
 
 
 
 
6f75206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a409c73
6f75206
 
 
 
 
 
 
 
 
54ff363
 
 
 
 
a87de62
9f5a84e
 
068739f
 
 
 
 
54ff363
068739f
54ff363
068739f
 
 
54ff363
7ecb9d9
d047708
c6381a2
 
 
577f505
357b766
9f5a84e
 
 
d047708
 
 
 
 
 
 
02b45d6
 
af9bf78
a409c73
c7afc58
c6381a2
 
c7afc58
 
220ca2f
15feb42
7ecb9d9
c7afc58
7ecb9d9
c6381a2
a409c73
c6381a2
7ecb9d9
c6381a2
c7afc58
c6381a2
c7afc58
c6381a2
7ecb9d9
 
c7afc58
 
7ecb9d9
6f75206
220ca2f
c6381a2
 
c7afc58
 
c6381a2
c7afc58
a409c73
c6381a2
6f75206
c7afc58
7ecb9d9
 
c7afc58
a87de62
7ecb9d9
01caa2c
 
 
 
 
 
9f5a84e
a409c73
02b45d6
a409c73
 
 
 
 
9f5a84e
 
a409c73
02b45d6
a277b4f
02b45d6
 
9f5a84e
220ca2f
 
c7afc58
 
220ca2f
 
c7afc58
220ca2f
 
6f75206
c6381a2
d047708
a409c73
 
 
c7afc58
01caa2c
a409c73
 
 
 
 
 
 
 
 
01caa2c
 
6f75206
a409c73
01caa2c
a409c73
 
01caa2c
a409c73
 
 
 
 
01caa2c
a409c73
 
01caa2c
a409c73
 
 
01caa2c
15e98c2
d047708
 
 
 
 
 
 
 
af9bf78
d047708
 
a00ff02
a409c73
d047708
01caa2c
d047708
a409c73
d047708
 
af9bf78
 
 
a00ff02
 
 
6f75206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b45d6
a3d6280
c6381a2
d047708
af9bf78
d047708
c7afc58
 
 
 
 
d047708
a00ff02
01caa2c
 
220ca2f
a00ff02
 
 
01caa2c
220ca2f
 
4c4b661
7ecb9d9
d047708
01caa2c
a87de62
7ecb9d9
01caa2c
 
 
 
 
 
 
 
 
 
220ca2f
 
 
01caa2c
 
 
 
 
 
 
220ca2f
01caa2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220ca2f
 
 
7ecb9d9
01caa2c
 
 
 
 
 
 
7ecb9d9
01caa2c
 
 
 
 
d047708
 
 
 
 
 
 
 
7ecb9d9
d047708
 
7ecb9d9
12dad16
7ecb9d9
a87de62
c6381a2
9f5a84e
5acdd5f
01caa2c
12dad16
7ecb9d9
a87de62
9f5a84e
01caa2c
9f5a84e
068739f
 
 
 
 
 
02b45d6
 
a87de62
c6381a2
54ff363
 
7ecb9d9
54ff363
5acdd5f
54ff363
a409c73
54ff363
 
c6381a2
7ecb9d9
54ff363
 
068739f
5acdd5f
068739f
54ff363
 
 
 
 
 
9f5a84e
5acdd5f
9f5a84e
01caa2c
 
d047708
7ecb9d9
c6381a2
01caa2c
54ff363
 
 
 
 
 
 
 
 
 
 
02b45d6
 
7ecb9d9
d047708
c6381a2
 
 
 
 
 
 
 
5acdd5f
01caa2c
c6381a2
 
d047708
c6381a2
 
ba9ee16
 
a409c73
c6381a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
import os
import sys
import subprocess
import logging
import warnings
import cv2
import gradio as gr
import torch
import numpy as np
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
from retrying import retry
import uuid
from multiprocessing import Pool, cpu_count
from functools import partial
import tempfile
import shutil
import tenacity
from scipy.spatial import distance

# ========================== # Configuration and Setup # ==========================
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def check_ffmpeg():
    try:
        subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
        logger.info("FFmpeg is available.")
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        logger.error("FFmpeg is not installed or not found in PATH. Video processing may fail.")
        return False

FFMPEG_AVAILABLE = check_ffmpeg()

# ========================== # BYTETracker Implementation # ==========================
class BYTETracker:
    def __init__(self, track_thresh=0.3, track_buffer=90, match_thresh=0.5, frame_rate=30):
        self.track_thresh = track_thresh
        self.track_buffer = track_buffer
        self.match_thresh = match_thresh
        self.frame_rate = frame_rate
        self.next_id = 1
        self.tracks = {}
        self.worker_history = {}
        self.last_positions = {}
        self.recently_removed = {}
        self.helmet_status = {}
        self.harness_status = {}

    def update(self, dets, scores, cls):
        tracks = []
        current_time = time.time()
        
        # Prune stale tracks
        stale_ids = [track_id for track_id, track_info in self.tracks.items()
                     if current_time - track_info['last_seen'] > self.track_buffer / self.frame_rate]
        for track_id in stale_ids:
            self.recently_removed[track_id] = {
                'bbox': self.tracks[track_id]['bbox'],
                'last_seen': current_time,
                'last_position': self.last_positions.get(track_id, [0, 0])
            }
            del self.tracks[track_id]
            self.worker_history.pop(track_id, None)
            self.last_positions.pop(track_id, None)

        # Clean up recently_removed tracks older than 1 second
        to_remove = [track_id for track_id, info in self.recently_removed.items()
                     if current_time - info['last_seen'] > 1.0]
        for track_id in to_remove:
            del self.recently_removed[track_id]

        for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
            if score < self.track_thresh:
                continue
                
            x, y, w, h = det
            matched = False
            best_iou = 0
            best_track_id = None
            
            for track_id, track_info in self.tracks.items():
                tx, ty, tw, th = track_info['bbox']
                iou = self._calculate_iou([x, y, w, h], [tx, ty, tw, th])
                
                if iou > self.match_thresh and iou > best_iou:
                    best_iou = iou
                    best_track_id = track_id
                    matched = True
            
            if matched:
                self.tracks[best_track_id].update({
                    'bbox': [x, y, w, h],
                    'score': score,
                    'cls': cl,
                    'last_seen': current_time
                })
                
                if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
                    self.helmet_status[best_track_id] = True
                elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
                    self.harness_status[best_track_id] = True
                
                self.worker_history[best_track_id] = self.worker_history.get(best_track_id, []) + [[x, y]]
                self.last_positions[best_track_id] = [x, y]
                
                tracks.append({
                    'id': best_track_id,
                    'bbox': [x, y, w, h],
                    'score': score,
                    'cls': cl
                })
            else:
                reidentified = False
                for track_id, info in list(self.recently_removed.items()):
                    if self._is_same_worker([x, y], info['last_position'], threshold=CONFIG["MAX_WORKER_DISTANCE"]):
                        self.tracks[track_id] = {
                            'bbox': [x, y, w, h],
                            'score': score,
                            'cls': cl,
                            'last_seen': current_time
                        }
                        self.worker_history[track_id] = [[x, y]]
                        self.last_positions[track_id] = [x, y]
                        
                        if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
                            self.helmet_status[track_id] = True
                        elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
                            self.harness_status[track_id] = True
                        
                        tracks.append({
                            'id': track_id,
                            'bbox': [x, y, w, h],
                            'score': score,
                            'cls': cl
                        })
                        reidentified = True
                        del self.recently_removed[track_id]
                        break
                
                if not reidentified:
                    same_worker = False
                    for worker_id, last_pos in self.last_positions.items():
                        if self._is_same_worker([x, y], last_pos, threshold=CONFIG["MAX_WORKER_DISTANCE"]):
                            self.tracks[worker_id] = {
                                'bbox': [x, y, w, h],
                                'score': score,
                                'cls': cl,
                                'last_seen': current_time
                            }
                            
                            if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
                                self.helmet_status[worker_id] = True
                            elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
                                self.harness_status[worker_id] = True
                            
                            tracks.append({
                                'id': worker_id,
                                'bbox': [x, y, w, h],
                                'score': score,
                                'cls': cl
                            })
                            same_worker = True
                            break
                    
                    if not same_worker:
                        self.tracks[self.next_id] = {
                            'bbox': [x, y, w, h],
                            'score': score,
                            'cls': cl,
                            'last_seen': current_time
                        }
                        self.worker_history[self.next_id] = [[x, y]]
                        self.last_positions[self.next_id] = [x, y]
                        
                        if cl == "no_helmet" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
                            self.helmet_status[self.next_id] = True
                        elif cl == "no_harness" and score > CONFIG["CONFIDENCE_THRESHOLDS"]["no_harness"]:
                            self.harness_status[self.next_id] = True
                        
                        tracks.append({
                            'id': self.next_id,
                            'bbox': [x, y, w, h],
                            'score': score,
                            'cls': cl
                        })
                        self.next_id += 1
        
        return tracks

    def _calculate_iou(self, box1, box2):
        x1, y1, w1, h1 = box1
        x2, y2, w2, h2 = box2
        x_left = max(x1 - w1/2, x2 - w2/2)
        y_top = max(y1 - h1/2, y2 - h2/2)
        x_right = min(x1 + w1/2, x2 + w2/2)
        y_bottom = min(y1 + h1/2, y2 + h2/2)
        if x_right < x_left or y_bottom < y_top:
            return 0.0
        intersection_area = (x_right - x_left) * (y_bottom - y_top)
        box1_area = w1 * h1
        box2_area = w2 * h2
        iou = intersection_area / (box1_area + box2_area - intersection_area)
        return iou
    
    def _is_same_worker(self, pos1, pos2, threshold=150):
        x1, y1 = pos1
        x2, y2 = pos2
        return np.sqrt((x1 - x2)**2 + (y1 - y2)**2) < threshold

    def validate_helmet_violation(self, worker_id, current_confidence):
        return worker_id in self.helmet_status and self.helmet_status[worker_id]

    def validate_harness_violation(self, worker_id, current_confidence):
        return worker_id in self.harness_status and self.harness_status[worker_id]

# ========================== # Optimized Configuration # ==========================
CONFIG = {
    "MODEL_NAME": "facebook/detr-resnet-50",
    "VIOLATION_LABELS": {
        "no_helmet": "No Helmet",
        "no_harness": "No Harness",
        "unsafe_posture": "Unsafe Posture",
        "unsafe_zone": "Unsafe Zone",
        "improper_tool_use": "Improper Tool Use"
    },
    "CLASS_COLORS": {
        "no_helmet": (0, 0, 255),
        "no_harness": (0, 165, 255),
        "unsafe_posture": (0, 255, 0),
        "unsafe_zone": (255, 0, 0),
        "improper_tool_use": (255, 255, 0)
    },
    "DISPLAY_NAMES": {
        "no_helmet": "No Helmet Violation",
        "no_harness": "No Harness Violation",
        "unsafe_posture": "Unsafe Posture",
        "unsafe_zone": "Unsafe Zone Entry",
        "improper_tool_use": "Improper Tool Use"
    },
    "SF_CREDENTIALS": {
        "username": os.getenv("SF_USERNAME", "prashanth1ai@safety.com"),
        "password": os.getenv("SF_PASSWORD", "SaiPrash461"),
        "security_token": os.getenv("SF_SECURITY_TOKEN", "AP4AQnPoidIKPvSvNEfAHyoK"),
        "domain": "login"
    },
    "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
    "CONFIDENCE_THRESHOLDS": {
        "no_helmet": 0.45,
        "no_harness": 0.25,
        "unsafe_posture": 0.25,
        "unsafe_zone": 0.25,
        "improper_tool_use": 0.25
    },
    "MIN_VIOLATION_FRAMES": 2,
    "VIOLATION_COOLDOWN": 30.0,
    "WORKER_TRACKING_DURATION": 10.0,
    "MAX_PROCESSING_TIME": 120,  # Increased to allow more time for CPU processing
    "FRAME_SKIP": 4,
    "BATCH_SIZE": 2,  # Reduced for better CPU performance
    "PARALLEL_WORKERS": max(1, cpu_count() - 1),
    "TRACK_BUFFER": 150,
    "TRACK_THRESH": 0.3,
    "MATCH_THRESH": 0.5,
    "SNAPSHOT_QUALITY": 95,
    "MAX_WORKER_DISTANCE": 150,
    "TARGET_RESOLUTION": (256, 256),  # Further reduced for faster inference
    "HELMET_VALIDATION_FRAMES": 3
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
if device.type == "cpu":
    logger.warning("Running on CPU, which may lead to slower processing. Consider using a GPU for better performance.")

def load_model():
    try:
        import timm
        logger.info("timm library is available.")
    except ImportError as e:
        logger.error("timm library is not installed. Install it with: pip install timm")
        raise ImportError("timm is required for DetrConvEncoder. Run `pip install timm` and restart your runtime.") from e

    try:
        processor = DetrImageProcessor.from_pretrained(CONFIG["MODEL_NAME"])
        model = DetrForObjectDetection.from_pretrained(CONFIG["MODEL_NAME"]).to(device)
        if device.type == "cuda":
            model = model.half()
        logger.info(f"Loaded DETR model: {CONFIG['MODEL_NAME']}")
        logger.info(f"Model classes: {model.config.id2label}")
        return processor, model
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")
        raise

processor, model = load_model()

# ========================== # Helper Functions # ==========================
def preprocess_frame(frame):
    target_res = CONFIG["TARGET_RESOLUTION"]
    frame = cv2.resize(frame, target_res, interpolation=cv2.INTER_LINEAR)
    frame = cv2.convertScaleAbs(frame, alpha=1.3, beta=20)
    # Removed cv2.filter2D to reduce processing time
    return frame

def is_unsafe_posture(box, frame_shape):
    x1, y1, x2, y2 = box
    height = y2 - y1
    width = x2 - x1
    aspect_ratio = height / max(width, 1)
    return aspect_ratio > 2.0

def is_improper_tool_use(person_box, tool_box):
    person_center = ((person_box[0] + person_box[2]) / 2, (person_box[1] + person_box[3]) / 2)
    tool_center = ((tool_box[0] + tool_box[2]) / 2, (tool_box[1] + tool_box[3]) / 2)
    dist = distance.euclidean(person_center, tool_center)
    return dist > 100

def is_unsafe_zone(person_box, frame_shape):
    px, py, pw, ph = person_box
    frame_h, frame_w = frame_shape
    person_center = (px + pw / 2, py + ph / 2)
    unsafe_zone = (0, 0, 0.5, 0.5)
    return (unsafe_zone[0] * frame_w < person_center[0] < unsafe_zone[2] * frame_w and
            unsafe_zone[1] * frame_h < person_center[1] < unsafe_zone[3] * frame_h)

def draw_detections(frame, detections):
    result_frame = frame.copy()
    for det in detections:
        label = det.get("violation", "Unknown")
        confidence = det.get("confidence", 0.0)
        x, y, w, h = det.get("bounding_box", [0, 0, 0, 0])
        worker_id = det.get("worker_id", "Unknown")
        x1 = int(x - w/2)
        y1 = int(y - h/2)
        x2 = int(x + w/2)
        y2 = int(y + h/2)
        color = CONFIG["CLASS_COLORS"].get(label, (0, 0, 255))
        line_thickness = 4 if label == "no_helmet" else 3
        cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, line_thickness)
        display_text = f"{CONFIG['DISPLAY_NAMES'].get(label, label)} (Worker {worker_id})"
        text_size = cv2.getTextSize(display_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
        cv2.rectangle(result_frame, (x1, y1-text_size[1]-10), (x1+text_size[0]+10, y1), (0, 0, 0), -1)
        cv2.putText(result_frame, display_text, (x1+5, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        conf_text = f"Conf: {confidence:.2f}"
        cv2.putText(result_frame, conf_text, (x1+5, y2+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
    return result_frame

def calculate_safety_score(violations):
    penalties = {
        "no_helmet": 25,
        "no_harness": 30,
        "unsafe_posture": 20,
        "unsafe_zone": 35,
        "improper_tool_use": 25
    }
    worker_violations = {}
    for v in violations:
        worker_id = v.get("worker_id", "Unknown")
        violation_type = v.get("violation", "Unknown")
        if worker_id not in worker_violations:
            worker_violations[worker_id] = set()
        worker_violations[worker_id].add(violation_type)
    total_penalty = sum(sum(penalties.get(v, 0) for v in worker_violations[wid]) for wid in worker_violations)
    return max(0, 100 - total_penalty)

def generate_violation_pdf(violations, score, output_dir):
    try:
        pdf_filename = f"violations_{int(time.time())}.pdf"
        pdf_path = os.path.join(output_dir, pdf_filename)
        pdf_file = BytesIO()
        c = canvas.Canvas(pdf_file, pagesize=letter)
        c.setFont("Helvetica-Bold", 16)
        c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
        c.setFont("Helvetica", 12)
        c.drawString(1 * inch, 9.5 * inch, f"Date: {time.strftime('%Y-%m-%d')}")
        c.drawString(1 * inch, 9.2 * inch, f"Time: {time.strftime('%H:%M:%S')}")
        c.setFont("Helvetica-Bold", 14)
        c.drawString(1 * inch, 8.7 * inch, f"Safety Compliance Score: {score}%")
        y_position = 8.2 * inch
        c.setFont("Helvetica-Bold", 12)
        c.drawString(1 * inch, y_position, "Summary:")
        y_position -= 0.3 * inch
        worker_violations = {}
        for v in violations:
            worker_id = v.get("worker_id", "Unknown")
            if worker_id not in worker_violations:
                worker_violations[worker_id] = []
            worker_violations[worker_id].append(v)
        c.setFont("Helvetica", 10)
        summary_data = {
            "Total Workers with Violations": len(worker_violations),
            "Total Violations Found": len(violations),
            "Analysis Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        for key, value in summary_data.items():
            c.drawString(1 * inch, y_position, f"{key}: {value}")
            y_position -= 0.25 * inch
        y_position -= 0.5 * inch
        c.setFont("Helvetica-Bold", 12)
        c.drawString(1 * inch, y_position, "Violations by Worker:")
        y_position -= 0.3 * inch
        c.setFont("Helvetica", 10)
        for worker_id, worker_vios in worker_violations.items():
            c.drawString(1 * inch, y_position, f"Worker {worker_id}:")
            y_position -= 0.2 * inch
            for v in worker_vios:
                display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
                time_str = f"{v.get('timestamp', 0.0):.2f}s"
                conf_str = f"{v.get('confidence', 0.0):.2f}"
                violation_text = f"  - {display_name} at {time_str} (Confidence: {conf_str})"
                c.drawString(1.2 * inch, y_position, violation_text)
                y_position -= 0.2 * inch
                if y_position < 1 * inch:
                    c.showPage()
                    c.setFont("Helvetica", 10)
                    y_position = 10 * inch
        c.save()
        pdf_file.seek(0)
        with open(pdf_path, "wb") as f:
            f.write(pdf_file.getvalue())
        public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
        logger.info(f"PDF generated: {public_url}")
        return pdf_path, public_url, pdf_file
    except Exception as e:
        logger.error(f"Error generating PDF: {e}")
        return "", "", None

@retry(stop_max_attempt_number=3, wait_fixed=2000)
def connect_to_salesforce():
    try:
        sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
        logger.info("Connected to Salesforce")
        sf.describe()
        return sf
    except Exception as e:
        logger.error(f"Salesforce connection failed: {e}")
        raise

def upload_pdf_to_salesforce(sf, pdf_file, report_id):
    try:
        if not pdf_file:
            logger.error("No PDF file provided for upload")
            return ""
        encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
        content_version_data = {
            "Title": f"Safety_Violation_Report_{int(time.time())}",
            "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
            "VersionData": encoded_pdf,
            "FirstPublishLocationId": report_id
        }
        content_version = sf.ContentVersion.create(content_version_data)
        result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
        if not result['records']:
            logger.error("Failed to retrieve ContentVersion")
            return ""
        file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
        logger.info(f"PDF uploaded to Salesforce: {file_url}")
        return file_url
    except Exception as e:
        logger.error(f"Error uploading PDF to Salesforce: {e}")
        return ""

def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
    try:
        sf = connect_to_salesforce()
        violations_text = ""
        for v in violations:
            display_name = CONFIG['DISPLAY_NAMES'].get(v.get('violation', 'Unknown'), 'Unknown')
            worker_id = v.get('worker_id', 'Unknown')
            timestamp = v.get('timestamp', 0.0)
            confidence = v.get('confidence', 0.0)
            violations_text += f"Worker {worker_id}: {display_name} at {timestamp:.2f}s (Conf: {confidence:.2f})\n"
        if not violations_text:
            violations_text = "No violations detected."
        pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""
        record_data = {
            "Compliance_Score__c": score,
            "Violations_Found__c": len(violations),
            "Violations_Details__c": violations_text,
            "Status__c": "Pending",
            "PDF_Report_URL__c": pdf_url
        }
        logger.info(f"Creating Salesforce record with data: {record_data}")
        try:
            record = sf.Safety_Video_Report__c.create(record_data)
            logger.info(f"Created record: {record['id']}")
        except Exception as e:
            logger.error(f"Failed to create Safety_Video_Report__c: {e}")
            record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
            logger.warning(f"Fell back to Account record: {record['id']}")
        record_id = record["id"]
        if pdf_file:
            uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
            if uploaded_url:
                try:
                    sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL": uploaded_url})
                    logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
                except Exception as e:
                    logger.error(f"Failed to update Safety_Violation_Report__c: {e}")
                    sf.Account.update(record_id, {"Description": uploaded_url})
                    logger.info(f"Updated account record {record_id} with PDF URL")
                pdf_url = uploaded_url
        return record_id, pdf_url
    except Exception as e:
        logger.error(f"Salesforce record creation failed: {e}")
        return "N/A", "Salesforce integration failed."

@tenacity.retry(
    stop=tenacity.stop_after_attempt(3),
    wait=tenacity.wait_fixed(1),
    retry=tenacity.retry_if_exception_type((IOError, OSError)),
    before_sleep=lambda retry_state: logger.info(f"Retrying file access (attempt {retry_state.attempt_number}/3)...")
)
def verify_and_open_video(video_path):
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Temporary video file not found: {video_path}")
    file_size = os.path.getsize(video_path)
    if file_size == 0:
        raise ValueError(f"Temporary video file is empty: {video_path}")
    with open(video_path, "rb") as f:
        f.read(1)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("Could not open video file. Ensure the video format is supported (e.g., MP4) and FFmpeg is installed.")
    return cap

def validate_helmet_detection(frame, bbox, confidence_threshold=0.45):
    x, y, w, h = bbox
    x1 = int(max(0, x - w/2))
    y1 = int(max(0, y - h/2))
    x2 = int(min(frame.shape[1], x + w/2))
    y2 = int(min(frame.shape[0], y + h/2))
    head_region = frame[y1:y2, x1:x2]
    if head_region.size == 0:
        return False
    hsv = cv2.cvtColor(head_region, cv2.COLOR_BGR2HSV)
    yellow_lower = np.array([20, 100, 100])
    yellow_upper = np.array([30, 255, 255])
    yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
    white_lower = np.array([0, 0, 200])
    white_upper = np.array([180, 30, 255])
    white_mask = cv2.inRange(hsv, white_lower, white_upper)
    orange_lower = np.array([5, 100, 100])
    orange_upper = np.array([15, 255, 255])
    orange_mask = cv2.inRange(hsv, orange_lower, orange_upper)
    blue_lower = np.array([100, 100, 100])
    blue_upper = np.array([130, 255, 255])
    blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
    helmet_mask = cv2.bitwise_or(yellow_mask, white_mask)
    helmet_mask = cv2.bitwise_or(helmet_mask, orange_mask)
    helmet_mask = cv2.bitwise_or(helmet_mask, blue_mask)
    helmet_percentage = np.sum(helmet_mask > 0) / (head_region.shape[0] * head_region.shape[1])
    if helmet_percentage > 0.25:
        return False
    gray = cv2.cvtColor(head_region, cv2.COLOR_BGR2GRAY)
    texture_score = np.std(gray)
    if texture_score < 15:
        return False
    edges = cv2.Canny(gray, 50, 150)
    edge_density = np.sum(edges > 0) / (head_region.shape[0] * head_region.shape[1])
    if edge_density > 0.15:
        return True
    if confidence_threshold >= 0.6:
        return True
    return True

def process_video(video_data, temp_dir):
    video_path = None
    output_dir = os.path.join(temp_dir, "output")
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        if not video_data:
            raise ValueError("Empty video data provided.")
        
        logger.info(f"Received video data size: {len(video_data)} bytes")
        if len(video_data) == 0:
            raise ValueError("Video data is empty.")

        with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
            temp_file.write(video_data)
            temp_file.flush()
            video_path = temp_file.name
        logger.info(f"Video saved to temporary file: {video_path}")

        cap = verify_and_open_video(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS) or 30
        duration = total_frames / fps
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        logger.info(f"Video properties: {duration:.2f}s, {total_frames} frames, {fps:.1f} FPS, {width}x{height}")

        if total_frames <= 0:
            raise ValueError("Video has no frames.")

        tracker = BYTETracker(
            track_thresh=CONFIG["TRACK_THRESH"],
            track_buffer=CONFIG["TRACK_BUFFER"],
            match_thresh=CONFIG["MATCH_THRESH"],
            frame_rate=fps
        )

        worker_id_mapping = {}
        unique_violations = {}
        violation_frames = {}
        helmet_detections = {}
        frame_detections = {}
        start_time = time.time()
        frame_skip = CONFIG["FRAME_SKIP"]
        processed_frames = 0  # Track actual frames processed
        frames_read = 0  # Track frames read from video
        last_yield_time = start_time
        worker_counter = 1

        while True:
            batch_frames = []
            batch_indices = []
            batch_originals = []
            
            for _ in range(CONFIG["BATCH_SIZE"]):
                frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
                frames_read = frame_idx
                if frame_idx >= total_frames:
                    logger.info("Reached end of video.")
                    break
                ret, frame = cap.read()
                if not ret:
                    logger.warning(f"Failed to read frame {frame_idx}. Assuming end of video.")
                    frames_read = total_frames  # Assume we've reached the end
                    break
                original_frame = frame.copy()
                frame = preprocess_frame(frame)
                for _ in range(frame_skip - 1):
                    if not cap.grab():
                        logger.warning(f"Failed to grab frame after {frame_idx}. Assuming end of video.")
                        frames_read = total_frames
                        break
                    frames_read += 1
                batch_frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
                batch_indices.append(frame_idx)
                batch_originals.append(original_frame)
                processed_frames += 1  # Increment for each frame actually processed

            if not batch_frames:
                logger.info("No more frames to process in this batch.")
                break

            # Check for timeout
            elapsed_time = time.time() - start_time
            if elapsed_time > CONFIG["MAX_PROCESSING_TIME"]:
                logger.warning(f"Processing exceeded time limit of {CONFIG['MAX_PROCESSING_TIME']}s. Terminating early.")
                break

            try:
                inputs = processor(images=batch_frames, return_tensors="pt").to(device)
                if device.type == "cuda":
                    inputs = {k: v.half() for k, v in inputs.items()}
                with torch.no_grad():
                    outputs = model(**inputs)
                target_sizes = torch.tensor([frame.size[::-1] for frame in batch_frames]).to(device)
                results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1)
            except Exception as e:
                logger.error(f"Model inference failed: {e}")
                raise ValueError(f"Failed to process video frames with DETR model: {str(e)}")
            finally:
                batch_frames = []
                if device.type == "cuda":
                    torch.cuda.empty_cache()

            current_time = time.time()
            if current_time - last_yield_time > 0.1:
                progress = (frames_read / total_frames) * 100
                progress = min(progress, 100.0)  # Ensure progress doesn't exceed 100%
                elapsed_time = current_time - start_time
                fps_processed = processed_frames / elapsed_time if elapsed_time > 0 else 0
                yield f"Processing video... {progress:.1f}% complete (Frame {frames_read}/{total_frames}, {fps_processed:.1f} FPS)", "", "", "", "", f"Elapsed: {elapsed_time:.1f}s"
                last_yield_time = current_time

            for i, (result, frame_idx, original_frame) in enumerate(zip(results, batch_indices, batch_originals)):
                current_time = frame_idx / fps
                track_inputs = []
                person_boxes = []
                tool_boxes = []

                frame_detections[frame_idx] = []

                for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
                    label_name = model.config.id2label[label.item()]
                    conf = float(score)
                    bbox = box.cpu().numpy()
                    x, y, x2, y2 = bbox
                    w, h = x2 - x, y2 - y
                    bbox_xywh = [x + w/2, y + h/2, w, h]

                    if label_name in ["no_helmet", "no_harness"] and conf >= CONFIG["CONFIDENCE_THRESHOLDS"].get(label_name, 0.25):
                        if label_name == "no_helmet" and not validate_helmet_detection(original_frame, bbox_xywh, conf):
                            logger.info(f"Frame {frame_idx}: Helmet false positive filtered at {conf:.2f} confidence")
                            continue
                        track_inputs.append({"bbox": bbox_xywh, "conf": conf, "cls": label_name})
                        frame_detections[frame_idx].append({"label": label_name, "conf": conf, "bbox": bbox_xywh})
                    elif label_name == "person":
                        person_boxes.append(bbox_xywh)
                    elif label_name in ["hammer", "wrench"]:
                        tool_boxes.append(bbox_xywh)

                for pbox in person_boxes:
                    if is_unsafe_posture(pbox, original_frame.shape[:2]):
                        track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_posture"})
                        frame_detections[frame_idx].append({"label": "unsafe_posture", "conf": 0.9, "bbox": pbox})
                    if is_unsafe_zone(pbox, original_frame.shape[:2]):
                        track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "unsafe_zone"})
                        frame_detections[frame_idx].append({"label": "unsafe_zone", "conf": 0.9, "bbox": pbox})
                    for tbox in tool_boxes:
                        if is_improper_tool_use(pbox, tbox):
                            track_inputs.append({"bbox": pbox, "conf": 0.9, "cls": "improper_tool_use"})
                            frame_detections[frame_idx].append({"label": "improper_tool_use", "conf": 0.9, "bbox": pbox})

                if not track_inputs:
                    continue
                    
                tracked_objects = tracker.update(
                    np.array([t["bbox"] for t in track_inputs]),
                    np.array([t["conf"] for t in track_inputs]),
                    np.array([t["cls"] for t in track_inputs])
                )
                logger.info(f"Frame {frame_idx}: Detected {len(tracked_objects)} workers")

                for obj in tracked_objects:
                    tracker_id = obj['id']
                    label = obj['cls']
                    conf = obj['score']
                    bbox = obj['bbox']
                    
                    if label not in CONFIG["VIOLATION_LABELS"]:
                        continue
                    
                    if tracker_id not in worker_id_mapping:
                        worker_id_mapping[tracker_id] = worker_counter
                        worker_counter += 1
                    
                    worker_id = worker_id_mapping[tracker_id]
                    
                    if label == "no_helmet":
                        if worker_id not in helmet_detections:
                            helmet_detections[worker_id] = []
                        helmet_detections[worker_id].append({
                            "frame_idx": frame_idx,
                            "confidence": conf,
                            "bbox": bbox
                        })
                        if len(helmet_detections[worker_id]) >= CONFIG["HELMET_VALIDATION_FRAMES"]:
                            avg_conf = sum(d["confidence"] for d in helmet_detections[worker_id]) / len(helmet_detections[worker_id])
                            if avg_conf >= CONFIG["CONFIDENCE_THRESHOLDS"]["no_helmet"]:
                                violation_key = (worker_id, label)
                                if violation_key not in unique_violations:
                                    unique_violations[violation_key] = current_time
                                    violation_frames[violation_key] = frame_idx
                                    logger.info(f"Frame {frame_idx}: Valid helmet violation for worker {worker_id} with avg conf {avg_conf:.2f}")
                    else:
                        violation_key = (worker_id, label)
                        if violation_key not in unique_violations:
                            unique_violations[violation_key] = current_time
                            violation_frames[violation_key] = frame_idx

        cap.release()
        processing_time = time.time() - start_time
        logger.info(f"Processing complete in {processing_time:.2f}s")
        logger.info(f"Total unique workers detected: {len(set(worker_id_mapping.values()))}")

        # Ensure final progress update
        final_progress = (frames_read / total_frames) * 100
        final_progress = min(final_progress, 100.0)
        yield f"Processing video... {final_progress:.1f}% complete (Frame {frames_read}/{total_frames})", "", "", "", "", f"Elapsed: {processing_time:.1f}s"

        violations = []
        for (worker_id, label), detection_time in unique_violations.items():
            frame_idx = violation_frames[(worker_id, label)]
            conf = next((d["conf"] for d in frame_detections.get(frame_idx, []) if d["label"] == label), 0.0)
            violations.append({
                "worker_id": worker_id,
                "violation": label,
                "timestamp": detection_time,
                "confidence": conf,
                "frame_idx": violation_frames[(worker_id, label)]
            })

        if not violations:
            logger.info("No violations detected after processing")
            yield "No violations detected in the video.", "Safety Score: 100%", "No snapshots captured.", "N/A", "N/A", f"Completed in {processing_time:.1f}s"
            return

        violation_table = "| Violation | Worker ID | Time (s) | Confidence |\n"
        violation_table += "|-----------|-----------|----------|------------|\n"
        for v in sorted(violations, key=lambda x: (x.get("worker_id", "Unknown"), x.get("timestamp", 0.0))):
            display_name = CONFIG["DISPLAY_NAMES"].get(v.get("violation", "Unknown"), "Unknown")
            worker_id = v.get("worker_id", "Unknown")
            timestamp = v.get("timestamp", 0.0)
            confidence = v.get("confidence", 0.0)
            violation_table += f"| {display_name} | {worker_id} | {timestamp:.2f} | {confidence:.2f} |\n"
        yield violation_table, "", "", "", "", f"Violations detected in {processing_time:.1f}s"

        snapshots = []
        cap = cv2.VideoCapture(video_path)
        for violation in violations:
            try:
                frame_idx = violation["frame_idx"]
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                if not ret:
                    logger.warning(f"Failed to read frame {frame_idx} for snapshot.")
                    continue

                frame = preprocess_frame(frame)
                detections = frame_detections.get(frame_idx, [])
                for det in detections:
                    if det["label"] == violation["violation"]:
                        violation["confidence"] = round(det["conf"], 2)
                        detection = {
                            "worker_id": violation["worker_id"],
                            "violation": det["label"],
                            "confidence": violation["confidence"],
                            "bounding_box": det["bbox"],
                            "timestamp": violation["timestamp"]
                        }
                        snapshot_frame = frame.copy()
                        snapshot_frame = draw_detections(snapshot_frame, [detection])
                        cv2.putText(
                            snapshot_frame,
                            f"Time: {violation['timestamp']:.2f}s",
                            (10, 30),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            0.7,
                            (255, 255, 255),
                            2
                        )
                        snapshot_filename = f"violation_{det['label']}_worker{violation['worker_id']}_{int(violation['timestamp']*100)}.jpg"
                        snapshot_path = os.path.join(output_dir, snapshot_filename)
                        cv2.imwrite(
                            snapshot_path,
                            snapshot_frame,
                            [cv2.IMWRITE_JPEG_QUALITY, CONFIG["SNAPSHOT_QUALITY"]]
                        )
                        snapshots.append({
                            "violation": det["label"],
                            "worker_id": violation["worker_id"],
                            "timestamp": violation["timestamp"],
                            "snapshot_path": snapshot_path,
                            "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{snapshot_filename}",
                            "confidence": violation["confidence"]
                        })
                        logger.info(f"Captured snapshot for {det['label']} violation by worker {violation['worker_id']} at {violation['timestamp']:.2f}s")
                        break
            except Exception as e:
                logger.error(f"Error generating snapshot for violation: {e}")
                continue

        cap.release()

        score = calculate_safety_score(violations)
        pdf_path, pdf_url, pdf_file = "", "", None
        try:
            pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score, output_dir)
        except Exception as e:
            logger.error(f"PDF generation failed: {e}")
            yield violation_table, f"Safety Score: {score}%", "Failed to generate snapshots due to PDF error.", "N/A", "N/A", f"Completed in {processing_time:.1f}s\nError: {str(e)}"
            return

        record_id, final_pdf_url = "N/A", "Salesforce integration failed."
        try:
            record_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)
        except Exception as e:
            logger.error(f"Salesforce integration failed: {e}")

        snapshots_text = ""
        for s in snapshots:
            display_name = CONFIG["DISPLAY_NAMES"].get(s["violation"], "Unknown")
            worker_id = s.get("worker_id", "Unknown")
            timestamp = s.get("timestamp", 0.0)
            snapshots_text += f"### {display_name} - Worker {worker_id} at {timestamp:.2f}s\n\n"
            snapshots_text += f"![Violation]({s['snapshot_url']})\n\n"

        if not snapshots_text:
            snapshots_text = "No snapshots captured."

        yield (
            violation_table,
            f"Safety Score: {score}%",
            snapshots_text,
            f"Salesforce Record ID: {record_id}",
            final_pdf_url,
            f"Completed in {processing_time:.1f}s"
        )

    except Exception as e:
        logger.error(f"Error processing video: {str(e)}", exc_info=True)
        yield f"Error processing video: {str(e)}", "", "", "", "", f"Failed after {time.time() - start_time:.1f}s"
    finally:
        if video_path and os.path.exists(video_path):
            try:
                os.remove(video_path)
                logger.info(f"Cleaned up temporary video file: {video_path}")
            except Exception as e:
                logger.error(f"Failed to clean up temporary video file {video_path}: {e}")
        if device.type == "cuda":
            torch.cuda.empty_cache()

def gradio_interface(video_file):
    temp_dir = None
    local_video_path = None
    try:
        if not video_file:
            return "No file uploaded.", "", "No file uploaded.", "", "", ""
        
        temp_dir = tempfile.mkdtemp(prefix="DETR_")
        logger.info(f"Created temporary directory for video processing: {temp_dir}")

        with open(video_file, "rb") as f:
            video_data = f.read()
        logger.info(f"Read Gradio video file: {video_file}, size: {len(video_data)} bytes")
        
        if len(video_data) == 0:
            return "Uploaded video file is empty.", "", "", "", "", ""

        with tempfile.NamedTemporaryFile(suffix=".mp4", dir=temp_dir, delete=False) as temp_file:
            temp_file.write(video_data)
            temp_file.flush()
            local_video_path = temp_file.name
        logger.info(f"Copied Gradio video to local temporary file: {local_video_path}")

        if not FFMPEG_AVAILABLE:
            return "FFmpeg is not available in the environment. Please install FFmpeg to process videos.", "", "", "", "", ""

        for status, score, snapshots_text, record_id, details_url, log in process_video(video_data, temp_dir):
            yield status, score, snapshots_text, record_id, details_url, log
            
    except Exception as e:
        logger.error(f"Error in Gradio interface: {e}", exc_info=True)
        yield f"Error: {str(e)}", "", "Error in processing.", "", "", str(e)
    finally:
        if local_video_path and os.path.exists(local_video_path):
            try:
                os.remove(local_video_path)
                logger.info(f"Cleaned up local temporary video file: {local_video_path}")
            except Exception as e:
                logger.error(f"Failed to clean up local temporary video file {local_video_path}: {e}")
        
        if temp_dir and os.path.exists(temp_dir):
            shutil.rmtree(temp_dir, ignore_errors=True)
            logger.info(f"Cleaned up temporary directory: {temp_dir}")
        if device.type == "cuda":
            torch.cuda.empty_cache()

# ========================== # Gradio Interface # ==========================
interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Video(label="Upload Site Video"),
    outputs=[
        gr.Markdown(label="Detected Safety Violations"),
        gr.Textbox(label="Compliance Score"),
        gr.Markdown(label="Snapshots"),
        gr.Textbox(label="Salesforce Record ID"),
        gr.Textbox(label="Violation Details URL"),
        gr.Textbox(label="Processing Log")
    ],
    title="Worksite Safety Violation Analyzer",
    description="Upload site videos to detect safety violations (No Helmet, No Harness, Unsafe Posture, Unsafe Zone, Improper Tool Use). Each unique violation is detected only once per worker.",
    allow_flagging="never"
)

if __name__ == "__main__":
    logger.info("Launching Enhanced Safety Analyzer App with DETR...")
    interface.launch()