File size: 28,667 Bytes
d686824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
import cv2
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from decord import VideoReader, cpu
from scipy.interpolate import UnivariateSpline
import copy
from tqdm import tqdm
import os
import sys
import cv2
import copy
import logging

# Assuming YoloWorldInterface is defined elsewhere and imported correctly
# from your_project.yolo_interface import YoloWorldInterface
# 导入自定义的 TStar 接口
# from TStar.interface_yolo import YoloWorldInterface, YoloV5Interface, YoloInterface
from .interface_owl import OWLInterface, owlInterface
class TStarSearcher:
    """
    A class to perform keyframe search in a video using object detection and dynamic sampling.

    Attributes:
        video_path (str): Path to the video file.
        target_objects (List[str]): List of target objects to find.
        cue_objects (List[str]): List of cue objects for context.
        confidence_threshold (float): Minimum confidence threshold for object detection.
        search_nframes (int): Number of keyframes to search for.
        image_grid_shape (Tuple[int, int]): Shape of the image grid for detection.
        output_dir (Optional[str]): Directory to save outputs.
        profix (str): Prefix for output files.
        object2weight (dict): Weights assigned to specific objects.
        raw_fps (float): Original frames per second of the video.
        total_frame_num (int): Total number of frames adjusted for sampling rate.
        duration (float): Duration of the video in seconds.
        remaining_targets (List[str]): Targets yet to be found.
        search_budget (int): Budget for the number of frames to process.
        score_distribution (np.ndarray): Scores assigned to each frame.
        P_history (List[List[float]]): History of probability distributions.
        non_visiting_frames (np.ndarray): Indicator for frames not yet visited.
        yolo (YoloWorldInterface): YOLO interface for object detection.
    """

    def __init__(
        self,
        video_path: str,
        target_objects: List[str],
        cue_objects: List[str],
        search_nframes: int = 8,
        image_grid_shape: Tuple[int, int] = (8, 8),
        search_budget: float = 0.1,
        output_dir: Optional[str] = None,
        prefix: str = None,
        confidence_threshold: float = 0.5,
        object2weight: Optional[dict] = None,
        model_choice: str ="owl",
        owl = None,
    ):
        """
        Initializes the TStarSearcher object with video properties and configurations.
        
        Args:
            video_path (str): Path to the input video file.
            target_objects (List[str]): List of objects to detect as primary targets.
            cue_objects (List[str]): List of contextual objects to aid detection.
            cue_object (Optional[str]): A single cue object for additional focus.
            search_nframes (int): Number of keyframes to identify.
            image_grid_shape (Tuple[int, int]): Grid dimensions for image tiling.
            output_dir (Optional[str]): Directory to store results.
            profix (str): Prefix for saved output files.
            confidence_threshold (float): Threshold for object detection confidence.
            object2weight (Optional[dict]): Mapping of objects to their respective detection weights.
            config_path (str): Path to the YOLO configuration file.
            checkpoint_path (str): Path to the YOLO model checkpoint.
            device (str): Device for model inference (e.g., "cuda:0").
        """
        self.video_path = video_path
        self.target_objects = target_objects
        self.cue_objects = cue_objects
        self.search_nframes = search_nframes
        self.image_grid_shape = image_grid_shape
        self.output_dir = output_dir
        self.profix = prefix
        self.confidence_threshold = confidence_threshold
        self.object2weight = object2weight if object2weight else {}
        self.fps = 1  # Sampling at 1 fps # TODO look at this
        self.model_choice = model_choice
        # Video properties
        cap = cv2.VideoCapture(self.video_path)
        if not cap.isOpened():
            raise ValueError(f"Cannot open video file: {self.video_path}")
        self.raw_fps = cap.get(cv2.CAP_PROP_FPS)
        self.total_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.duration = self.total_frame_num / self.raw_fps

        # Adjust total frame number based on sampling rate
        self.total_frame_num = int(self.duration * self.fps)
        self.remaining_targets = target_objects.copy()
        self.search_budget = min(1000, self.total_frame_num*search_budget)

        # Initialize distributions
        self.score_distribution = np.zeros(self.total_frame_num)
        self.P_history = []
        self.non_visiting_frames = np.ones(self.total_frame_num)
        self.P = np.ones(self.total_frame_num) * self.confidence_threshold * 0.3

        # Initialize YOLO interface TODO : allow for YOLOV5
        self.owl = None
        # self.reset_yolo_vocabulary(target_objects=target_objects, cue_objects=cue_objects)
        for object in target_objects:
            self.object2weight[object] = 1.0
        for object in cue_objects:
            self.object2weight[object] = 0.5
        #TODO: put in if statement
        # Initialize OWL interface
        model_name="google/owlvit-base-patch32"
        self.owl = owl
        # self.owl = OWLInterface(
        #     config_path = model_name,
        #     checkpoint_path=None,
        #     device="cuda:0"
        # )
    
        # self.reset_owl_vocabulary(target_objects=target_objects, cue_objects=cue_objects)
        for object in target_objects:
            self.object2weight[object] = 1.0
        for object in cue_objects:
            self.object2weight[object] = 0.5


    def reset_yolo_vocabulary(self, target_objects: List[str], cue_objects: List[str]):
        """
        Dynamically resets the YOLO vocabulary with the specified target and cue objects.

        Args:
            target_objects (List[str]): New list of target objects for detection.
            cue_objects (List[str]): New list of cue objects for detection context.
        """
        self.target_objects = target_objects
        self.cue_objects = cue_objects
        self.owl.reparameterize_object_list(target_objects, cue_objects)

    ### --- Detection Methods --- ###

    def imageGridScoreFunction(
        self,
        images: List[np.ndarray],
        output_dir: Optional[str],
        image_grids: Tuple[int, int]
    ) -> Tuple[np.ndarray, List[List[List[str]]]]:
        """
        Perform object detection on a batch of images using the YOLO interface.

        Args:
            images (List[np.ndarray]): List of images to process.
            output_dir (Optional[str]): Directory to save detection results.
            image_grids (Tuple[int, int]): Dimensions of the image grid (rows, cols).

        Returns:
            Tuple[np.ndarray, List[List[List[str]]]]: Confidence maps and detected object lists.
                - confidence_maps: numpy array of shape (num_images, grid_rows, grid_cols)
                - detected_objects_maps: list of lists, each sublist corresponds to a grid_image and contains detected objects per cell
        """
        if len(images) == 0:
            return np.array([]), []

        grid_rows, grid_cols = image_grids
        grid_height = images[0].shape[0] / grid_rows
        grid_width = images[0].shape[1] / grid_cols

        confidence_maps = []
        detected_objects_maps = []

        # Perform detection on all images  --AI
        for image in images:
            if self.model_choice == "yolo":
                # Run the YOLO inference
                detections = self.owl.inference_detector(
                    images=[image],  # Single image as a batch
                    max_dets=50,
                    use_amp=False
                )
            elif self.model_choice == "owl":
                # Run the OWL inference
                detections = self.owl.inference_detector(
                    images=[image],  # Single image as a batch 
                    use_amp=False
                )

            # Initialize confidence map and detected objects map
            confidence_map = np.zeros((grid_rows, grid_cols))
            detected_objects_map = [[] for _ in range(grid_rows * grid_cols)]

            # Process detections
            for detection in detections:
                for bbox, label, confidence in zip(detection.xyxy, detection.class_id, detection.confidence):
                    # Convert class ID to object name
                    if self.model_choice == "yolo":
                        object_name = self.owl.texts[label][0] #@Jinhui TBD for YOLOWorld
                    elif self.model_choice == "owl":
                        object_name = self.owl.texts[label][0]

                    # Apply object weight if available
                    weight = self.object2weight.get(object_name, 0.5)
                    adjusted_confidence = confidence * weight

                    # Calculate bounding box center
                    x_min, y_min, x_max, y_max = bbox
                    box_center_x = (x_min + x_max) / 2
                    box_center_y = (y_min + y_max) / 2

                    # Map center to grid cell
                    grid_x = int(box_center_x // grid_width)
                    grid_y = int(box_center_y // grid_height)

                    # Ensure grid indices are valid
                    grid_x = min(grid_x, grid_cols - 1)
                    grid_y = min(grid_y, grid_rows - 1)

                    # Update confidence map and detected objects
                    cell_index = grid_y * grid_cols + grid_x
                    confidence_map[grid_y, grid_x] = max(confidence_map[grid_y, grid_x], adjusted_confidence)
                    detected_objects_map[cell_index].append(object_name)

            confidence_maps.append(confidence_map)
            detected_objects_maps.append(detected_objects_map)

        return np.stack(confidence_maps), detected_objects_maps

    def read_frame_batch(self, video_path: str, frame_indices: List[int]) -> Tuple[List[int], np.ndarray]:
        """
        Reads a batch of frames from the video at specified indices.

        Args:
            video_path (str): Path to the video file.
            frame_indices (List[int]): Indices of frames to read.

        Returns:
            Tuple[List[int], np.ndarray]: List of indices and corresponding frame array.
        """
        vr = VideoReader(video_path, ctx=cpu(0))
        return frame_indices, vr.get_batch(frame_indices).asnumpy()

    def create_image_grid(self, frames: List[np.ndarray], rows: int, cols: int) -> np.ndarray:
        """
        Combine frames into a single image grid.

        Args:
            frames (List[np.ndarray]): List of frame images.
            rows (int): Number of rows in the grid.
            cols (int): Number of columns in the grid.

        Returns:
            np.ndarray: Combined image grid.
        """
        if len(frames) != rows * cols:
            raise ValueError("Frame count does not match grid dimensions")

        # Resize frames to fit the grid
        resized_frames = [cv2.resize(frame, (160, 120)) for frame in frames]  # Resize to 160x120
        grid_rows = [np.hstack(resized_frames[i * cols:(i + 1) * cols]) for i in range(rows)]
        return np.vstack(grid_rows)

    ### --- Scoring Methods --- ###

    def score_image_grids(
        self,
        images: List[np.ndarray],
        image_grids: Tuple[int, int]
    ) -> Tuple[np.ndarray, List[List[List[str]]]]:
        """
        Generate confidence maps and detected objects for each image grid.

        Args:
            images (List[np.ndarray]): List of image grids to detect objects.
            image_grids (Tuple[int, int]): Grid dimensions (rows, cols).

        Returns:
            Tuple[np.ndarray, List[List[List[str]]]]: Confidence maps and detected objects maps.
        """
        return self.imageGridScoreFunction(
            images=images,
            output_dir=self.output_dir,
            image_grids=image_grids
        )


    def store_score_distribution(self):
        """
        Stores a copy of the current probability distribution to the history.
        """
        self.P_history.append(copy.deepcopy(self.P).tolist())

    def update_top_25_with_window(
        self,
        frame_confidences: List[float],
        sampled_frame_indices: List[int],
        window_size: int = 5
    ):
        """
        Update score distribution for top 25% frames and their neighbors.

        Args:
            frame_confidences (List[float]): Confidence scores for sampled frames.
            sampled_frame_indices (List[int]): Corresponding frame indices.
            window_size (int): Number of neighboring frames to update.
        """
        # Calculate the threshold for top 25%
        top_25_threshold = np.percentile(frame_confidences, 75)

        # Identify top 25% frames
        top_25_indices = [
            frame_idx for frame_idx, confidence in zip(sampled_frame_indices, frame_confidences)
            if confidence >= top_25_threshold
        ]

        # Update neighboring frames
        for frame_idx in top_25_indices:
            for offset in range(-window_size, window_size + 1):
                neighbor_idx = frame_idx + offset
                if 0 <= neighbor_idx < len(self.score_distribution):
                    self.score_distribution[neighbor_idx] = max(
                        self.score_distribution[neighbor_idx],
                        self.score_distribution[frame_idx]/(abs(offset) + 1) 
                    )

    def spline_keyframe_distribution(
        self,
        non_visiting_frames: np.ndarray,
        score_distribution: np.ndarray,
        video_length: int
    ) -> np.ndarray:
        """
        Generate a probability distribution over frames using spline interpolation.

        Args:
            non_visiting_frames (np.ndarray): Indicator array for frames not yet visited.
            score_distribution (np.ndarray): Current score distribution over frames.
            video_length (int): Total number of frames.

        Returns:
            np.ndarray: Normalized probability distribution over frames.
        """
        # Extract indices and scores of visited frames
        frame_indices = np.array([idx for idx, visited in enumerate(non_visiting_frames) if visited == 0])
        observed_scores = np.array([score_distribution[idx] for idx in frame_indices])

        # If no frames have been visited, return uniform distribution
        if len(frame_indices) == 0:
            return np.ones(video_length) / video_length

        # Spline interpolation
        spline = UnivariateSpline(frame_indices, observed_scores, s=0.5)
        all_frames = np.arange(video_length)
        spline_scores = spline(all_frames)

        # Apply sigmoid function
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))

        adjusted_scores = np.maximum(1 / video_length, spline_scores)
        p_distribution = sigmoid(adjusted_scores)

        # Normalize the distribution
        p_distribution /= p_distribution.sum()

        return p_distribution

    def update_frame_distribution(
        self,
        sampled_frame_indices: List[int],
        confidence_maps: np.ndarray,
        detected_objects_maps: List[List[List[str]]]
    ) -> Tuple[List[float], List[List[str]]]:
        """
        Update the frame distribution based on detection results.

        Args:
            sampled_frame_indices (List[int]): Indices of sampled frames.
            confidence_maps (np.ndarray): Confidence maps from detection.
            detected_objects_maps (List[List[List[str]]]): Detected objects from detection.

        Returns:
            Tuple[List[float], List[List[str]]]: Frame confidences and detected objects.
        """
        confidence_map = confidence_maps[0]  # Only one image grid @TBD
        detected_objects_map = detected_objects_maps[0]

        grid_rows, grid_cols = self.image_grid_shape

        frame_confidences = []
        frame_detected_objects = []
        for idx, frame_idx in enumerate(sampled_frame_indices):
            # Calculate grid cell position
            row = idx // grid_cols
            col = idx % grid_cols
            confidence = confidence_map[row, col]
            detected_objects = detected_objects_map[idx]
            frame_confidences.append(confidence)
            frame_detected_objects.append(detected_objects)

        # Update non-visiting frames and score distribution
        for frame_idx, confidence in zip(sampled_frame_indices, frame_confidences):
            self.non_visiting_frames[frame_idx] = 0  # Mark as visited
            self.score_distribution[frame_idx] = confidence

        # Update top 25% frames
        self.update_top_25_with_window(frame_confidences, sampled_frame_indices)

        # Update probability distribution
        self.P = self.spline_keyframe_distribution(
            self.non_visiting_frames,
            self.score_distribution,
            len(self.score_distribution)
        )

        # Store the updated distribution
        self.store_score_distribution()

        return frame_confidences, frame_detected_objects

    ### --- Sampling Methods --- ###

    def sample_frames(self, num_samples: int) -> Tuple[List[int], np.ndarray]:
        """
        Sample frames based on the current score distribution.

        Args:
            num_samples (int): Number of frames to sample.

        Returns:
            Tuple[List[int], np.ndarray]: Sampled frame indices and frame data.
        """
        if num_samples > self.total_frame_num:
            num_samples = self.total_frame_num

        # Adjust probabilities for non-visited frames
        _P = (self.P + num_samples / self.total_frame_num) * self.non_visiting_frames
        _P /= _P.sum()

        # Sample frames
        sampled_frame_secs = np.random.choice(
            self.total_frame_num,
            size=num_samples,
            replace=False,
            p=_P
        )
        sampled_frame_indices = [int(sec * self.raw_fps / self.fps) for sec in sampled_frame_secs]

        # Read frames
        frame_indices, frames = self.read_frame_batch(
            video_path=self.video_path,
            frame_indices=sampled_frame_indices
        )

        return sampled_frame_secs.tolist(), frames

    ### --- Verification Methods --- ###

    def verify_and_remove_target(
        self,
        frame_sec: int,
        detected_objects: List[str],
        confidence_threshold: float,
    ) -> bool:
        """
        Verify target object detection in an individual frame and remove it from the target list if confirmed.

        Args:
            frame_sec (int): The timestamp of the frame in seconds.
            detected_objects (List[str]): Objects detected in the grid image for this frame.
            confidence_threshold (float): Threshold to confirm target detection.

        Returns:
            bool: True if a target was found and removed, False otherwise.
        """
        for target in list(self.remaining_targets):
            if target in detected_objects:
                frame_idx = int(frame_sec * self.raw_fps / self.fps)
                # Read the individual frame
                _, frame = self.read_frame_batch(self.video_path, [frame_idx])
                frame = frame[0]  # Extract the frame from the list

                # Perform detection on the individual frame
                single_confidence_maps, single_detected_objects_maps = self.score_image_grids(
                    [frame], (1, 1)
                )
                single_confidence = single_confidence_maps[0, 0, 0]
                single_detected_objects = single_detected_objects_maps[0][0]
                self.score_distribution[frame_sec] = single_confidence

                # Check if target object confidence exceeds the threshold
                if target in single_detected_objects and single_confidence > confidence_threshold:
                    self.remaining_targets.remove(target)
                    print(f"Found target '{target}' in frame {frame_idx}, score {single_confidence:.2f}")
                    self.image_grid_iters.append([frame])
                    self.detect_annotot_iters.append(self.owl.bbox_visualization(images=[frame], detections_inbatch=self.owl.detections_inbatch))
                    self.detect_bbox_iters.append(self.owl.detections_inbatch)
                    
                    return True

        return False

    ### --- Visualization Methods --- ###

    def plot_score_distribution(self, save_path: Optional[str] = None):
        """
        Plot the score distribution over time.

        Args:
            save_path (Optional[str]): File path to save the plot.
        """
        time_axis = np.linspace(0, self.duration, len(self.score_distribution))

        plt.figure(figsize=(12, 6))
        plt.plot(time_axis, self.score_distribution, label="Score Distribution")
        plt.xlabel("Time (seconds)")
        plt.ylabel("Score")
        plt.title("Score Distribution Over Time")
        plt.grid(True)
        plt.legend()

        if save_path:
            plt.savefig(save_path, format='png', dpi=300)
            print(f"Plot saved to {save_path}")

        plt.show()

    ### --- Main Search Logic --- ###

    def search(self) -> Tuple[List[np.ndarray], List[float]]:
        """
        Perform the keyframe search based on object detection and dynamic sampling.

        Returns:
            Tuple[List[np.ndarray], List[float]]: Extracted keyframes and their timestamps.
        """
        K = self.search_nframes  # Number of keyframes to find
            # Estimate the total number of iterations based on search_budget and frames per iteration
        video_length = int(self.total_frame_num)
        
        # Initialize tqdm progress bar
        progress_bar = tqdm(total=video_length, desc="Searching Iterations / video_length", unit="iter", dynamic_ncols=True)
        
        while self.remaining_targets and self.search_budget > 0:
            grid_rows, grid_cols = self.image_grid_shape
            num_frames_in_grid = grid_rows * grid_cols

            # Sample frames based on the current distribution
            sampled_frame_secs, frames = self.sample_frames(num_frames_in_grid)
            self.search_budget -= num_frames_in_grid

            # Create an image grid from the sampled frames
            grid_image = self.create_image_grid(frames, grid_rows, grid_cols)

            # Perform object detection on the image grid
            confidence_maps, detected_objects_maps = self.score_image_grids(
                images=[grid_image],
                image_grids=self.image_grid_shape
            )

            # Update frame distributions based on detection results
            frame_confidences, frame_detected_objects = self.update_frame_distribution(
                sampled_frame_indices=sampled_frame_secs,
                confidence_maps=confidence_maps,
                detected_objects_maps=detected_objects_maps
            )

            # Verify and remove detected targets
            for frame_sec, detected_objects in zip(sampled_frame_secs, frame_detected_objects):
                self.verify_and_remove_target(
                    frame_sec=frame_sec,
                    detected_objects=detected_objects,
                    confidence_threshold=self.confidence_threshold,
                )
            # Update the progress bar
            progress_bar.update(1)
        
        # Close the progress bar once the loop is done
        progress_bar.close()
        # Select top K frames based on the score distribution
        top_k_indices = np.argsort(self.score_distribution)[-K:][::-1]
        top_k_frames = []
        time_stamps = []

        # Read and store the top K frames
        for idx in top_k_indices:
            frame_idx = int(idx * self.raw_fps / self.fps)
            _, frame = self.read_frame_batch(self.video_path, [frame_idx])
            top_k_frames.append(frame[0])
            time_stamps.append(idx / self.fps)

        return top_k_frames, time_stamps



    def search_with_visualization(self) -> Tuple[List[np.ndarray], List[float]]:
        """
        Perform the keyframe search based on object detection and dynamic sampling.

        Returns:
            Tuple[List[np.ndarray], List[float]]: Extracted keyframes and their timestamps.
        """


        # Initialize history 
        self.image_grid_iters = [] # iters, b, image
        self.detect_annotot_iters = [] # iters, b, image
        self.detect_bbox_iters = [] #iters, b, n_objects, xxyy
            
        K = self.search_nframes  # Number of keyframes to find
            # Estimate the total number of iterations based on search_budget and frames per iteration
        video_length = int(self.total_frame_num)
        
        # Initialize tqdm progress bar
        progress_bar = tqdm(total=video_length, desc="Searching Iterations / video_length", unit="iter", dynamic_ncols=True)
        
        while self.remaining_targets and self.search_budget > 0:
            grid_rows, grid_cols = self.image_grid_shape
            num_frames_in_grid = grid_rows * grid_cols

            # Sample frames based on the current distribution
            sampled_frame_secs, frames = self.sample_frames(num_frames_in_grid)
            self.search_budget -= num_frames_in_grid

            # Create an image grid from the sampled frames
            grid_image = self.create_image_grid(frames, grid_rows, grid_cols)
            

            # Perform object detection on the image grid
            confidence_maps, detected_objects_maps = self.score_image_grids(
                images=[grid_image],
                image_grids=self.image_grid_shape
            )

            self.image_grid_iters.append([grid_image])
            self.detect_annotot_iters.append(self.owl.bbox_visualization(images=[grid_image], detections_inbatch=self.owl.detections_inbatch))
            self.detect_bbox_iters.append(self.owl.detections_inbatch)
            
            # Update frame distributions based on detection results
            frame_confidences, frame_detected_objects = self.update_frame_distribution(
                sampled_frame_indices=sampled_frame_secs,
                confidence_maps=confidence_maps,
                detected_objects_maps=detected_objects_maps
            )

            # Verify and remove detected targets
            for frame_sec, detected_objects in zip(sampled_frame_secs, frame_detected_objects):
                self.verify_and_remove_target(
                    frame_sec=frame_sec,
                    detected_objects=detected_objects,
                    confidence_threshold=self.confidence_threshold,
                )
            # Update the progress bar
            progress_bar.update(1)
        
        # Close the progress bar once the loop is done
        progress_bar.close()
        # Select top K frames based on the score distribution
        top_k_indices = np.argsort(self.score_distribution)[-K:][::-1]
        top_k_frames = []
        time_stamps = []

        # Read and store the top K frames
        for idx in top_k_indices:
            frame_idx = int(idx * self.raw_fps / self.fps)
            _, frame = self.read_frame_batch(self.video_path, [frame_idx])
            top_k_frames.append(frame[0])
            time_stamps.append(idx / self.fps)

        return top_k_frames, time_stamps





# Example usage
if __name__ == "__main__":
    # Define video path and target objects
    video_path = "/home/anabella/projects/MLLM/TSTAR/data/friend_clip_t.mp4"
    query = "what is the color of the couch?"
    target_objects = ["couch"]  # Target objects to find
    cue_objects = ["table", "woman"]

    # Create VideoSearcher instance
    searcher = TStarSearcher(
        video_path=video_path,
        target_objects=target_objects,
        cue_objects=cue_objects,
        search_nframes=8,
        image_grid_shape=(4, 4),
        confidence_threshold=0.6
    )

    # Perform the search
    all_frames, time_stamps = searcher.search()

    # Process results
    print(f"Found {len(all_frames)} frames, timestamps: {time_stamps}")

    # Plot the score distribution
    searcher.plot_score_distribution(save_path='./output/score/score_distribution.png')