#!/usr/bin/env python # -*- coding: utf-8 -*- import base64 import logging import math import os import tempfile from io import BytesIO import librosa import numpy as np import torch from decord import cpu from decord import VideoReader try: from moviepy import VideoFileClip # moviepy >= 2.0 except ImportError: from moviepy.editor import VideoFileClip # moviepy < 2.0 from PIL import Image logger = logging.getLogger(__name__) def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False): """ Incrementally decode tokens from an iterator, handling partial multi-byte characters. When streaming tokens, multi-byte characters (like Chinese) may be split across multiple tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function buffers tokens and only yields complete characters. Args: token_iterator: An iterator yielding (token_ids, is_finished) tuples. token_ids can be torch.Tensor or any iterable of integers. tokenizer: The tokenizer to use for decoding. skip_special_tokens: Whether to skip special tokens during decoding. Yields: (decoded_text, is_finished) tuples where decoded_text is the new text since last yield. """ accumulated_token_ids = [] yielded_text_len = 0 for token_ids, is_finished in token_iterator: # Accumulate token IDs if torch.is_tensor(token_ids): accumulated_token_ids.extend(token_ids.reshape(-1).tolist()) else: accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids]) # Decode all accumulated tokens full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens) if is_finished: # Final chunk - yield all remaining text new_text = full_decoded[yielded_text_len:] yield new_text, is_finished else: # Find safe prefix without incomplete multi-byte characters # The replacement character '�' (U+FFFD) indicates incomplete decoding new_text = full_decoded[yielded_text_len:] # Hold back text ending with replacement character (incomplete UTF-8 sequence) safe_end = len(new_text) while safe_end > 0 and new_text[safe_end - 1] == "\ufffd": safe_end -= 1 safe_text = new_text[:safe_end] if safe_end > 0 else "" yielded_text_len += len(safe_text) yield safe_text, is_finished def torch_clone_recursive(obj): """Recursively clone nested containers of torch.Tensors. Supported container types: dict, list, tuple. Non-container non-Tensor objects are returned as-is. """ if torch.is_tensor(obj): return obj.clone() elif isinstance(obj, dict): return {k: torch_clone_recursive(v) for k, v in obj.items()} elif isinstance(obj, list): return [torch_clone_recursive(v) for v in obj] elif isinstance(obj, tuple): return tuple(torch_clone_recursive(v) for v in obj) else: raise ValueError(f"Unsupported type: {type(obj)}") def _fmt_bytes(n_bytes: int) -> str: mb = n_bytes / (1024**2) return f"{mb:.2f}MB" def _cuda_tensor_bytes(obj): total_bytes = 0 if torch.is_tensor(obj): total_bytes += obj.numel() * obj.element_size() print(f"cuda tensor: {obj.shape}, total_bytes: {_fmt_bytes(obj.numel() * obj.element_size())}") return total_bytes elif isinstance(obj, dict): for v in obj.values(): total_bytes += _cuda_tensor_bytes(v) return total_bytes elif isinstance(obj, (list, tuple)): for v in obj: total_bytes += _cuda_tensor_bytes(v) return total_bytes else: raise ValueError(f"Unsupported type: {type(obj)}") def concat_images(images, bg_color=(255, 255, 255), cell_size=None, line_color=(0, 0, 0), line_width=6): """ images: List[PIL.Image.Image] 规则:3 张 -> 1x3;4 张 -> 2x2;9 张 -> 3x3;其余:1xN 仅在拼接处画分界线(不画外框)。 """ # 统一将输入转换为 PIL.Image:支持 PIL.Image、bytes/bytearray、base64 字符串 _converted_images = [] for im in images: if isinstance(im, Image.Image): _converted_images.append(im) elif isinstance(im, (bytes, bytearray)): _converted_images.append(Image.open(BytesIO(im)).convert("RGB")) elif isinstance(im, str): # 处理形如 'data:image/jpeg;base64,...' 或纯 base64 b64 = im.split(",")[-1] if ";base64," in im else im img_bytes = base64.b64decode(b64) _converted_images.append(Image.open(BytesIO(img_bytes)).convert("RGB")) else: raise TypeError(f"Unsupported image type: {type(im)}") images = _converted_images n = len(images) if n == 0: raise ValueError("images is empty") if n == 4: rows, cols = 2, 2 elif n == 3: # 动态选择 1x3 / 3x1 / 2x2,使最终更接近正方形 # 先用原图最大宽高确定单元格尺寸(下方 letterbox 会自适应) if cell_size is None: cell_w = max(im.width for im in images) cell_h = max(im.height for im in images) else: cell_w, cell_h = cell_size candidates = [(1, 3), (3, 1)] def canvas_ratio(r, c): W = c * cell_w + (c - 1) * line_width H = r * cell_h + (r - 1) * line_width return W / max(1, H) ratios = [abs(canvas_ratio(r, c) - 1.0) for (r, c) in candidates] best_idx = int(np.argmin(ratios)) rows, cols = candidates[best_idx] elif n == 1: rows, cols = 1, 1 elif n == 2: # 动态选择 1x2 / 2x1,使最终更接近正方形 if cell_size is None: cell_w = max(im.width for im in images) cell_h = max(im.height for im in images) else: cell_w, cell_h = cell_size candidates = [(1, 2), (2, 1)] def canvas_ratio(r, c): W = c * cell_w + (c - 1) * line_width H = r * cell_h + (r - 1) * line_width return W / max(1, H) ratios = [abs(canvas_ratio(r, c) - 1.0) for (r, c) in candidates] # 如出现并列,依据平均宽高比进行决策:横向排列适合横图,纵向排列适合竖图 if ratios[0] == ratios[1]: avg_ar = np.mean([im.width / max(1, im.height) for im in images]) rows, cols = (1, 2) if avg_ar >= 1.0 else (2, 1) else: best_idx = int(np.argmin(ratios)) rows, cols = candidates[best_idx] else: rows, cols = 1, n # 单元格尺寸 if cell_size is None: cell_w = max(im.width for im in images) cell_h = max(im.height for im in images) else: cell_w, cell_h = cell_size # 保持纵横比缩放到单元格 def letterbox(im, tw, th): im = im.convert("RGB") w, h = im.size s = min(tw / w, th / h) nw, nh = max(1, int(round(w * s))), max(1, int(round(h * s))) try: im_r = im.resize((nw, nh), Image.Resampling.BICUBIC) except AttributeError: im_r = im.resize((nw, nh), Image.BICUBIC) canvas = Image.new("RGB", (tw, th), bg_color) canvas.paste(im_r, ((tw - nw) // 2, (th - nh) // 2)) return canvas # 仅在内部缝隙处留出 line_width 的带状区域作为分界线 W = cols * cell_w + (cols - 1) * line_width H = rows * cell_h + (rows - 1) * line_width canvas = Image.new("RGB", (W, H), line_color) for i, im in enumerate(images[: rows * cols]): r, c = divmod(i, cols) cell = letterbox(im, cell_w, cell_h) x = c * (cell_w + line_width) y = r * (cell_h + line_width) canvas.paste(cell, (x, y)) return canvas MAX_NUM_FRAMES = int(os.getenv("MAX_NUM_FRAMES", 64)) VIDEO_MME_DURATION = os.getenv("VIDEO_MME_DURATION", "ALL") def uniform_sample(l, n): if len(l) <= n: return l idxs = np.linspace(0, len(l) - 1, n, dtype=int) return [l[i] for i in idxs] def get_video_frame_audio_segments(video_path, audio_path=None, last_vad_timestamp=None, stack_frames=1): vr = VideoReader(str(video_path), ctx=cpu(0)) avg_fps = vr.get_avg_fps() duration = len(vr) / avg_fps if last_vad_timestamp is not None: duration = last_vad_timestamp # 按秒计算时间戳(用于音频分割) num_seconds = math.ceil(duration) second_timestamps = list(range(num_seconds)) # 提取原始帧(每秒 1 帧,在每秒开头 0.0s, 1.0s, 2.0s...) if duration > MAX_NUM_FRAMES: timestamps = [round(i * 0.1, 1) for i in range(int(duration / 0.1))] frame_idx = [min(int(ts * avg_fps), len(vr) - 1) for ts in timestamps] frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) timestamps = uniform_sample(timestamps, MAX_NUM_FRAMES) else: frame_idx = [int(i * avg_fps) for i in range(num_seconds)] timestamps = second_timestamps video = vr.get_batch(frame_idx).asnumpy() video_segments = [Image.fromarray(v.astype("uint8")).convert("RGB") for v in video] # 如果 stack_frames > 1,额外提取高帧率帧并合并成 stackimage # 每秒跳过第一帧(与 1fps 重复),只取剩余的 (stack_frames-1) 帧 stacked_video_segments = None if stack_frames > 1: # 按 stack_frames fps 抽帧,但跳过每秒第一帧(i=0) # 例如 stack_frames=5 时,每秒取 i=1,2,3,4 即 0.2s, 0.4s, 0.6s, 0.8s all_frame_timestamps = [] for sec in range(num_seconds): for i in range(1, stack_frames): # 从 1 开始,跳过 0 ts = sec + i / stack_frames if ts < duration: all_frame_timestamps.append(ts) stack_frame_idx = [min(int(ts * avg_fps), len(vr) - 1) for ts in all_frame_timestamps] # 如果总帧数超过限制,需要均匀采样 max_stack_frames = MAX_NUM_FRAMES * (stack_frames - 1) if len(stack_frame_idx) > max_stack_frames: stack_frame_idx = uniform_sample(stack_frame_idx, max_stack_frames) all_frame_timestamps = uniform_sample(all_frame_timestamps, max_stack_frames) stack_video = vr.get_batch(stack_frame_idx).asnumpy() all_frames = [Image.fromarray(v.astype("uint8")).convert("RGB") for v in stack_video] # 将每秒的帧合并成一张 stackimage stacked_video_segments = [] frame_cursor = 0 for sec in range(num_seconds): # 找出属于当前秒的帧(时间范围 [sec, sec+1)) frames_this_second = [] while frame_cursor < len(all_frame_timestamps) and all_frame_timestamps[frame_cursor] < sec + 1: frames_this_second.append(all_frames[frame_cursor]) frame_cursor += 1 if len(frames_this_second) > 0: stacked_frame = concat_images(frames_this_second) stacked_video_segments.append(stacked_frame) else: # 如果当前秒没有帧(末尾不足),用 None 占位 stacked_video_segments.append(None) # 加载音频 if audio_path is None: try: audio_np, sr = librosa.load(video_path, sr=16000, mono=True) except: video_clip = VideoFileClip(video_path) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file: temp_audio_file_path = temp_audio_file.name video_clip.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000) audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True) else: audio_np, sr = librosa.load(audio_path, sr=16000, mono=True) # segment audio according to the timestamps audio_segments = [] for i in range(len(timestamps)): start_time = timestamps[i] if i < len(timestamps) - 1: end_time = timestamps[i + 1] else: end_time = duration start_sample = int(start_time * sr) end_sample = int(end_time * sr) segment = audio_np[start_sample:end_sample] # 确保最后一个零头 segment 长度大于 0.1s if i == len(timestamps) - 1 and len(segment) < 1600: segment = np.concatenate([segment, np.zeros(1600 - len(segment), dtype=segment.dtype)]) audio_segments.append(segment) return video_segments, audio_segments, stacked_video_segments