YummyYum's picture
Upload folder using huggingface_hub
be99bcf verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import base64
import logging
import math
import os
import tempfile
from io import BytesIO
import librosa
import numpy as np
import torch
from decord import cpu
from decord import VideoReader
try:
from moviepy import VideoFileClip # moviepy >= 2.0
except ImportError:
from moviepy.editor import VideoFileClip # moviepy < 2.0
from PIL import Image
logger = logging.getLogger(__name__)
def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False):
"""
Incrementally decode tokens from an iterator, handling partial multi-byte characters.
When streaming tokens, multi-byte characters (like Chinese) may be split across multiple
tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function
buffers tokens and only yields complete characters.
Args:
token_iterator: An iterator yielding (token_ids, is_finished) tuples.
token_ids can be torch.Tensor or any iterable of integers.
tokenizer: The tokenizer to use for decoding.
skip_special_tokens: Whether to skip special tokens during decoding.
Yields:
(decoded_text, is_finished) tuples where decoded_text is the new text since last yield.
"""
accumulated_token_ids = []
yielded_text_len = 0
for token_ids, is_finished in token_iterator:
# Accumulate token IDs
if torch.is_tensor(token_ids):
accumulated_token_ids.extend(token_ids.reshape(-1).tolist())
else:
accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids])
# Decode all accumulated tokens
full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens)
if is_finished:
# Final chunk - yield all remaining text
new_text = full_decoded[yielded_text_len:]
yield new_text, is_finished
else:
# Find safe prefix without incomplete multi-byte characters
# The replacement character '�' (U+FFFD) indicates incomplete decoding
new_text = full_decoded[yielded_text_len:]
# Hold back text ending with replacement character (incomplete UTF-8 sequence)
safe_end = len(new_text)
while safe_end > 0 and new_text[safe_end - 1] == "\ufffd":
safe_end -= 1
safe_text = new_text[:safe_end] if safe_end > 0 else ""
yielded_text_len += len(safe_text)
yield safe_text, is_finished
def torch_clone_recursive(obj):
"""Recursively clone nested containers of torch.Tensors.
Supported container types: dict, list, tuple. Non-container non-Tensor
objects are returned as-is.
"""
if torch.is_tensor(obj):
return obj.clone()
elif isinstance(obj, dict):
return {k: torch_clone_recursive(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [torch_clone_recursive(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(torch_clone_recursive(v) for v in obj)
else:
raise ValueError(f"Unsupported type: {type(obj)}")
def _fmt_bytes(n_bytes: int) -> str:
mb = n_bytes / (1024**2)
return f"{mb:.2f}MB"
def _cuda_tensor_bytes(obj):
total_bytes = 0
if torch.is_tensor(obj):
total_bytes += obj.numel() * obj.element_size()
print(f"cuda tensor: {obj.shape}, total_bytes: {_fmt_bytes(obj.numel() * obj.element_size())}")
return total_bytes
elif isinstance(obj, dict):
for v in obj.values():
total_bytes += _cuda_tensor_bytes(v)
return total_bytes
elif isinstance(obj, (list, tuple)):
for v in obj:
total_bytes += _cuda_tensor_bytes(v)
return total_bytes
else:
raise ValueError(f"Unsupported type: {type(obj)}")
def concat_images(images, bg_color=(255, 255, 255), cell_size=None, line_color=(0, 0, 0), line_width=6):
"""
images: List[PIL.Image.Image]
规则:3 张 -> 1x3;4 张 -> 2x2;9 张 -> 3x3;其余:1xN
仅在拼接处画分界线(不画外框)。
"""
# 统一将输入转换为 PIL.Image:支持 PIL.Image、bytes/bytearray、base64 字符串
_converted_images = []
for im in images:
if isinstance(im, Image.Image):
_converted_images.append(im)
elif isinstance(im, (bytes, bytearray)):
_converted_images.append(Image.open(BytesIO(im)).convert("RGB"))
elif isinstance(im, str):
# 处理形如 'data:image/jpeg;base64,...' 或纯 base64
b64 = im.split(",")[-1] if ";base64," in im else im
img_bytes = base64.b64decode(b64)
_converted_images.append(Image.open(BytesIO(img_bytes)).convert("RGB"))
else:
raise TypeError(f"Unsupported image type: {type(im)}")
images = _converted_images
n = len(images)
if n == 0:
raise ValueError("images is empty")
if n == 4:
rows, cols = 2, 2
elif n == 3:
# 动态选择 1x3 / 3x1 / 2x2,使最终更接近正方形
# 先用原图最大宽高确定单元格尺寸(下方 letterbox 会自适应)
if cell_size is None:
cell_w = max(im.width for im in images)
cell_h = max(im.height for im in images)
else:
cell_w, cell_h = cell_size
candidates = [(1, 3), (3, 1)]
def canvas_ratio(r, c):
W = c * cell_w + (c - 1) * line_width
H = r * cell_h + (r - 1) * line_width
return W / max(1, H)
ratios = [abs(canvas_ratio(r, c) - 1.0) for (r, c) in candidates]
best_idx = int(np.argmin(ratios))
rows, cols = candidates[best_idx]
elif n == 1:
rows, cols = 1, 1
elif n == 2:
# 动态选择 1x2 / 2x1,使最终更接近正方形
if cell_size is None:
cell_w = max(im.width for im in images)
cell_h = max(im.height for im in images)
else:
cell_w, cell_h = cell_size
candidates = [(1, 2), (2, 1)]
def canvas_ratio(r, c):
W = c * cell_w + (c - 1) * line_width
H = r * cell_h + (r - 1) * line_width
return W / max(1, H)
ratios = [abs(canvas_ratio(r, c) - 1.0) for (r, c) in candidates]
# 如出现并列,依据平均宽高比进行决策:横向排列适合横图,纵向排列适合竖图
if ratios[0] == ratios[1]:
avg_ar = np.mean([im.width / max(1, im.height) for im in images])
rows, cols = (1, 2) if avg_ar >= 1.0 else (2, 1)
else:
best_idx = int(np.argmin(ratios))
rows, cols = candidates[best_idx]
else:
rows, cols = 1, n
# 单元格尺寸
if cell_size is None:
cell_w = max(im.width for im in images)
cell_h = max(im.height for im in images)
else:
cell_w, cell_h = cell_size
# 保持纵横比缩放到单元格
def letterbox(im, tw, th):
im = im.convert("RGB")
w, h = im.size
s = min(tw / w, th / h)
nw, nh = max(1, int(round(w * s))), max(1, int(round(h * s)))
try:
im_r = im.resize((nw, nh), Image.Resampling.BICUBIC)
except AttributeError:
im_r = im.resize((nw, nh), Image.BICUBIC)
canvas = Image.new("RGB", (tw, th), bg_color)
canvas.paste(im_r, ((tw - nw) // 2, (th - nh) // 2))
return canvas
# 仅在内部缝隙处留出 line_width 的带状区域作为分界线
W = cols * cell_w + (cols - 1) * line_width
H = rows * cell_h + (rows - 1) * line_width
canvas = Image.new("RGB", (W, H), line_color)
for i, im in enumerate(images[: rows * cols]):
r, c = divmod(i, cols)
cell = letterbox(im, cell_w, cell_h)
x = c * (cell_w + line_width)
y = r * (cell_h + line_width)
canvas.paste(cell, (x, y))
return canvas
MAX_NUM_FRAMES = int(os.getenv("MAX_NUM_FRAMES", 64))
VIDEO_MME_DURATION = os.getenv("VIDEO_MME_DURATION", "ALL")
def uniform_sample(l, n):
if len(l) <= n:
return l
idxs = np.linspace(0, len(l) - 1, n, dtype=int)
return [l[i] for i in idxs]
def get_video_frame_audio_segments(video_path, audio_path=None, last_vad_timestamp=None, stack_frames=1):
vr = VideoReader(str(video_path), ctx=cpu(0))
avg_fps = vr.get_avg_fps()
duration = len(vr) / avg_fps
if last_vad_timestamp is not None:
duration = last_vad_timestamp
# 按秒计算时间戳(用于音频分割)
num_seconds = math.ceil(duration)
second_timestamps = list(range(num_seconds))
# 提取原始帧(每秒 1 帧,在每秒开头 0.0s, 1.0s, 2.0s...)
if duration > MAX_NUM_FRAMES:
timestamps = [round(i * 0.1, 1) for i in range(int(duration / 0.1))]
frame_idx = [min(int(ts * avg_fps), len(vr) - 1) for ts in timestamps]
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
timestamps = uniform_sample(timestamps, MAX_NUM_FRAMES)
else:
frame_idx = [int(i * avg_fps) for i in range(num_seconds)]
timestamps = second_timestamps
video = vr.get_batch(frame_idx).asnumpy()
video_segments = [Image.fromarray(v.astype("uint8")).convert("RGB") for v in video]
# 如果 stack_frames > 1,额外提取高帧率帧并合并成 stackimage
# 每秒跳过第一帧(与 1fps 重复),只取剩余的 (stack_frames-1) 帧
stacked_video_segments = None
if stack_frames > 1:
# 按 stack_frames fps 抽帧,但跳过每秒第一帧(i=0)
# 例如 stack_frames=5 时,每秒取 i=1,2,3,4 即 0.2s, 0.4s, 0.6s, 0.8s
all_frame_timestamps = []
for sec in range(num_seconds):
for i in range(1, stack_frames): # 从 1 开始,跳过 0
ts = sec + i / stack_frames
if ts < duration:
all_frame_timestamps.append(ts)
stack_frame_idx = [min(int(ts * avg_fps), len(vr) - 1) for ts in all_frame_timestamps]
# 如果总帧数超过限制,需要均匀采样
max_stack_frames = MAX_NUM_FRAMES * (stack_frames - 1)
if len(stack_frame_idx) > max_stack_frames:
stack_frame_idx = uniform_sample(stack_frame_idx, max_stack_frames)
all_frame_timestamps = uniform_sample(all_frame_timestamps, max_stack_frames)
stack_video = vr.get_batch(stack_frame_idx).asnumpy()
all_frames = [Image.fromarray(v.astype("uint8")).convert("RGB") for v in stack_video]
# 将每秒的帧合并成一张 stackimage
stacked_video_segments = []
frame_cursor = 0
for sec in range(num_seconds):
# 找出属于当前秒的帧(时间范围 [sec, sec+1))
frames_this_second = []
while frame_cursor < len(all_frame_timestamps) and all_frame_timestamps[frame_cursor] < sec + 1:
frames_this_second.append(all_frames[frame_cursor])
frame_cursor += 1
if len(frames_this_second) > 0:
stacked_frame = concat_images(frames_this_second)
stacked_video_segments.append(stacked_frame)
else:
# 如果当前秒没有帧(末尾不足),用 None 占位
stacked_video_segments.append(None)
# 加载音频
if audio_path is None:
try:
audio_np, sr = librosa.load(video_path, sr=16000, mono=True)
except:
video_clip = VideoFileClip(video_path)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file:
temp_audio_file_path = temp_audio_file.name
video_clip.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000)
audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True)
else:
audio_np, sr = librosa.load(audio_path, sr=16000, mono=True)
# segment audio according to the timestamps
audio_segments = []
for i in range(len(timestamps)):
start_time = timestamps[i]
if i < len(timestamps) - 1:
end_time = timestamps[i + 1]
else:
end_time = duration
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
segment = audio_np[start_sample:end_sample]
# 确保最后一个零头 segment 长度大于 0.1s
if i == len(timestamps) - 1 and len(segment) < 1600:
segment = np.concatenate([segment, np.zeros(1600 - len(segment), dtype=segment.dtype)])
audio_segments.append(segment)
return video_segments, audio_segments, stacked_video_segments