Spaces:
Runtime error
Runtime error
| class BucketGroup: | |
| """Manages dynamic batch grouping buckets for image inference.""" | |
| def __init__( | |
| self, | |
| bucket_configs: list[tuple[int, int, int, int, int]], | |
| prioritize_frame_matching: bool = True, | |
| ): | |
| """ | |
| Initialize bucket group with predefined configurations. | |
| Args: | |
| bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples | |
| prioritize_frame_matching: Unused, kept for API compatibility. | |
| """ | |
| self.bucket_configs = [tuple(b) for b in bucket_configs] | |
| def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]: | |
| """ | |
| Find the best matching bucket for given media dimensions. | |
| Args: | |
| media_shape: (num_items, num_frames, height, width) of input media | |
| Returns: | |
| Best matching bucket as (batch_size, num_items, num_frames, height, width) | |
| """ | |
| num_items, num_frames, height, width = media_shape | |
| target_aspect_ratio = height / width | |
| if num_frames != 1: | |
| raise ValueError( | |
| f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}") | |
| valid_buckets = [ | |
| b for b in self.bucket_configs | |
| if b[1] == num_items and b[2] == 1 | |
| ] | |
| if not valid_buckets: | |
| raise ValueError( | |
| f"No image buckets found for shape {media_shape}") | |
| return min( | |
| valid_buckets, | |
| key=lambda bucket: abs( | |
| (bucket[3] / bucket[4]) - target_aspect_ratio) | |
| ) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"BucketGroup(" | |
| f"total_buckets={len(self.bucket_configs)}, " | |
| f"configs={self.bucket_configs})" | |
| ) | |
| def _generate_hw_buckets(base_height=256, base_width=256, step_width=16, step_height=16, max_ratio=4.0) -> list[tuple[int, int, int, int, int]]: | |
| """Generate dimension buckets based on aspect ratios.""" | |
| buckets = [] | |
| target_pixels = base_height * base_width | |
| height = target_pixels // step_width | |
| width = step_width | |
| while height >= step_height: | |
| if max(height, width) / min(height, width) <= max_ratio: | |
| buckets.append((1, 1, 1, height, width)) | |
| if height * (width + step_width) <= target_pixels: | |
| width += step_width | |
| else: | |
| height -= step_height | |
| return buckets | |
| def generate_video_image_bucket(basesize=256, min_temporal=65, max_temporal=129, bs_img=8, bs_vid=1, bs_mimg=4, min_items=1, max_items=1): | |
| """Generate bucket configs for image inference. | |
| Returns: | |
| List of (batch_size, num_items, num_frames, height, width) tuples. | |
| """ | |
| assert basesize in [ | |
| 256, 512, 768, 1024], f"[generate_video_image_bucket] wrong basesize {basesize}" | |
| bucket_list = [] | |
| base_bucket_list = _generate_hw_buckets() | |
| # image | |
| for _bucket in base_bucket_list: | |
| bucket = list(_bucket) | |
| bucket[0] = bs_img | |
| bucket_list.append(bucket) | |
| # multiple images | |
| for num_items in range(min_items, max_items + 1): | |
| for _bucket in base_bucket_list: | |
| bucket = list(_bucket) | |
| bucket[0] = bs_mimg | |
| bucket[1] = num_items | |
| bucket_list.append(bucket) | |
| # spatial resize | |
| if basesize > 256: | |
| ratio = basesize // 256 | |
| def resize(bucket, r): | |
| bucket[-2] *= r | |
| bucket[-1] *= r | |
| return bucket | |
| bucket_list = [resize(bucket, ratio) for bucket in bucket_list] | |
| return bucket_list | |