| |
| |
|
|
| import base64 |
| import logging |
| import math |
| import os |
| import tempfile |
| from io import BytesIO |
|
|
| import librosa |
| import numpy as np |
| import torch |
| from decord import cpu |
| from decord import VideoReader |
|
|
| try: |
| from moviepy import VideoFileClip |
| except ImportError: |
| from moviepy.editor import VideoFileClip |
|
|
| from PIL import Image |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False): |
| """ |
| Incrementally decode tokens from an iterator, handling partial multi-byte characters. |
| |
| When streaming tokens, multi-byte characters (like Chinese) may be split across multiple |
| tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function |
| buffers tokens and only yields complete characters. |
| |
| Args: |
| token_iterator: An iterator yielding (token_ids, is_finished) tuples. |
| token_ids can be torch.Tensor or any iterable of integers. |
| tokenizer: The tokenizer to use for decoding. |
| skip_special_tokens: Whether to skip special tokens during decoding. |
| |
| Yields: |
| (decoded_text, is_finished) tuples where decoded_text is the new text since last yield. |
| """ |
| accumulated_token_ids = [] |
| yielded_text_len = 0 |
|
|
| for token_ids, is_finished in token_iterator: |
| |
| if torch.is_tensor(token_ids): |
| accumulated_token_ids.extend(token_ids.reshape(-1).tolist()) |
| else: |
| accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids]) |
|
|
| |
| full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens) |
|
|
| if is_finished: |
| |
| new_text = full_decoded[yielded_text_len:] |
| yield new_text, is_finished |
| else: |
| |
| |
| new_text = full_decoded[yielded_text_len:] |
|
|
| |
| safe_end = len(new_text) |
| while safe_end > 0 and new_text[safe_end - 1] == "\ufffd": |
| safe_end -= 1 |
|
|
| safe_text = new_text[:safe_end] if safe_end > 0 else "" |
| yielded_text_len += len(safe_text) |
| yield safe_text, is_finished |
|
|
|
|
| def torch_clone_recursive(obj): |
| """Recursively clone nested containers of torch.Tensors. |
| |
| Supported container types: dict, list, tuple. Non-container non-Tensor |
| objects are returned as-is. |
| """ |
| if torch.is_tensor(obj): |
| return obj.clone() |
| elif isinstance(obj, dict): |
| return {k: torch_clone_recursive(v) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [torch_clone_recursive(v) for v in obj] |
| elif isinstance(obj, tuple): |
| return tuple(torch_clone_recursive(v) for v in obj) |
| else: |
| raise ValueError(f"Unsupported type: {type(obj)}") |
|
|
|
|
| def _fmt_bytes(n_bytes: int) -> str: |
| mb = n_bytes / (1024**2) |
| return f"{mb:.2f}MB" |
|
|
|
|
| def _cuda_tensor_bytes(obj): |
| total_bytes = 0 |
| if torch.is_tensor(obj): |
| total_bytes += obj.numel() * obj.element_size() |
| print(f"cuda tensor: {obj.shape}, total_bytes: {_fmt_bytes(obj.numel() * obj.element_size())}") |
| return total_bytes |
| elif isinstance(obj, dict): |
| for v in obj.values(): |
| total_bytes += _cuda_tensor_bytes(v) |
| return total_bytes |
| elif isinstance(obj, (list, tuple)): |
| for v in obj: |
| total_bytes += _cuda_tensor_bytes(v) |
| return total_bytes |
| else: |
| raise ValueError(f"Unsupported type: {type(obj)}") |
|
|
|
|
| def concat_images(images, bg_color=(255, 255, 255), cell_size=None, line_color=(0, 0, 0), line_width=6): |
| """ |
| images: List[PIL.Image.Image] |
| 规则:3 张 -> 1x3;4 张 -> 2x2;9 张 -> 3x3;其余:1xN |
| 仅在拼接处画分界线(不画外框)。 |
| """ |
|
|
| |
| _converted_images = [] |
| for im in images: |
| if isinstance(im, Image.Image): |
| _converted_images.append(im) |
| elif isinstance(im, (bytes, bytearray)): |
| _converted_images.append(Image.open(BytesIO(im)).convert("RGB")) |
| elif isinstance(im, str): |
| |
| b64 = im.split(",")[-1] if ";base64," in im else im |
| img_bytes = base64.b64decode(b64) |
| _converted_images.append(Image.open(BytesIO(img_bytes)).convert("RGB")) |
| else: |
| raise TypeError(f"Unsupported image type: {type(im)}") |
| images = _converted_images |
| n = len(images) |
| if n == 0: |
| raise ValueError("images is empty") |
|
|
| if n == 4: |
| rows, cols = 2, 2 |
| elif n == 3: |
| |
| |
| if cell_size is None: |
| cell_w = max(im.width for im in images) |
| cell_h = max(im.height for im in images) |
| else: |
| cell_w, cell_h = cell_size |
|
|
| candidates = [(1, 3), (3, 1)] |
|
|
| def canvas_ratio(r, c): |
| W = c * cell_w + (c - 1) * line_width |
| H = r * cell_h + (r - 1) * line_width |
| return W / max(1, H) |
|
|
| ratios = [abs(canvas_ratio(r, c) - 1.0) for (r, c) in candidates] |
| best_idx = int(np.argmin(ratios)) |
| rows, cols = candidates[best_idx] |
| elif n == 1: |
| rows, cols = 1, 1 |
| elif n == 2: |
| |
| if cell_size is None: |
| cell_w = max(im.width for im in images) |
| cell_h = max(im.height for im in images) |
| else: |
| cell_w, cell_h = cell_size |
| candidates = [(1, 2), (2, 1)] |
|
|
| def canvas_ratio(r, c): |
| W = c * cell_w + (c - 1) * line_width |
| H = r * cell_h + (r - 1) * line_width |
| return W / max(1, H) |
|
|
| ratios = [abs(canvas_ratio(r, c) - 1.0) for (r, c) in candidates] |
| |
| if ratios[0] == ratios[1]: |
| avg_ar = np.mean([im.width / max(1, im.height) for im in images]) |
| rows, cols = (1, 2) if avg_ar >= 1.0 else (2, 1) |
| else: |
| best_idx = int(np.argmin(ratios)) |
| rows, cols = candidates[best_idx] |
| else: |
| rows, cols = 1, n |
|
|
| |
| if cell_size is None: |
| cell_w = max(im.width for im in images) |
| cell_h = max(im.height for im in images) |
| else: |
| cell_w, cell_h = cell_size |
|
|
| |
| def letterbox(im, tw, th): |
| im = im.convert("RGB") |
| w, h = im.size |
| s = min(tw / w, th / h) |
| nw, nh = max(1, int(round(w * s))), max(1, int(round(h * s))) |
| try: |
| im_r = im.resize((nw, nh), Image.Resampling.BICUBIC) |
| except AttributeError: |
| im_r = im.resize((nw, nh), Image.BICUBIC) |
| canvas = Image.new("RGB", (tw, th), bg_color) |
| canvas.paste(im_r, ((tw - nw) // 2, (th - nh) // 2)) |
| return canvas |
|
|
| |
| W = cols * cell_w + (cols - 1) * line_width |
| H = rows * cell_h + (rows - 1) * line_width |
| canvas = Image.new("RGB", (W, H), line_color) |
|
|
| for i, im in enumerate(images[: rows * cols]): |
| r, c = divmod(i, cols) |
| cell = letterbox(im, cell_w, cell_h) |
| x = c * (cell_w + line_width) |
| y = r * (cell_h + line_width) |
| canvas.paste(cell, (x, y)) |
|
|
| return canvas |
|
|
|
|
| MAX_NUM_FRAMES = int(os.getenv("MAX_NUM_FRAMES", 64)) |
| VIDEO_MME_DURATION = os.getenv("VIDEO_MME_DURATION", "ALL") |
|
|
|
|
| def uniform_sample(l, n): |
| if len(l) <= n: |
| return l |
| idxs = np.linspace(0, len(l) - 1, n, dtype=int) |
| return [l[i] for i in idxs] |
|
|
|
|
| def get_video_frame_audio_segments(video_path, audio_path=None, last_vad_timestamp=None, stack_frames=1): |
| vr = VideoReader(str(video_path), ctx=cpu(0)) |
| avg_fps = vr.get_avg_fps() |
| duration = len(vr) / avg_fps |
|
|
| if last_vad_timestamp is not None: |
| duration = last_vad_timestamp |
|
|
| |
| num_seconds = math.ceil(duration) |
| second_timestamps = list(range(num_seconds)) |
|
|
| |
| if duration > MAX_NUM_FRAMES: |
| timestamps = [round(i * 0.1, 1) for i in range(int(duration / 0.1))] |
| frame_idx = [min(int(ts * avg_fps), len(vr) - 1) for ts in timestamps] |
| frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) |
| timestamps = uniform_sample(timestamps, MAX_NUM_FRAMES) |
| else: |
| frame_idx = [int(i * avg_fps) for i in range(num_seconds)] |
| timestamps = second_timestamps |
|
|
| video = vr.get_batch(frame_idx).asnumpy() |
| video_segments = [Image.fromarray(v.astype("uint8")).convert("RGB") for v in video] |
|
|
| |
| |
| stacked_video_segments = None |
| if stack_frames > 1: |
| |
| |
| all_frame_timestamps = [] |
| for sec in range(num_seconds): |
| for i in range(1, stack_frames): |
| ts = sec + i / stack_frames |
| if ts < duration: |
| all_frame_timestamps.append(ts) |
|
|
| stack_frame_idx = [min(int(ts * avg_fps), len(vr) - 1) for ts in all_frame_timestamps] |
|
|
| |
| max_stack_frames = MAX_NUM_FRAMES * (stack_frames - 1) |
| if len(stack_frame_idx) > max_stack_frames: |
| stack_frame_idx = uniform_sample(stack_frame_idx, max_stack_frames) |
| all_frame_timestamps = uniform_sample(all_frame_timestamps, max_stack_frames) |
|
|
| stack_video = vr.get_batch(stack_frame_idx).asnumpy() |
| all_frames = [Image.fromarray(v.astype("uint8")).convert("RGB") for v in stack_video] |
|
|
| |
| stacked_video_segments = [] |
| frame_cursor = 0 |
| for sec in range(num_seconds): |
| |
| frames_this_second = [] |
| while frame_cursor < len(all_frame_timestamps) and all_frame_timestamps[frame_cursor] < sec + 1: |
| frames_this_second.append(all_frames[frame_cursor]) |
| frame_cursor += 1 |
|
|
| if len(frames_this_second) > 0: |
| stacked_frame = concat_images(frames_this_second) |
| stacked_video_segments.append(stacked_frame) |
| else: |
| |
| stacked_video_segments.append(None) |
|
|
| |
| if audio_path is None: |
| try: |
| audio_np, sr = librosa.load(video_path, sr=16000, mono=True) |
| except: |
| video_clip = VideoFileClip(video_path) |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file: |
| temp_audio_file_path = temp_audio_file.name |
| video_clip.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000) |
| audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True) |
| else: |
| audio_np, sr = librosa.load(audio_path, sr=16000, mono=True) |
|
|
| |
| audio_segments = [] |
| for i in range(len(timestamps)): |
| start_time = timestamps[i] |
| if i < len(timestamps) - 1: |
| end_time = timestamps[i + 1] |
| else: |
| end_time = duration |
|
|
| start_sample = int(start_time * sr) |
| end_sample = int(end_time * sr) |
| segment = audio_np[start_sample:end_sample] |
|
|
| |
| if i == len(timestamps) - 1 and len(segment) < 1600: |
| segment = np.concatenate([segment, np.zeros(1600 - len(segment), dtype=segment.dtype)]) |
| audio_segments.append(segment) |
|
|
| return video_segments, audio_segments, stacked_video_segments |
|
|