| from dataclasses import dataclass |
| from typing import Callable, List, Optional |
|
|
|
|
| @dataclass |
| class ChunkPayload: |
| segments: list |
| image_urls: list |
|
|
|
|
| class RequestChunker: |
| def __init__(self, message_builder: Callable, max_bytes: int, size_estimator: Optional[Callable] = None): |
| self.message_builder = message_builder |
| self.max_bytes = max_bytes |
| self.size_estimator = size_estimator |
|
|
| def estimate(self, messages) -> int: |
| if self.size_estimator: |
| return self.size_estimator(messages) |
| import json |
| return len(json.dumps(messages, ensure_ascii=False).encode("utf-8")) |
|
|
| def _messages_size(self, segments, image_urls, **kwargs) -> int: |
| messages = self.message_builder(segments, image_urls, **kwargs) |
| return self.estimate(messages) |
|
|
| def _get_text(self, segment) -> str: |
| if isinstance(segment, dict): |
| return segment.get("text", "") |
| return getattr(segment, "text", "") |
|
|
| def _make_segment(self, segment, text: str): |
| if isinstance(segment, dict): |
| new_seg = dict(segment) |
| new_seg["text"] = text |
| return new_seg |
| if hasattr(segment, "__dict__"): |
| data = dict(segment.__dict__) |
| data["text"] = text |
| return type(segment)(**data) |
| return type(segment)(segment.start, segment.end, text) |
|
|
| def _split_segment_to_fit(self, segment, **kwargs): |
| text = self._get_text(segment) |
| if not text: |
| raise ValueError("empty segment cannot be split") |
| lo, hi = 1, len(text) |
| best = None |
| while lo <= hi: |
| mid = (lo + hi) // 2 |
| candidate = self._make_segment(segment, text[:mid]) |
| size = self._messages_size([candidate], [], **kwargs) |
| if size <= self.max_bytes: |
| best = mid |
| lo = mid + 1 |
| else: |
| hi = mid - 1 |
| if best is None: |
| raise ValueError("single segment too large to fit request") |
| head = self._make_segment(segment, text[:best]) |
| tail = self._make_segment(segment, text[best:]) |
| return head, tail |
|
|
| def chunk(self, segments: list, image_urls: list, **kwargs) -> List[ChunkPayload]: |
| segments = list(segments or []) |
| image_urls = list(image_urls or []) |
| if not segments and not image_urls: |
| return [] |
|
|
| chunks: List[ChunkPayload] = [] |
| seg_idx = 0 |
|
|
| while seg_idx < len(segments): |
| batch_segments = [] |
| while seg_idx < len(segments): |
| candidate = batch_segments + [segments[seg_idx]] |
| size = self._messages_size(candidate, [], **kwargs) |
| if size <= self.max_bytes: |
| batch_segments = candidate |
| seg_idx += 1 |
| continue |
| if not batch_segments: |
| head, tail = self._split_segment_to_fit(segments[seg_idx], **kwargs) |
| segments[seg_idx] = head |
| segments.insert(seg_idx + 1, tail) |
| continue |
| break |
|
|
| if not batch_segments: |
| raise ValueError("unable to fit any content into chunk") |
|
|
| chunks.append(ChunkPayload(segments=batch_segments, image_urls=[])) |
|
|
| if not image_urls: |
| return chunks |
|
|
| if not chunks: |
| chunks = [ChunkPayload(segments=[], image_urls=[])] |
|
|
| if not segments: |
| for image in image_urls: |
| appended = False |
| for chunk in chunks[-1:]: |
| candidate_images = chunk.image_urls + [image] |
| if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes: |
| chunk.image_urls = candidate_images |
| appended = True |
| break |
|
|
| if appended: |
| continue |
|
|
| if self._messages_size([], [image], **kwargs) > self.max_bytes: |
| raise ValueError("single image payload exceeds max_bytes") |
| chunks.append(ChunkPayload(segments=[], image_urls=[image])) |
| return chunks |
|
|
| chunk_count = len(chunks) |
| total_images = len(image_urls) |
| for idx, image in enumerate(image_urls): |
| preferred_idx = min(chunk_count - 1, (idx * chunk_count) // total_images) |
| placed = False |
|
|
| for chunk_idx in range(preferred_idx, len(chunks)): |
| chunk = chunks[chunk_idx] |
| candidate_images = chunk.image_urls + [image] |
| if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes: |
| chunk.image_urls = candidate_images |
| placed = True |
| break |
|
|
| if placed: |
| continue |
|
|
| if self._messages_size([], [image], **kwargs) > self.max_bytes: |
| raise ValueError("single image payload exceeds max_bytes") |
| chunks.append(ChunkPayload(segments=[], image_urls=[image])) |
|
|
| return chunks |
|
|
| def group_texts_by_budget(self, texts: List[str], build_messages: Callable, **kwargs) -> List[List[str]]: |
| groups: List[List[str]] = [] |
| idx = 0 |
| while idx < len(texts): |
| group: List[str] = [] |
| while idx < len(texts): |
| candidate = group + [texts[idx]] |
| try: |
| messages = build_messages(candidate, [], **kwargs) |
| except TypeError: |
| messages = build_messages(candidate, **kwargs) |
| size = self.estimate(messages) |
| if size <= self.max_bytes: |
| group = candidate |
| idx += 1 |
| continue |
| if not group: |
| raise ValueError("single text block exceeds max_bytes") |
| break |
| groups.append(group) |
| return groups |
|
|