| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Processor class for MiniCPMO. |
| """ |
|
|
| import math |
| import re |
| from typing import List |
| from typing import Optional |
| from typing import Tuple |
| from typing import Union |
|
|
| import numpy as np |
| import torch |
| from transformers.image_utils import ImageInput |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.tokenization_utils_base import PreTokenizedInput |
| from transformers.tokenization_utils_base import TextInput |
| from transformers.utils import TensorType |
|
|
| from .audio_utils import process_audio_batch |
| from .audio_utils import regroup_audio_features |
| from .processing_audio_minicpma import MiniCPMAAudioProcessor |
| from .processing_image_minicpmv import MiniCPMOBatchFeature |
| from .processing_streaming_mel import StreamingMelProcessorExact |
|
|
|
|
| class MiniCPMOProcessor(ProcessorMixin): |
| """ |
| MiniCPMO多模态处理器 |
| |
| 集成了图像、音频和文本处理功能,支持批量处理和流式音频处理。 |
| """ |
|
|
| attributes = ["image_processor", "audio_processor", "tokenizer"] |
| audio_processor_class = "AutoFeatureExtractor" |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| def __init__(self, image_processor=None, audio_processor=None, tokenizer=None, **kwargs): |
| """ |
| 初始化处理器 |
| |
| 参数: |
| image_processor: 图像处理器 |
| audio_processor: 音频特征提取器 |
| tokenizer: 文本分词器 |
| **kwargs: 其他参数 |
| """ |
| super().__init__(image_processor, audio_processor, tokenizer) |
|
|
| |
| self.version = image_processor.version if image_processor else None |
| |
| self.pool_step = kwargs.get("audio_pool_step", 5) |
|
|
| |
| self._streaming_mel_processor = None |
| if audio_processor is not None: |
| self._init_streaming_processor() |
|
|
| def get_audio_placeholder( |
| self, |
| audio_lens: int, |
| chunk_input: bool = True, |
| chunk_length: int = 1, |
| ) -> str: |
| """ |
| Public method to get audio placeholder string for vLLM integration. |
| |
| Args: |
| audio_lens: Length of audio in samples |
| chunk_input: Whether to use chunked processing |
| chunk_length: Chunk length in seconds |
| |
| Returns: |
| Audio placeholder string |
| """ |
| pool_step = self.pool_step |
| feature_lens = math.ceil(audio_lens / self.audio_processor.hop_length) |
|
|
| feature_lens = (feature_lens - 1) // 2 + 1 |
| output_lens = (feature_lens - pool_step) // pool_step + 1 |
|
|
| if chunk_input: |
| fbank_feat_in_chunk = int(chunk_length * 100) |
| cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 |
| audio_embeds_in_chunk = (cnn_feat_in_chunk - pool_step) // pool_step + 1 |
| num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk |
|
|
| place_holders = "" |
| total_unk_len = 0 |
| for _ in range(num_audio_chunks): |
| unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len) |
| place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end |
| total_unk_len += unk_len |
| audio_placeholder = place_holders |
| else: |
| audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end |
|
|
| return audio_placeholder |
|
|
| def _init_streaming_processor( |
| self, |
| chunk_ms: int = 100, |
| verbose: bool = False, |
| cnn_redundancy_ms: int = 0, |
| *, |
| mode: str = "exact", |
| first_chunk_ms: Optional[int] = None, |
| |
| enable_sliding_window: bool = False, |
| slide_trigger_seconds: float = 30.0, |
| slide_stride_seconds: float = 10.0, |
| ): |
| """ |
| 初始化流式处理器(内部方法) |
| 参数: |
| chunk_ms: chunk大小(毫秒),也是滑动步长 |
| verbose: 是否打印详细信息 |
| cnn_redundancy_ms: CNN边界的前后冗余(毫秒),0表示标准模式 |
| mode: 流式处理模式,目前仅支持 "exact" |
| first_chunk_ms: 第一个chunk的大小(毫秒),如果不指定则与chunk_ms相同 |
| enable_sliding_window: 是否启用滑窗(trigger模式) |
| slide_trigger_seconds: 触发滑窗的缓冲区秒数阈值 |
| slide_stride_seconds: 每次滑窗移动的秒数 |
| """ |
| if mode == "exact": |
| self._streaming_mel_processor = StreamingMelProcessorExact( |
| feature_extractor=self.audio_processor, |
| chunk_ms=chunk_ms, |
| first_chunk_ms=first_chunk_ms, |
| sample_rate=16000, |
| verbose=verbose, |
| cnn_redundancy_ms=cnn_redundancy_ms, |
| |
| enable_sliding_window=enable_sliding_window, |
| slide_trigger_seconds=slide_trigger_seconds, |
| slide_stride_seconds=slide_stride_seconds, |
| ) |
| else: |
| raise ValueError(f"不支持的模式:{mode},可选 'exact'") |
| self._streaming_mode = mode if mode in ["exact"] else ("exact") |
|
|
| def set_streaming_mode( |
| self, |
| mode: str = "exact", |
| chunk_ms: int = 100, |
| cnn_redundancy_ms: int = 0, |
| verbose: bool = False, |
| *, |
| first_chunk_ms: Optional[int] = None, |
| |
| enable_sliding_window: bool = False, |
| slide_trigger_seconds: float = 30.0, |
| slide_stride_seconds: float = 10.0, |
| ): |
| """ |
| 设置流式处理模式 |
| 参数: |
| mode: 流式处理模式,目前仅支持 "exact" |
| chunk_ms: chunk大小(毫秒),也是滑动步长 |
| cnn_redundancy_ms: CNN边界的前后冗余(毫秒),0表示标准模式 |
| verbose: 是否打印详细信息 |
| first_chunk_ms: 第一个chunk的大小(毫秒),如果不指定则与chunk_ms相同 |
| enable_sliding_window: 是否启用滑窗(trigger模式) |
| slide_trigger_seconds: 触发滑窗的缓冲区秒数阈值 |
| slide_stride_seconds: 每次滑窗移动的秒数 |
| """ |
| if self.audio_processor is None: |
| raise ValueError("audio_processor未设置,无法初始化流式处理器") |
| self._init_streaming_processor( |
| chunk_ms=chunk_ms, |
| verbose=verbose, |
| cnn_redundancy_ms=cnn_redundancy_ms, |
| mode=mode, |
| first_chunk_ms=first_chunk_ms, |
| |
| enable_sliding_window=enable_sliding_window, |
| slide_trigger_seconds=slide_trigger_seconds, |
| slide_stride_seconds=slide_stride_seconds, |
| ) |
|
|
| |
|
|
| def process_image( |
| self, |
| images: Optional[ImageInput] = None, |
| do_pad: bool = True, |
| max_slice_nums: int = 1, |
| return_tensors: str = "pt", |
| ) -> MiniCPMOBatchFeature: |
| """ |
| 处理图像数据 |
| |
| 参数: |
| images: 输入图像 |
| do_pad: 是否进行padding |
| max_slice_nums: 最大切片数 |
| return_tensors: 返回张量类型 |
| 返回: |
| MiniCPMOBatchFeature对象 |
| """ |
| if images is None: |
| return MiniCPMOBatchFeature(data={"pixel_values": [[]], "image_sizes": [[]], "tgt_sizes": [[]]}) |
|
|
| result = self.image_processor( |
| images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors |
| ) |
|
|
| model_inputs = { |
| "pixel_values": result.get("pixel_values", [[]]), |
| "image_sizes": result.get("image_sizes", [[]]), |
| "tgt_sizes": result.get("tgt_sizes", [[]]), |
| } |
|
|
| return MiniCPMOBatchFeature(data=model_inputs) |
|
|
| |
|
|
| def process_audio( |
| self, |
| audios: Optional[Union[np.ndarray, List[np.ndarray]]] = None, |
| sampling_rate: int = 16000, |
| regroup_to_seconds: Optional[int] = None, |
| fps: int = 100, |
| ) -> MiniCPMOBatchFeature: |
| """ |
| 批量处理音频数据 |
| |
| 参数: |
| audios: 音频数据 |
| sampling_rate: 采样率 |
| regroup_to_seconds: 重组时长(秒) |
| fps: 每秒帧数 |
| 返回: |
| MiniCPMOBatchFeature对象 |
| """ |
| if audios is None: |
| return MiniCPMOBatchFeature(data={"audio_features": [], "audio_feature_lens": []}) |
|
|
| |
| audio_features, audio_feature_lens = process_audio_batch( |
| audios=audios, |
| feature_extractor=self.audio_processor, |
| sampling_rate=sampling_rate, |
| max_duration_seconds=30, |
| return_attention_mask=True, |
| ) |
|
|
| |
| if regroup_to_seconds is not None and len(audio_features) > 0: |
| audio_features, audio_feature_lens = regroup_audio_features( |
| audio_features=audio_features, |
| audio_feature_lens=audio_feature_lens, |
| regroup_seconds=regroup_to_seconds, |
| fps=fps, |
| ) |
|
|
| model_inputs = {"audio_features": audio_features, "audio_feature_lens": audio_feature_lens} |
|
|
| return MiniCPMOBatchFeature(data=model_inputs) |
|
|
| |
|
|
| def process_audio_streaming( |
| self, |
| audio_chunk: np.ndarray, |
| reset: bool = False, |
| return_batch_feature: bool = False, |
| is_last_chunk: bool = False, |
| ) -> Union[Tuple[torch.Tensor, dict], MiniCPMOBatchFeature]: |
| """ |
| 流式处理音频chunk |
| |
| 参数: |
| audio_chunk: 音频数据chunk(随意的一段音频, 比如 第一次 process 125ms,第二次开始 process 100ms) |
| reset: 是否重置处理器状态 |
| return_batch_feature: 是否返回MiniCPMOBatchFeature格式(与process_audio一致) |
| 返回: |
| 如果 return_batch_feature=False: |
| (audio_features, info) |
| - audio_features: [1, 80, n_frames] mel特征 |
| - info: 处理信息字典 |
| 如果 return_batch_feature=True: |
| MiniCPMOBatchFeature对象,包含: |
| - audio_features: [1, 80, n_frames] mel特征 |
| - audio_feature_lens: [tensor([n_frames])] |
| - info: 处理信息(作为额外属性) |
| """ |
| if self._streaming_mel_processor is None: |
| raise ValueError("流式处理器未初始化,请确保audio_processor已设置") |
|
|
| if reset: |
| self._streaming_mel_processor.reset() |
|
|
| |
| mel_features, info = self._streaming_mel_processor.process(audio_chunk, is_last_chunk=is_last_chunk) |
|
|
| |
| if return_batch_feature: |
| |
| |
| n_frames = info.get("emitted_frames", mel_features.shape[-1]) |
| model_inputs = { |
| "audio_features": mel_features, |
| "audio_feature_lens": [torch.tensor([n_frames])], |
| "streaming_info": info, |
| } |
| return MiniCPMOBatchFeature(data=model_inputs) |
| else: |
| |
| return mel_features, info |
|
|
| def reset_streaming(self): |
| """重置流式处理器状态""" |
| if self._streaming_mel_processor is not None: |
| self._streaming_mel_processor.reset() |
|
|
| def get_streaming_chunk_size(self) -> int: |
| """ |
| 获取当前需要的audio chunk大小 |
| |
| 返回: |
| samples数 |
| """ |
| if self._streaming_mel_processor is None: |
| raise ValueError("流式处理器未初始化") |
| return self._streaming_mel_processor.get_chunk_size() |
|
|
| def configure_streaming( |
| self, |
| chunk_ms: int = 100, |
| verbose: bool = False, |
| |
| enable_sliding_window: bool = False, |
| slide_trigger_seconds: float = 30.0, |
| slide_stride_seconds: float = 10.0, |
| ): |
| """ |
| 配置流式处理器参数 |
| 参数: |
| chunk_ms: chunk大小(毫秒) |
| verbose: 是否打印详细信息 |
| enable_sliding_window: 是否启用滑窗(trigger模式) |
| slide_trigger_seconds: 触发滑窗的缓冲区秒数阈值 |
| slide_stride_seconds: 每次滑窗移动的秒数 |
| """ |
| if self.audio_processor is None: |
| raise ValueError("audio_processor未设置") |
|
|
| |
| self._init_streaming_processor( |
| chunk_ms=chunk_ms, |
| verbose=verbose, |
| enable_sliding_window=enable_sliding_window, |
| slide_trigger_seconds=slide_trigger_seconds, |
| slide_stride_seconds=slide_stride_seconds, |
| ) |
|
|
| if verbose: |
| print( |
| f"流式处理器已重新配置: chunk_ms={chunk_ms}ms, sliding_window={'enabled' if enable_sliding_window else 'disabled'}" |
| ) |
|
|
| def get_streaming_config(self) -> dict: |
| """ |
| 获取流式处理器配置 |
| |
| 返回: |
| 配置字典 |
| """ |
| if self._streaming_mel_processor is None: |
| return {} |
| return self._streaming_mel_processor.get_config() |
|
|
| def get_streaming_state(self) -> dict: |
| """ |
| 获取流式处理器状态 |
| |
| 返回: |
| 状态字典 |
| """ |
| if self._streaming_mel_processor is None: |
| return {} |
| return self._streaming_mel_processor.get_state() |
|
|
| def get_streaming_snapshot(self) -> dict: |
| """ |
| 获取流式处理器完整快照(包括 buffer),用于抢跑恢复 |
| |
| 返回: |
| 快照字典 |
| """ |
| if self._streaming_mel_processor is None: |
| return {} |
| return self._streaming_mel_processor.get_snapshot() |
|
|
| def restore_streaming_snapshot(self, snapshot: dict) -> None: |
| """ |
| 从快照恢复流式处理器状态 |
| |
| 参数: |
| snapshot: 快照字典(由 get_streaming_snapshot 返回) |
| """ |
| if self._streaming_mel_processor is None: |
| return |
| if not snapshot: |
| return |
| self._streaming_mel_processor.restore_snapshot(snapshot) |
|
|
| |
|
|
| def __call__( |
| self, |
| text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], |
| images: ImageInput = None, |
| audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]] = None, |
| audio_parts: Optional[list] = None, |
| max_length: Optional[int] = None, |
| do_pad: Optional[bool] = True, |
| max_slice_nums: int = None, |
| use_image_id: bool = True, |
| stream_input: bool = False, |
| return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| sampling_rate: Optional[int] = 16000, |
| online_streaming: bool = False, |
| audio_chunk_idx: int = 0, |
| is_last_chunk: bool = False, |
| **kwargs, |
| ) -> MiniCPMOBatchFeature: |
| """ |
| 完整的多模态处理调用方法 |
| |
| 支持图像、音频和文本的联合处理。 |
| """ |
| |
| if images is not None: |
| image_inputs = self.process_image( |
| images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors |
| ) |
| else: |
| image_inputs = None |
|
|
| |
| audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract( |
| audios, |
| audio_parts, |
| stream_input, |
| sampling_rate, |
| online_streaming=online_streaming, |
| is_last_chunk=is_last_chunk, |
| ) |
|
|
| model_inputs = self._convert_omni_to_inputs( |
| image_inputs, |
| audio_phs, |
| text, |
| max_slice_nums=max_slice_nums, |
| use_image_id=use_image_id, |
| max_length=max_length, |
| **kwargs, |
| ) |
|
|
| model_inputs["audio_features"] = audio_features |
| model_inputs["audio_feature_lens"] = audio_feature_lens |
|
|
| result = MiniCPMOBatchFeature(data={**model_inputs}) |
|
|
| |
| if online_streaming: |
| result.use_extra_context = True |
| result.prefix_extra_frames = 0 if audio_chunk_idx == 0 else 2 |
| result.suffix_extra_frames = 2 |
| result.chunk_idx = audio_chunk_idx |
|
|
| return result |
|
|
| def audio_feature_extract( |
| self, |
| audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]], None] = None, |
| audio_parts: Optional[list] = None, |
| stream_input: Optional[bool] = False, |
| sampling_rate: Optional[int] = None, |
| chunk_length: Optional[int] = 1, |
| online_streaming: bool = False, |
| is_last_chunk: bool = False, |
| **kwargs, |
| ): |
| """ |
| 提取音频特征并生成 placeholder |
| |
| 参数: |
| online_streaming: 在线流式处理模式,使用 process_audio_streaming |
| """ |
| |
| if audios is None: |
| return [], [], [] |
|
|
| if isinstance(audios, np.ndarray): |
| audios_list = [[audios]] |
| elif isinstance(audios[0], np.ndarray): |
| audios_list = [audios] |
| else: |
| audios_list = audios |
|
|
| if audio_parts is not None: |
| assert len(audio_parts) == len(audios_list) |
| for parts, audios in zip(audio_parts, audios_list): |
| assert len(parts) == len(audios) |
|
|
| audio_feature_lens_list = [] |
| audio_ph_list = [] |
| audio_features_all = [] |
|
|
| |
| for audios in audios_list: |
| if audios: |
| audio_ph_list.append( |
| [ |
| self.get_audio_placeholder(len(a), chunk_input=stream_input, chunk_length=chunk_length) |
| for a in audios |
| ] |
| ) |
| else: |
| audio_ph_list.append([]) |
|
|
| for idx, audios in enumerate(audios_list): |
| if audio_parts is not None: |
| |
| audio_part = audio_parts[idx] |
| merge_audio = [] |
| cur_audio = [] |
| for aid, (part, audio) in enumerate(zip(audio_part, audios)): |
| if aid == 0 or audio_part[aid] == audio_part[aid - 1]: |
| cur_audio.append(audio) |
| else: |
| merge_audio.append(np.hstack(cur_audio)) |
| cur_audio = [audio] |
| if cur_audio: |
| merge_audio.append(np.hstack(cur_audio)) |
| else: |
| merge_audio = audios |
|
|
| |
| final_merge_audio = [] |
| max_audio_inp_len = 30 * sampling_rate |
| for audio in merge_audio: |
| if len(audio) <= max_audio_inp_len: |
| final_merge_audio.append(audio) |
| else: |
| for i in range(math.ceil(len(audio) / max_audio_inp_len)): |
| final_merge_audio.append(audio[i * max_audio_inp_len : (i + 1) * max_audio_inp_len]) |
|
|
| audio_feature_lens = [] |
|
|
| if audios: |
| if online_streaming: |
| |
| |
| assert ( |
| len(final_merge_audio) == 1 |
| ), f"online_streaming 模式只支持单个音频,当前有 {len(final_merge_audio)} 个" |
| audio = final_merge_audio[0] |
| result = self.process_audio_streaming( |
| audio, reset=False, return_batch_feature=True, is_last_chunk=is_last_chunk |
| ) |
| audio_features_all.append( |
| result["audio_features"].squeeze(0) |
| ) |
| audio_feature_lens_list.append(result["audio_feature_lens"][0]) |
| else: |
| |
| audio_inputs = self.audio_processor( |
| final_merge_audio, |
| sampling_rate=sampling_rate, |
| return_attention_mask=True, |
| padding="max_length", |
| return_tensors="pt", |
| **kwargs, |
| ) |
| audio_feature = audio_inputs["input_features"] |
| actual_lens = audio_inputs["attention_mask"].sum(dim=1) |
|
|
| for feat, lens in zip(audio_feature, actual_lens): |
| audio_features_all.append(feat[:, :lens]) |
| audio_feature_lens.append(lens) |
|
|
| audio_feature_lens = torch.hstack(audio_feature_lens) |
| audio_feature_lens_list.append(audio_feature_lens) |
| else: |
| audio_feature_lens_list.append([]) |
|
|
| if audio_features_all: |
| audio_features = [i.permute(1, 0) for i in audio_features_all] |
| audio_features = torch.nn.utils.rnn.pad_sequence( |
| audio_features, batch_first=True, padding_value=0.0 |
| ).permute(0, 2, 1) |
| else: |
| audio_features = [] |
|
|
| return audio_features, audio_feature_lens_list, audio_ph_list |
|
|
| def _convert(self, input_str, max_inp_length: Optional[int] = None): |
| old_input_ids = self.tokenizer.encode(input_str) |
|
|
| listen_token_id = self.tokenizer.convert_tokens_to_ids("<|listen|>") |
| input_ids = [] |
| for token in old_input_ids: |
| if token != listen_token_id: |
| input_ids.append(token) |
|
|
| if max_inp_length is not None: |
| input_ids = input_ids[:max_inp_length] |
| input_ids = torch.tensor(input_ids, dtype=torch.int32) |
|
|
| |
| start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id) |
| end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id) |
|
|
| image_start_idx = torch.where(start_cond)[0] |
| image_start_idx += 1 |
| image_end_idx = torch.where(end_cond)[0] |
|
|
| valid_image_nums = max(len(image_start_idx), len(image_end_idx)) |
|
|
| image_bounds = torch.hstack( |
| [ |
| image_start_idx[:valid_image_nums].unsqueeze(-1), |
| image_end_idx[:valid_image_nums].unsqueeze(-1), |
| ] |
| ) |
|
|
| |
| audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0] |
| audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0] |
| assert len(audio_start_idx) == len(audio_end_idx) |
| audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)]) |
|
|
| spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0] |
| spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0] |
| assert len(spk_start_idx) == len(spk_end_idx) |
| spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) |
|
|
| return input_ids, image_bounds, audio_bounds, spk_bounds |
|
|
| def _convert_omni_to_inputs( |
| self, |
| images, |
| audio_phs, |
| texts: Union[str, List[str]], |
| truncation=None, |
| max_length=None, |
| max_slice_nums=None, |
| use_image_id=None, |
| return_tensors=None, |
| **kwargs, |
| ): |
| if images is None and audio_phs is None: |
| model_inputs = self.tokenizer( |
| texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs |
| ) |
| return MiniCPMOBatchFeature(data={**model_inputs}) |
|
|
| image_pattern = "<image>./</image>" |
| audio_pattern = "<audio>./</audio>" |
| split_pattern = f"({image_pattern}|{audio_pattern})" |
|
|
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| bs = len(texts) |
| if images is not None: |
| images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"] |
| else: |
| images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs |
|
|
| input_ids_list = [] |
| image_bounds_list = [] |
| audio_bounds_list = [] |
| spk_bounds_list = [] |
|
|
| for index, text in enumerate(texts): |
| text_chunks = re.split(split_pattern, text) |
|
|
| image_tags = re.findall(image_pattern, text) |
| audio_tags = re.findall(audio_pattern, text) |
|
|
| if image_tags: |
| assert images is not None |
| assert len(image_tags) == len(image_sizes[index]) |
| if audio_tags: |
| assert audio_phs is not None |
| assert len(audio_tags) == len(audio_phs[index]) |
|
|
| image_id = 0 |
| audio_id = 0 |
| for i, chunk in enumerate(text_chunks): |
| if chunk == image_pattern: |
| image_placeholder = self.image_processor.get_slice_image_placeholder( |
| image_sizes[index][image_id], image_id, max_slice_nums, use_image_id |
| ) |
| image_id += 1 |
| text_chunks[i] = image_placeholder |
| elif chunk == audio_pattern: |
| audio_placeholder = audio_phs[index][audio_id] |
| audio_id += 1 |
| text_chunks[i] = audio_placeholder |
|
|
| final_text = "".join(text_chunks) |
| input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length) |
|
|
| input_ids_list.append(input_ids) |
| image_bounds_list.append(image_bounds) |
| audio_bounds_list.append(audio_bounds) |
| spk_bounds_list.append(spk_bounds) |
|
|
| padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left") |
| attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool) |
| for i, length in enumerate(padding_lengths): |
| image_bounds_list[i] = image_bounds_list[i] + length |
| audio_bounds_list[i] = audio_bounds_list[i] + length |
| spk_bounds_list[i] = spk_bounds_list[i] + length |
| attention_mask[i, :length] = False |
|
|
| data = { |
| "input_ids": padded_input_ids, |
| "attention_mask": attention_mask, |
| "pixel_values": images, |
| "image_sizes": image_sizes, |
| "image_bound": image_bounds_list, |
| "tgt_sizes": tgt_sizes, |
| "audio_bounds": audio_bounds_list, |
| "spk_bounds": spk_bounds_list, |
| } |
|
|
| return data |
|
|
| def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"): |
| items = [] |
| if isinstance(inputs[0], list): |
| assert isinstance(inputs[0][0], torch.Tensor) |
| for it in inputs: |
| for tr in it: |
| items.append(tr) |
| else: |
| assert isinstance(inputs[0], torch.Tensor) |
| items = inputs |
|
|
| batch_size = len(items) |
| shape = items[0].shape |
| dim = len(shape) |
| assert dim <= 2 |
| if max_length is None: |
| max_length = 0 |
| max_length = max(max_length, max(item.shape[-1] for item in items)) |
| min_length = min(item.shape[-1] for item in items) |
| dtype = items[0].dtype |
|
|
| if dim == 0: |
| return torch.stack([item for item in items], dim=0), [0] |
| elif dim == 1: |
| if max_length == min_length: |
| return torch.stack([item for item in items], dim=0), [0] * batch_size |
| tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value |
| else: |
| tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value |
|
|
| padding_length = [] |
| for i, item in enumerate(items): |
| if dim == 1: |
| if padding_side == "left": |
| tensor[i, -len(item) :] = item.clone() |
| else: |
| tensor[i, : len(item)] = item.clone() |
| elif dim == 2: |
| if padding_side == "left": |
| tensor[i, -len(item) :, :] = item.clone() |
| else: |
| tensor[i, : len(item), :] = item.clone() |
| padding_length.append(tensor.shape[-1] - len(item)) |
|
|
| return tensor, padding_length |
|
|