File size: 3,637 Bytes
fcfea15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
class BucketGroup:
    """Manages dynamic batch grouping buckets for image inference."""

    def __init__(
        self,
        bucket_configs: list[tuple[int, int, int, int, int]],
        prioritize_frame_matching: bool = True,
    ):
        """
        Initialize bucket group with predefined configurations.

        Args:
            bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples
            prioritize_frame_matching: Unused, kept for API compatibility.
        """
        self.bucket_configs = [tuple(b) for b in bucket_configs]

    def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]:
        """
        Find the best matching bucket for given media dimensions.

        Args:
            media_shape: (num_items, num_frames, height, width) of input media

        Returns:
            Best matching bucket as (batch_size, num_items, num_frames, height, width)
        """
        num_items, num_frames, height, width = media_shape
        target_aspect_ratio = height / width

        if num_frames != 1:
            raise ValueError(
                f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}")

        valid_buckets = [
            b for b in self.bucket_configs
            if b[1] == num_items and b[2] == 1
        ]
        if not valid_buckets:
            raise ValueError(
                f"No image buckets found for shape {media_shape}")

        return min(
            valid_buckets,
            key=lambda bucket: abs(
                (bucket[3] / bucket[4]) - target_aspect_ratio)
        )

    def __repr__(self) -> str:
        return (
            f"BucketGroup("
            f"total_buckets={len(self.bucket_configs)}, "
            f"configs={self.bucket_configs})"
        )


def _generate_hw_buckets(base_height=256, base_width=256, step_width=16, step_height=16, max_ratio=4.0) -> list[tuple[int, int, int, int, int]]:
    """Generate dimension buckets based on aspect ratios."""
    buckets = []
    target_pixels = base_height * base_width

    height = target_pixels // step_width
    width = step_width

    while height >= step_height:
        if max(height, width) / min(height, width) <= max_ratio:
            buckets.append((1, 1, 1, height, width))
        if height * (width + step_width) <= target_pixels:
            width += step_width
        else:
            height -= step_height

    return buckets


def generate_video_image_bucket(basesize=256, min_temporal=65, max_temporal=129, bs_img=8, bs_vid=1, bs_mimg=4, min_items=1, max_items=1):
    """Generate bucket configs for image inference.

    Returns:
        List of (batch_size, num_items, num_frames, height, width) tuples.
    """
    assert basesize in [
        256, 512, 768, 1024], f"[generate_video_image_bucket] wrong basesize {basesize}"
    bucket_list = []

    base_bucket_list = _generate_hw_buckets()
    # image
    for _bucket in base_bucket_list:
        bucket = list(_bucket)
        bucket[0] = bs_img
        bucket_list.append(bucket)
    # multiple images
    for num_items in range(min_items, max_items + 1):
        for _bucket in base_bucket_list:
            bucket = list(_bucket)
            bucket[0] = bs_mimg
            bucket[1] = num_items
            bucket_list.append(bucket)
    # spatial resize
    if basesize > 256:
        ratio = basesize // 256

        def resize(bucket, r):
            bucket[-2] *= r
            bucket[-1] *= r
            return bucket
        bucket_list = [resize(bucket, ratio) for bucket in bucket_list]
    return bucket_list