MiniCPM-o-4.5-nvidia-FlagOS / processing_minicpmo.py
YummyYum's picture
Upload folder using huggingface_hub
be99bcf verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2025 The OpenBMB Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for MiniCPMO.
"""
import math
import re
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTokenizedInput
from transformers.tokenization_utils_base import TextInput
from transformers.utils import TensorType
from .audio_utils import process_audio_batch
from .audio_utils import regroup_audio_features
from .processing_audio_minicpma import MiniCPMAAudioProcessor # noqa: F401
from .processing_image_minicpmv import MiniCPMOBatchFeature
from .processing_streaming_mel import StreamingMelProcessorExact
class MiniCPMOProcessor(ProcessorMixin):
"""
MiniCPMO多模态处理器
集成了图像、音频和文本处理功能,支持批量处理和流式音频处理。
"""
attributes = ["image_processor", "audio_processor", "tokenizer"]
audio_processor_class = "AutoFeatureExtractor"
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, audio_processor=None, tokenizer=None, **kwargs):
"""
初始化处理器
参数:
image_processor: 图像处理器
audio_processor: 音频特征提取器
tokenizer: 文本分词器
**kwargs: 其他参数
"""
super().__init__(image_processor, audio_processor, tokenizer)
# 保存版本信息
self.version = image_processor.version if image_processor else None
# 音频特征池化步长,需要与 config.audio_pool_step 保持一致
self.pool_step = kwargs.get("audio_pool_step", 5)
# 初始化流式音频处理器
self._streaming_mel_processor = None
if audio_processor is not None:
self._init_streaming_processor()
def get_audio_placeholder(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
"""
Public method to get audio placeholder string for vLLM integration.
Args:
audio_lens: Length of audio in samples
chunk_input: Whether to use chunked processing
chunk_length: Chunk length in seconds
Returns:
Audio placeholder string
"""
pool_step = self.pool_step
feature_lens = math.ceil(audio_lens / self.audio_processor.hop_length)
feature_lens = (feature_lens - 1) // 2 + 1
output_lens = (feature_lens - pool_step) // pool_step + 1
if chunk_input:
fbank_feat_in_chunk = int(chunk_length * 100)
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
audio_embeds_in_chunk = (cnn_feat_in_chunk - pool_step) // pool_step + 1
num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk
place_holders = ""
total_unk_len = 0
for _ in range(num_audio_chunks):
unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
total_unk_len += unk_len
audio_placeholder = place_holders
else:
audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
return audio_placeholder
def _init_streaming_processor(
self,
chunk_ms: int = 100,
verbose: bool = False,
cnn_redundancy_ms: int = 0,
*,
mode: str = "exact",
first_chunk_ms: Optional[int] = None,
# 滑窗参数(trigger模式)
enable_sliding_window: bool = False,
slide_trigger_seconds: float = 30.0,
slide_stride_seconds: float = 10.0,
):
"""
初始化流式处理器(内部方法)
参数:
chunk_ms: chunk大小(毫秒),也是滑动步长
verbose: 是否打印详细信息
cnn_redundancy_ms: CNN边界的前后冗余(毫秒),0表示标准模式
mode: 流式处理模式,目前仅支持 "exact"
first_chunk_ms: 第一个chunk的大小(毫秒),如果不指定则与chunk_ms相同
enable_sliding_window: 是否启用滑窗(trigger模式)
slide_trigger_seconds: 触发滑窗的缓冲区秒数阈值
slide_stride_seconds: 每次滑窗移动的秒数
"""
if mode == "exact":
self._streaming_mel_processor = StreamingMelProcessorExact(
feature_extractor=self.audio_processor,
chunk_ms=chunk_ms,
first_chunk_ms=first_chunk_ms,
sample_rate=16000,
verbose=verbose,
cnn_redundancy_ms=cnn_redundancy_ms,
# 滑窗参数(trigger模式)
enable_sliding_window=enable_sliding_window,
slide_trigger_seconds=slide_trigger_seconds,
slide_stride_seconds=slide_stride_seconds,
)
else:
raise ValueError(f"不支持的模式:{mode},可选 'exact'")
self._streaming_mode = mode if mode in ["exact"] else ("exact")
def set_streaming_mode(
self,
mode: str = "exact",
chunk_ms: int = 100,
cnn_redundancy_ms: int = 0,
verbose: bool = False,
*,
first_chunk_ms: Optional[int] = None,
# 滑窗参数(trigger模式)
enable_sliding_window: bool = False,
slide_trigger_seconds: float = 30.0,
slide_stride_seconds: float = 10.0,
):
"""
设置流式处理模式
参数:
mode: 流式处理模式,目前仅支持 "exact"
chunk_ms: chunk大小(毫秒),也是滑动步长
cnn_redundancy_ms: CNN边界的前后冗余(毫秒),0表示标准模式
verbose: 是否打印详细信息
first_chunk_ms: 第一个chunk的大小(毫秒),如果不指定则与chunk_ms相同
enable_sliding_window: 是否启用滑窗(trigger模式)
slide_trigger_seconds: 触发滑窗的缓冲区秒数阈值
slide_stride_seconds: 每次滑窗移动的秒数
"""
if self.audio_processor is None:
raise ValueError("audio_processor未设置,无法初始化流式处理器")
self._init_streaming_processor(
chunk_ms=chunk_ms,
verbose=verbose,
cnn_redundancy_ms=cnn_redundancy_ms,
mode=mode,
first_chunk_ms=first_chunk_ms,
# 滑窗参数(trigger模式)
enable_sliding_window=enable_sliding_window,
slide_trigger_seconds=slide_trigger_seconds,
slide_stride_seconds=slide_stride_seconds,
)
# ============== 图像处理 ==============
def process_image(
self,
images: Optional[ImageInput] = None,
do_pad: bool = True,
max_slice_nums: int = 1,
return_tensors: str = "pt",
) -> MiniCPMOBatchFeature:
"""
处理图像数据
参数:
images: 输入图像
do_pad: 是否进行padding
max_slice_nums: 最大切片数
return_tensors: 返回张量类型
返回:
MiniCPMOBatchFeature对象
"""
if images is None:
return MiniCPMOBatchFeature(data={"pixel_values": [[]], "image_sizes": [[]], "tgt_sizes": [[]]})
result = self.image_processor(
images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
)
model_inputs = {
"pixel_values": result.get("pixel_values", [[]]),
"image_sizes": result.get("image_sizes", [[]]),
"tgt_sizes": result.get("tgt_sizes", [[]]),
}
return MiniCPMOBatchFeature(data=model_inputs)
# ============== 音频处理(批量) ==============
def process_audio(
self,
audios: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
sampling_rate: int = 16000,
regroup_to_seconds: Optional[int] = None,
fps: int = 100,
) -> MiniCPMOBatchFeature:
"""
批量处理音频数据
参数:
audios: 音频数据
sampling_rate: 采样率
regroup_to_seconds: 重组时长(秒)
fps: 每秒帧数
返回:
MiniCPMOBatchFeature对象
"""
if audios is None:
return MiniCPMOBatchFeature(data={"audio_features": [], "audio_feature_lens": []})
# 使用工具函数处理音频
audio_features, audio_feature_lens = process_audio_batch(
audios=audios,
feature_extractor=self.audio_processor,
sampling_rate=sampling_rate,
max_duration_seconds=30,
return_attention_mask=True,
)
# 可选:重组音频特征
if regroup_to_seconds is not None and len(audio_features) > 0:
audio_features, audio_feature_lens = regroup_audio_features(
audio_features=audio_features,
audio_feature_lens=audio_feature_lens,
regroup_seconds=regroup_to_seconds,
fps=fps,
)
model_inputs = {"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}
return MiniCPMOBatchFeature(data=model_inputs)
# ============== 流式音频处理 ==============
def process_audio_streaming(
self,
audio_chunk: np.ndarray,
reset: bool = False,
return_batch_feature: bool = False,
is_last_chunk: bool = False,
) -> Union[Tuple[torch.Tensor, dict], MiniCPMOBatchFeature]:
"""
流式处理音频chunk
参数:
audio_chunk: 音频数据chunk(随意的一段音频, 比如 第一次 process 125ms,第二次开始 process 100ms)
reset: 是否重置处理器状态
return_batch_feature: 是否返回MiniCPMOBatchFeature格式(与process_audio一致)
返回:
如果 return_batch_feature=False:
(audio_features, info)
- audio_features: [1, 80, n_frames] mel特征
- info: 处理信息字典
如果 return_batch_feature=True:
MiniCPMOBatchFeature对象,包含:
- audio_features: [1, 80, n_frames] mel特征
- audio_feature_lens: [tensor([n_frames])]
- info: 处理信息(作为额外属性)
"""
if self._streaming_mel_processor is None:
raise ValueError("流式处理器未初始化,请确保audio_processor已设置")
if reset:
self._streaming_mel_processor.reset()
# 处理chunk
mel_features, info = self._streaming_mel_processor.process(audio_chunk, is_last_chunk=is_last_chunk)
# 根据参数决定返回格式
if return_batch_feature:
# 返回与process_audio一致的格式
# 注意:info中返回的是emitted_frames,表示实际输出的帧数
n_frames = info.get("emitted_frames", mel_features.shape[-1])
model_inputs = {
"audio_features": mel_features,
"audio_feature_lens": [torch.tensor([n_frames])],
"streaming_info": info, # 添加流式处理信息
}
return MiniCPMOBatchFeature(data=model_inputs)
else:
# 返回原始格式(向后兼容)
return mel_features, info
def reset_streaming(self):
"""重置流式处理器状态"""
if self._streaming_mel_processor is not None:
self._streaming_mel_processor.reset()
def get_streaming_chunk_size(self) -> int:
"""
获取当前需要的audio chunk大小
返回:
samples数
"""
if self._streaming_mel_processor is None:
raise ValueError("流式处理器未初始化")
return self._streaming_mel_processor.get_chunk_size()
def configure_streaming(
self,
chunk_ms: int = 100,
verbose: bool = False,
# 滑窗参数(trigger模式)
enable_sliding_window: bool = False,
slide_trigger_seconds: float = 30.0,
slide_stride_seconds: float = 10.0,
):
"""
配置流式处理器参数
参数:
chunk_ms: chunk大小(毫秒)
verbose: 是否打印详细信息
enable_sliding_window: 是否启用滑窗(trigger模式)
slide_trigger_seconds: 触发滑窗的缓冲区秒数阈值
slide_stride_seconds: 每次滑窗移动的秒数
"""
if self.audio_processor is None:
raise ValueError("audio_processor未设置")
# 重新初始化流式处理器
self._init_streaming_processor(
chunk_ms=chunk_ms,
verbose=verbose,
enable_sliding_window=enable_sliding_window,
slide_trigger_seconds=slide_trigger_seconds,
slide_stride_seconds=slide_stride_seconds,
)
if verbose:
print(
f"流式处理器已重新配置: chunk_ms={chunk_ms}ms, sliding_window={'enabled' if enable_sliding_window else 'disabled'}"
)
def get_streaming_config(self) -> dict:
"""
获取流式处理器配置
返回:
配置字典
"""
if self._streaming_mel_processor is None:
return {}
return self._streaming_mel_processor.get_config()
def get_streaming_state(self) -> dict:
"""
获取流式处理器状态
返回:
状态字典
"""
if self._streaming_mel_processor is None:
return {}
return self._streaming_mel_processor.get_state()
def get_streaming_snapshot(self) -> dict:
"""
获取流式处理器完整快照(包括 buffer),用于抢跑恢复
返回:
快照字典
"""
if self._streaming_mel_processor is None:
return {}
return self._streaming_mel_processor.get_snapshot()
def restore_streaming_snapshot(self, snapshot: dict) -> None:
"""
从快照恢复流式处理器状态
参数:
snapshot: 快照字典(由 get_streaming_snapshot 返回)
"""
if self._streaming_mel_processor is None:
return
if not snapshot:
return
self._streaming_mel_processor.restore_snapshot(snapshot)
# ============== 兼容性方法(从 modeling_minicpmo/processing_minicpmo.py 继承) ==============
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: ImageInput = None,
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]] = None,
audio_parts: Optional[list] = None,
max_length: Optional[int] = None,
do_pad: Optional[bool] = True,
max_slice_nums: int = None,
use_image_id: bool = True,
stream_input: bool = False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
sampling_rate: Optional[int] = 16000,
online_streaming: bool = False,
audio_chunk_idx: int = 0,
is_last_chunk: bool = False,
**kwargs,
) -> MiniCPMOBatchFeature:
"""
完整的多模态处理调用方法
支持图像、音频和文本的联合处理。
"""
# 复用 process_image 处理图像(返回的 MiniCPMOBatchFeature 可直接当字典用)
if images is not None:
image_inputs = self.process_image(
images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
)
else:
image_inputs = None
# 音频处理(内部根据 online_streaming 决定使用批量还是流式处理)
audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract(
audios,
audio_parts,
stream_input,
sampling_rate,
online_streaming=online_streaming,
is_last_chunk=is_last_chunk,
)
model_inputs = self._convert_omni_to_inputs(
image_inputs,
audio_phs,
text,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
max_length=max_length,
**kwargs,
)
model_inputs["audio_features"] = audio_features
model_inputs["audio_feature_lens"] = audio_feature_lens
result = MiniCPMOBatchFeature(data={**model_inputs})
# online_streaming 模式下设置流式处理参数
if online_streaming:
result.use_extra_context = True
result.prefix_extra_frames = 0 if audio_chunk_idx == 0 else 2
result.suffix_extra_frames = 2
result.chunk_idx = audio_chunk_idx
return result
def audio_feature_extract(
self,
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]], None] = None,
audio_parts: Optional[list] = None,
stream_input: Optional[bool] = False,
sampling_rate: Optional[int] = None,
chunk_length: Optional[int] = 1,
online_streaming: bool = False,
is_last_chunk: bool = False,
**kwargs,
):
"""
提取音频特征并生成 placeholder
参数:
online_streaming: 在线流式处理模式,使用 process_audio_streaming
"""
# 空输入直接返回
if audios is None:
return [], [], []
if isinstance(audios, np.ndarray):
audios_list = [[audios]]
elif isinstance(audios[0], np.ndarray):
audios_list = [audios]
else:
audios_list = audios
if audio_parts is not None:
assert len(audio_parts) == len(audios_list)
for parts, audios in zip(audio_parts, audios_list):
assert len(parts) == len(audios)
audio_feature_lens_list = []
audio_ph_list = []
audio_features_all = []
# audio placeholder not dependent on audio_parts
for audios in audios_list:
if audios:
audio_ph_list.append(
[
self.get_audio_placeholder(len(a), chunk_input=stream_input, chunk_length=chunk_length)
for a in audios
]
)
else:
audio_ph_list.append([])
for idx, audios in enumerate(audios_list):
if audio_parts is not None:
# same audio part merge
audio_part = audio_parts[idx]
merge_audio = []
cur_audio = []
for aid, (part, audio) in enumerate(zip(audio_part, audios)):
if aid == 0 or audio_part[aid] == audio_part[aid - 1]:
cur_audio.append(audio)
else:
merge_audio.append(np.hstack(cur_audio))
cur_audio = [audio]
if cur_audio:
merge_audio.append(np.hstack(cur_audio))
else:
merge_audio = audios
# If the audio exceeds 30 seconds, split it into chunks every 30 seconds.
final_merge_audio = []
max_audio_inp_len = 30 * sampling_rate
for audio in merge_audio:
if len(audio) <= max_audio_inp_len:
final_merge_audio.append(audio)
else:
for i in range(math.ceil(len(audio) / max_audio_inp_len)):
final_merge_audio.append(audio[i * max_audio_inp_len : (i + 1) * max_audio_inp_len])
audio_feature_lens = []
if audios:
if online_streaming:
# 唯一加入代码 @LINBO CHECK
# 在线流式处理:只支持单个音频,直接使用 process_audio_streaming 返回格式
assert (
len(final_merge_audio) == 1
), f"online_streaming 模式只支持单个音频,当前有 {len(final_merge_audio)} 个"
audio = final_merge_audio[0]
result = self.process_audio_streaming(
audio, reset=False, return_batch_feature=True, is_last_chunk=is_last_chunk
)
audio_features_all.append(
result["audio_features"].squeeze(0)
) # [1, 80, T] -> [80, T],保持和批量处理一致
audio_feature_lens_list.append(result["audio_feature_lens"][0])
else:
# 批量处理:调用 audio_processor
audio_inputs = self.audio_processor(
final_merge_audio,
sampling_rate=sampling_rate,
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
**kwargs,
)
audio_feature = audio_inputs["input_features"]
actual_lens = audio_inputs["attention_mask"].sum(dim=1)
for feat, lens in zip(audio_feature, actual_lens):
audio_features_all.append(feat[:, :lens])
audio_feature_lens.append(lens)
audio_feature_lens = torch.hstack(audio_feature_lens)
audio_feature_lens_list.append(audio_feature_lens)
else:
audio_feature_lens_list.append([])
if audio_features_all:
audio_features = [i.permute(1, 0) for i in audio_features_all]
audio_features = torch.nn.utils.rnn.pad_sequence(
audio_features, batch_first=True, padding_value=0.0
).permute(0, 2, 1)
else:
audio_features = []
return audio_features, audio_feature_lens_list, audio_ph_list
def _convert(self, input_str, max_inp_length: Optional[int] = None):
old_input_ids = self.tokenizer.encode(input_str)
listen_token_id = self.tokenizer.convert_tokens_to_ids("<|listen|>")
input_ids = []
for token in old_input_ids:
if token != listen_token_id:
input_ids.append(token)
if max_inp_length is not None:
input_ids = input_ids[:max_inp_length]
input_ids = torch.tensor(input_ids, dtype=torch.int32)
## image bound
start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
image_start_idx = torch.where(start_cond)[0]
image_start_idx += 1
image_end_idx = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_idx), len(image_end_idx))
image_bounds = torch.hstack(
[
image_start_idx[:valid_image_nums].unsqueeze(-1),
image_end_idx[:valid_image_nums].unsqueeze(-1),
]
)
## audio bound
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
assert len(audio_start_idx) == len(audio_end_idx)
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
assert len(spk_start_idx) == len(spk_end_idx)
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
return input_ids, image_bounds, audio_bounds, spk_bounds
def _convert_omni_to_inputs(
self,
images,
audio_phs,
texts: Union[str, List[str]],
truncation=None,
max_length=None,
max_slice_nums=None,
use_image_id=None,
return_tensors=None,
**kwargs,
):
if images is None and audio_phs is None:
model_inputs = self.tokenizer(
texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
)
return MiniCPMOBatchFeature(data={**model_inputs})
image_pattern = "<image>./</image>"
audio_pattern = "<audio>./</audio>"
split_pattern = f"({image_pattern}|{audio_pattern})"
if isinstance(texts, str):
texts = [texts]
bs = len(texts)
if images is not None:
images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
else:
images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs
input_ids_list = []
image_bounds_list = []
audio_bounds_list = []
spk_bounds_list = []
for index, text in enumerate(texts):
text_chunks = re.split(split_pattern, text)
image_tags = re.findall(image_pattern, text)
audio_tags = re.findall(audio_pattern, text)
if image_tags:
assert images is not None
assert len(image_tags) == len(image_sizes[index])
if audio_tags:
assert audio_phs is not None
assert len(audio_tags) == len(audio_phs[index])
image_id = 0
audio_id = 0
for i, chunk in enumerate(text_chunks):
if chunk == image_pattern:
image_placeholder = self.image_processor.get_slice_image_placeholder(
image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
)
image_id += 1
text_chunks[i] = image_placeholder
elif chunk == audio_pattern:
audio_placeholder = audio_phs[index][audio_id]
audio_id += 1
text_chunks[i] = audio_placeholder
final_text = "".join(text_chunks)
input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length)
input_ids_list.append(input_ids)
image_bounds_list.append(image_bounds)
audio_bounds_list.append(audio_bounds)
spk_bounds_list.append(spk_bounds)
padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool)
for i, length in enumerate(padding_lengths):
image_bounds_list[i] = image_bounds_list[i] + length
audio_bounds_list[i] = audio_bounds_list[i] + length
spk_bounds_list[i] = spk_bounds_list[i] + length
attention_mask[i, :length] = False
data = {
"input_ids": padded_input_ids,
"attention_mask": attention_mask,
"pixel_values": images,
"image_sizes": image_sizes,
"image_bound": image_bounds_list,
"tgt_sizes": tgt_sizes,
"audio_bounds": audio_bounds_list,
"spk_bounds": spk_bounds_list,
}
return data
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(inputs[0], list):
assert isinstance(inputs[0][0], torch.Tensor)
for it in inputs:
for tr in it:
items.append(tr)
else:
assert isinstance(inputs[0], torch.Tensor)
items = inputs
batch_size = len(items)
shape = items[0].shape
dim = len(shape)
assert dim <= 2
if max_length is None:
max_length = 0
max_length = max(max_length, max(item.shape[-1] for item in items))
min_length = min(item.shape[-1] for item in items)
dtype = items[0].dtype
if dim == 0:
return torch.stack([item for item in items], dim=0), [0]
elif dim == 1:
if max_length == min_length:
return torch.stack([item for item in items], dim=0), [0] * batch_size
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
padding_length = []
for i, item in enumerate(items):
if dim == 1:
if padding_side == "left":
tensor[i, -len(item) :] = item.clone()
else:
tensor[i, : len(item)] = item.clone()
elif dim == 2:
if padding_side == "left":
tensor[i, -len(item) :, :] = item.clone()
else:
tensor[i, : len(item), :] = item.clone()
padding_length.append(tensor.shape[-1] - len(item))
return tensor, padding_length