diff --git "a/modeling_minicpmo.py" "b/modeling_minicpmo.py"
new file mode 100644--- /dev/null
+++ "b/modeling_minicpmo.py"
@@ -0,0 +1,5072 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2026 The OpenBMB Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import math
+import os
+import tempfile
+import threading
+import time
+import types
+from copy import deepcopy
+from dataclasses import dataclass
+from functools import partial
+from threading import Thread
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.nn.utils.parametrize as P
+from torch import nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils.parametrizations import weight_norm
+from tqdm import tqdm
+
+if os.getenv("USE_FLAGOS") == "1":
+ import importlib
+
+ flag_gems = importlib.import_module("flag_gems") # noqa: F401
+ flag_gems_experimental = importlib.import_module("flag_gems.experimental_ops")
+ gems_rmsnorm = flag_gems_experimental.rmsnorm
+
+ class GemsRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ return gems_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+ from transformers.models.llama import modeling_llama
+ from transformers.models.qwen3 import modeling_qwen3
+
+ modeling_qwen3.Qwen3RMSNorm = GemsRMSNorm
+ modeling_llama.LlamaRMSNorm = GemsRMSNorm
+
+from transformers import LlamaConfig
+from transformers import LlamaModel
+from transformers import PreTrainedModel
+from transformers import Qwen3ForCausalLM
+from transformers import Qwen3PreTrainedModel
+from transformers import TextIteratorStreamer
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache
+from transformers.cache_utils import DynamicCache
+from transformers.cache_utils import EncoderDecoderCache
+from transformers.cache_utils import StaticCache
+from transformers.generation.logits_process import TopKLogitsWarper
+from transformers.generation.logits_process import TopPLogitsWarper
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.modeling_outputs import ModelOutput
+from transformers.models.whisper.configuration_whisper import WhisperConfig
+from transformers.models.whisper.modeling_whisper import WhisperEncoder
+
+from .configuration_minicpmo import MiniCPMOConfig
+from .configuration_minicpmo import MiniCPMTTSConfig
+from .modeling_navit_siglip import SiglipVisionTransformer
+from .processing_minicpmo import MiniCPMOProcessor
+from .utils import as_dynamic_cache
+from .utils import ChunkPrefillChunkGenerate
+from .utils import drop_tokens_from_cache
+from .utils import DuplexWindowConfig
+from .utils import get_kv_cache_length
+from .utils import normalize_content
+from .utils import realign_rotary_suffix
+from .utils import SpeculativeSnapshot
+from .utils import streaming_token_decoder
+from .utils import StreamingWindowConfig
+from .utils import torch_clone_recursive
+from .utils import TTSSamplingParams
+from .utils import TTSStreamingGenerator
+
+logger = logging.getLogger(__name__)
+
+
+class MiniCPMOPreTrainedModel(Qwen3PreTrainedModel):
+ config_class = MiniCPMOConfig
+
+
+class MiniCPMO(MiniCPMOPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.llm = Qwen3ForCausalLM(config)
+ self.embed_dim = self.llm.config.hidden_size
+ self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm
+
+ # init vision module
+ if self.config.init_vision:
+ self.vpm = self.init_vision_module()
+ self.vision_dim = self.vpm.embed_dim
+ self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
+
+ # init audio module
+ if self.config.init_audio:
+ self.apm = self.init_audio_module()
+ audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
+ self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)
+ self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)
+ self.audio_encoder_layer = -1
+
+ # init tts module
+ if self.config.init_tts:
+ self.tts = self.init_tts_module()
+
+ self.terminators = ["<|im_end|>", "<|endoftext|>"]
+
+ self.think_str = ""
+ if self.llm.__class__.__name__ == "Qwen3ForCausalLM":
+ self.think_str = "\\n\\n\\n\\n"
+
+ # for streaming
+ self.reset_session(reset_token2wav_cache=True)
+
+ # streaming audio processing constants
+ self.SAMPLE_RATE = 16000
+ self.CHUNK_MS = 1000 # regular chunk length (ms)
+ self.FIRST_CHUNK_MS = 1035 # first chunk length (ms)
+ self.CNN_REDUNDANCY_MS = 0 # CNN redundancy (ms)
+
+ # for sliding window
+ self.streaming_window_config = StreamingWindowConfig()
+ self.streaming_require_system_prompt = True
+ self.streaming_window_enabled = True
+ self.force_rope_reindex = False # RoPE reindex testing switch
+
+ def init_streaming_processor(self):
+ self.prepare_processor(processor=None, tokenizer=None)
+
+ if hasattr(self.processor, "set_streaming_mode"):
+ self.processor.set_streaming_mode(
+ mode="exact",
+ chunk_ms=self.CHUNK_MS,
+ first_chunk_ms=self.FIRST_CHUNK_MS,
+ cnn_redundancy_ms=self.CNN_REDUNDANCY_MS,
+ enable_sliding_window=True,
+ slide_trigger_seconds=30.0,
+ slide_stride_seconds=10.0,
+ )
+ self.processor.reset_streaming()
+ self.audio_chunk_idx = 0
+
+ def reset_session(self, reset_token2wav_cache=True):
+ self.llm_past_key_values = None
+ self.audio_past_key_values = None
+ self.tts_last_turn_tokens = None
+ self.llm_generated = False # last turn generated by llm or not
+ self.llm_generate_completed = False
+ self.new_user_msg = True
+
+ self.session_id = None
+
+ if reset_token2wav_cache:
+ self.token2wav_cache = None
+
+ # for sliding window
+ self.streaming_text_preserve = 0
+ self.streaming_position_offset = 0
+
+ self._rope_inv_freq_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
+
+ self._next_round_id = 0
+ self._pending_round_id = None
+
+ self._omni_chunk_history: List[Dict[str, Union[str, int]]] = []
+ self._round_history: List[Dict[str, Union[int, str, torch.Tensor, Optional[int]]]] = []
+
+ def init_vision_module(self):
+ if self.config._attn_implementation == "flash_attention_2":
+ self.config.vision_config._attn_implementation = "flash_attention_2"
+ else:
+ self.config.vision_config._attn_implementation = "eager"
+ model = SiglipVisionTransformer(self.config.vision_config)
+ if self.config.drop_vision_last_layer:
+ model.encoder.layers = model.encoder.layers[:-1]
+
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
+ setattr(model, "patch_size", model.embeddings.patch_size)
+
+ return model
+
+ def init_resampler(self, embed_dim, vision_dim):
+ return Resampler(
+ num_queries=self.config.query_num,
+ embed_dim=embed_dim,
+ num_heads=embed_dim // 128,
+ kv_dim=vision_dim,
+ adaptive=True,
+ )
+
+ def init_audio_module(self):
+ if self.config._attn_implementation == "eager":
+ self.config.audio_config._attn_implementation = "eager"
+ else:
+ # using flash_attention_2 will cause: RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
+ self.config.audio_config._attn_implementation = "sdpa"
+
+ return MiniCPMWhisperEncoder(self.config.audio_config)
+
+ def init_tts_module(self):
+ if self.config._attn_implementation == "flash_attention_2":
+ self.config.tts_config.attn_implementation = "flash_attention_2"
+ else:
+ self.config.tts_config.attn_implementation = "eager"
+
+ return MiniCPMTTS(config=self.config.tts_config, audio_tokenizer=None)
+
+ def _ensure_asset_dir(self, asset_subpath: str, model_dir: Optional[str] = None) -> str:
+ """Ensure asset directory exists, downloading from HF if needed."""
+ model_dir = model_dir or os.path.join(self.config._name_or_path, asset_subpath)
+ if not os.path.exists(model_dir):
+ from huggingface_hub import snapshot_download
+
+ repo_dir = snapshot_download(
+ repo_id="openbmb/MiniCPM-o-4_5",
+ allow_patterns=[f"{asset_subpath}/**"],
+ )
+ model_dir = os.path.join(repo_dir, asset_subpath)
+ assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}"
+ return model_dir
+
+ def init_tts(self, model_dir=None, enable_float16=False, n_timesteps=10, **kwargs):
+ if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio":
+ logger.warning("audio tokenizer type is set to s3tokenizer_step_audio")
+ self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
+
+ try:
+ from stepaudio2 import Token2wav
+ except ImportError:
+ raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]")
+
+ model_dir = self._ensure_asset_dir("assets/token2wav", model_dir)
+ self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
+ return self.tts.audio_tokenizer
+
+ def get_input_embeddings(self):
+ return self.llm.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.llm.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.llm.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.llm.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.llm = decoder
+
+ def get_decoder(self):
+ return self.llm
+
+ @staticmethod
+ def get_sys_prompt(ref_audio=None, mode="default", language="en", ref_audio_max_ms=None):
+ if ref_audio is not None:
+ if isinstance(ref_audio, str):
+ import os
+
+ import librosa
+
+ if os.path.isfile(ref_audio):
+ duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None
+ ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration)
+ else:
+ logger.error(f"Could not find {ref_audio}")
+ ref_audio = None
+
+ assert isinstance(ref_audio, np.ndarray), "ref_audio error"
+
+ if mode == "omni":
+ if language == "zh":
+ sys_prompt = ""
+ vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。"
+ vc_prompt_suffix = (
+ "请用这种声音风格来为用户提供帮助。 请认真、高质量地回复用户的问题。 请用高自然度的方式和用户聊天。"
+ )
+ else:
+ sys_prompt = ""
+ vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt."
+ vc_prompt_suffix = "As an assistant, you will speak using this voice style."
+
+ if ref_audio is not None:
+ sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
+ else:
+ sys_msgs = {"role": "system", "content": [sys_prompt]}
+
+ return sys_msgs
+ elif mode == "audio_assistant":
+ if language == "zh":
+ vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。"
+ vc_prompt_suffix = "你的任务是用这种声音模式来当一个助手。请认真、高质量地回复用户的问题。请用高自然度的方式和用户聊天。你是由面壁智能开发的人工智能助手:面壁小钢炮。"
+ else:
+ vc_prompt_prefix = "Clone the voice in the provided audio prompt."
+ vc_prompt_suffix = "Please assist users while maintaining this voice style. Please answer the user's questions seriously and in a high quality. Please chat with the user in a highly human-like and oral style. You are a helpful assistant developed by ModelBest: MiniCPM-Omni."
+
+ if ref_audio is not None:
+ sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
+ else:
+ logger.warning(
+ "Warning: ref_audio is None, speech generation will be performed based on the default voice."
+ )
+ sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]}
+
+ return sys_msgs
+ elif mode == "audio_roleplay":
+ if language == "zh":
+ vc_prompt_prefix = "模仿输入音频中的声音特征。"
+ vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。"
+ else:
+ vc_prompt_prefix = "Clone the voice in the provided audio prompt."
+ vc_prompt_suffix = "Try to role-play the character based on the audio prompt above."
+
+ if ref_audio is not None:
+ sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
+ else:
+ sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]}
+
+ return sys_msgs
+ elif mode == "voice_cloning":
+ if language == "zh":
+ vc_prompt_prefix = "模仿输入音频中的声音特征。"
+ else:
+ vc_prompt_prefix = "Clone the voice in the provided audio prompt."
+
+ if ref_audio is not None:
+ sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio]}
+ else:
+ raise ValueError("ref_audio con't be None in voice_cloning mode.")
+
+ return sys_msgs
+ else:
+ sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text."
+ sys_msgs = {"role": "system", "content": [sys_prompt]}
+
+ return sys_msgs
+
+ @staticmethod
+ def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+ num_lookhead: int = 0,
+ ) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+ num_lookhead:
+
+ Returns:
+ torch.Tensor: mask
+ """
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ for i in range(size):
+ if num_left_chunks < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+ ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
+ ret[i, start:ending] = True
+ return ret
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """Computes the output length of the convolutional layers and the output length of the audio encoder"""
+ input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
+ input_lengths_after_pooling = (
+ input_lengths_after_cnn - self.config.audio_pool_step
+ ) // self.config.audio_pool_step + 1
+ input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
+
+ return input_lengths_after_cnn, input_lengths_after_pooling
+
+ def get_vision_embedding(self, data):
+ if "vision_hidden_states" not in data:
+ dtype = self.llm.model.embed_tokens.weight.dtype
+ device = self.llm.model.embed_tokens.weight.device
+ tgt_sizes = data["tgt_sizes"]
+ pixel_values_list = data["pixel_values"]
+ vision_hidden_states = []
+ all_pixel_values = []
+ img_cnt = []
+ for pixel_values in pixel_values_list:
+ img_cnt.append(len(pixel_values))
+ all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
+
+ # exist image
+ if all_pixel_values:
+ tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
+ tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
+
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
+
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
+ all_pixel_values, batch_first=True, padding_value=0.0
+ )
+ B, L, _ = all_pixel_values.shape
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+
+ patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
+ for i in range(B):
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
+
+ vision_batch_size = self.config.vision_batch_size
+ all_pixel_values = all_pixel_values.type(dtype)
+ if B > vision_batch_size:
+ hs = []
+ for i in range(0, B, vision_batch_size):
+ start_idx = i
+ end_idx = i + vision_batch_size
+ tmp_hs = self.vpm(
+ all_pixel_values[start_idx:end_idx],
+ patch_attention_mask=patch_attn_mask[start_idx:end_idx],
+ tgt_sizes=tgt_sizes[start_idx:end_idx],
+ ).last_hidden_state
+ hs.append(tmp_hs)
+ vision_embedding = torch.cat(hs, dim=0)
+ else:
+ vision_embedding = self.vpm(
+ all_pixel_values,
+ patch_attention_mask=patch_attn_mask,
+ tgt_sizes=tgt_sizes,
+ ).last_hidden_state
+ vision_embedding = self.resampler(vision_embedding, tgt_sizes)
+
+ start = 0
+ for pixel_values in pixel_values_list:
+ img_cnt = len(pixel_values)
+ if img_cnt > 0:
+ vision_hidden_states.append(vision_embedding[start : start + img_cnt])
+ start += img_cnt
+ else:
+ vision_hidden_states.append([])
+ else: # no image
+ if self.training:
+ dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
+ tgt_sizes = torch.Tensor(
+ [
+ [
+ (224 // self.config.patch_size),
+ math.ceil(224 / self.config.patch_size),
+ ]
+ ]
+ ).type(torch.int32)
+ dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
+ else:
+ dummy_feature = []
+ for _ in range(len(pixel_values_list)):
+ vision_hidden_states.append(dummy_feature)
+ else:
+ vision_hidden_states = data["vision_hidden_states"]
+
+ return vision_hidden_states
+
+ def get_vllm_embedding(self, data):
+ vision_hidden_states = self.get_vision_embedding(data)
+
+ if hasattr(self.llm.config, "scale_emb"):
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
+ else:
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
+
+ vision_hidden_states = [
+ i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
+ ]
+
+ bs = len(data["input_ids"])
+ for i in range(bs):
+ cur_vs_hs = vision_hidden_states[i]
+ if len(cur_vs_hs) > 0:
+ cur_vllm_emb = vllm_embedding[i]
+ cur_image_bound = data["image_bound"][i]
+ if len(cur_image_bound) > 0:
+ image_indices = torch.stack(
+ [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
+ ).to(vllm_embedding.device)
+
+ cur_vllm_emb.scatter_(
+ 0,
+ image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
+ cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
+ )
+ elif self.training:
+ cur_vllm_emb += cur_vs_hs[0].mean() * 0
+
+ return vllm_embedding, vision_hidden_states
+
+ def get_audio_embedding_streaming(
+ self,
+ data,
+ use_extra_context=False,
+ prefix_extra_frames=1,
+ suffix_extra_frames=1,
+ cnn_min_length=None,
+ ):
+ """Extract audio embeddings in a streaming manner using cached key-value pairs.
+
+ This method processes incoming audio features incrementally and stores/updates `past_key_values`
+ for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
+ for streaming scenarios.
+
+ Args:
+ data (dict):
+ - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
+ - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
+ use_extra_context (bool): If True, assumes input contains extra frames for CNN context.
+ prefix_extra_frames (int): Number of prefix extra frames.
+ suffix_extra_frames (int): Number of suffix extra frames.
+ cnn_min_length (int): Minimum length for CNN input padding.
+
+ Returns:
+ List[List[torch.Tensor]]: audio embeddings
+ """
+ wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
+ audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
+
+ # exist audio
+ if len(wavforms) > 0:
+ audio_feature_lens = torch.hstack(audio_feature_lens_raw)
+ batch_size, _, max_mel_seq_len = wavforms.shape
+ assert batch_size == 1
+ max_seq_len = (max_mel_seq_len - 1) // 2 + 1
+
+ # whisper's past_key_values management (core)
+ if self.audio_past_key_values is not None:
+ cache_length = self.audio_past_key_values[0][0].shape[2]
+ apm_max_len = self.apm.embed_positions.weight.shape[0]
+ if cache_length + max_seq_len >= apm_max_len:
+ logger.warning(
+ f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
+ )
+ self.audio_past_key_values = None
+
+ # build attention mask (bidirectional attention, same as offline mode)
+ batch_size, _, max_mel_seq_len = wavforms.shape
+ current_seq_len = (max_mel_seq_len - 1) // 2 + 1
+ # if use extra context, need to adjust sequence length
+ if use_extra_context:
+ # calculate actual sequence length after removing redundancy
+ # conv2's stride=2, so the mapping from mel frames to output frames is ceil(x/2)
+ prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0
+ suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0
+ current_seq_len = current_seq_len - prefix_to_remove - suffix_to_remove
+ # calculate history length (if there is KV cache)
+ if self.audio_past_key_values is not None:
+ past_len = self.audio_past_key_values[0][0].shape[2] # get history sequence length
+ total_seq_len = past_len + current_seq_len
+ else:
+ past_len = 0
+ total_seq_len = current_seq_len
+ # create bidirectional attention mask (full attention)
+ audio_attention_mask = torch.zeros(
+ (batch_size, 1, current_seq_len, total_seq_len),
+ dtype=self.apm.conv1.weight.dtype,
+ device=wavforms.device,
+ )
+
+ # Step 1: APM processing
+ audio_outputs = self.apm(
+ wavforms,
+ past_key_values=self.audio_past_key_values,
+ use_cache=True,
+ output_hidden_states=True,
+ attention_mask=audio_attention_mask,
+ use_extra_context=use_extra_context,
+ prefix_extra_frames=prefix_extra_frames,
+ suffix_extra_frames=suffix_extra_frames,
+ cnn_min_length=cnn_min_length,
+ )
+
+ if hasattr(self, "audio_encoder_layer"):
+ audio_states = audio_outputs.hidden_states[self.audio_encoder_layer]
+ else:
+ audio_states = audio_outputs.last_hidden_state
+
+ self.audio_past_key_values = audio_outputs.past_key_values
+
+ # Step 2: Projection
+ audio_embeds = self.audio_projection_layer(audio_states)
+
+ # Step 3: Pooling
+ audio_embeds = audio_embeds.transpose(1, 2)
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
+ audio_embeds = audio_embeds.transpose(1, 2)
+
+ _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
+
+ num_audio_tokens = feature_lens_after_pooling
+
+ final_audio_embeds = []
+ idx = 0
+ for i in range(len(audio_feature_lens_raw)):
+ target_audio_embeds = []
+ for _ in range(len(audio_feature_lens_raw[i])):
+ target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
+ idx += 1
+ final_audio_embeds.append(target_audio_embeds)
+
+ return final_audio_embeds
+ else:
+ return final_audio_embeds
+ else:
+ return []
+
+ def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
+ wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
+ audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
+
+ if len(wavforms) > 0:
+ audio_feature_lens = torch.hstack(audio_feature_lens_raw)
+ batch_size, _, max_mel_seq_len = wavforms.shape
+ max_seq_len = (max_mel_seq_len - 1) // 2 + 1
+
+ # Create a sequence tensor of shape (batch_size, max_seq_len)
+ seq_range = (
+ torch.arange(
+ 0,
+ max_seq_len,
+ dtype=audio_feature_lens.dtype,
+ device=audio_feature_lens.device,
+ )
+ .unsqueeze(0)
+ .expand(batch_size, max_seq_len)
+ )
+ lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
+ # Create mask
+ padding_mask = seq_range >= lengths_expand # 1 for padded values
+
+ audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
+ batch_size, 1, max_seq_len, max_seq_len
+ )
+ audio_attention_mask = audio_attention_mask_.to(
+ dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
+ )
+
+ if chunk_length > 0:
+ chunk_num_frame = int(chunk_length * 50)
+ chunk_mask = self.subsequent_chunk_mask(
+ size=max_seq_len,
+ chunk_size=chunk_num_frame,
+ num_left_chunks=-1,
+ device=audio_attention_mask_.device,
+ )
+ audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask))
+
+ audio_attention_mask[audio_attention_mask_] = float("-inf")
+ audio_states = self.apm(
+ wavforms, output_hidden_states=True, attention_mask=audio_attention_mask
+ ).hidden_states[self.audio_encoder_layer]
+ audio_embeds = self.audio_projection_layer(audio_states)
+
+ audio_embeds = audio_embeds.transpose(1, 2)
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
+ audio_embeds = audio_embeds.transpose(1, 2)
+
+ _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
+
+ num_audio_tokens = feature_lens_after_pooling
+
+ final_audio_embeds = []
+ idx = 0
+ for i in range(len(audio_feature_lens_raw)):
+ target_audio_embeds = []
+ for _ in range(len(audio_feature_lens_raw[i])):
+ target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
+ idx += 1
+ final_audio_embeds.append(target_audio_embeds)
+ return final_audio_embeds
+ elif self.training and dummy:
+ dtype = self.apm.embed_positions.weight.dtype
+ device = self.apm.embed_positions.weight.device
+
+ dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
+ audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
+
+ audio_embeds = self.audio_projection_layer(audio_states)
+
+ audio_embeds = audio_embeds.transpose(1, 2)
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
+ audio_embeds = audio_embeds.transpose(1, 2)
+ return [audio_embeds]
+ else:
+ return []
+
+ def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False):
+ """
+ Args:
+ data:
+ input_embeddings:
+ chunk_length: whisper use full attention or chunk attention
+ stream_input: use streaming audio embedding or not
+
+ Returns:
+ final embeddings with audio feature
+ """
+ if stream_input:
+ audio_embeddings = self.get_audio_embedding_streaming(data)
+ else:
+ audio_embeddings = self.get_audio_embedding(data, chunk_length)
+
+ bs = len(input_embeddings)
+ if len(data.get("audio_features", [])) > 0:
+ assert len(audio_embeddings) == len(input_embeddings)
+
+ if len(audio_embeddings) > 0:
+ audio_bounds = data["audio_bounds"]
+
+ if self.config.stream_input:
+ assert bs == 1, "audio stream_input mode only support batch size 1"
+ for i in range(bs):
+ audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
+ device=input_embeddings.device, dtype=input_embeddings.dtype
+ )
+ audio_start_pos = 0
+ for bound in audio_bounds[i]:
+ audio_len = bound[1] - bound[0]
+ input_embeddings[i, bound[0] : bound[1]] = audio_embs[
+ audio_start_pos : audio_start_pos + audio_len, :
+ ]
+ audio_start_pos += audio_len
+ else:
+ for i in range(bs):
+ audio_embs = audio_embeddings[i]
+ bounds = audio_bounds[i]
+ for embs, bound in zip(audio_embs, bounds):
+ audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to(
+ input_embeddings.device
+ )
+
+ if embs.shape[0] != len(audio_indices):
+ raise ValueError(
+ f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
+ f"to input indices of length {len(audio_indices)}"
+ )
+ input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype)
+ elif self.training:
+ for i in range(bs):
+ # dummy audio_embedings
+ input_embeddings += audio_embeddings[0].mean() * 0
+
+ return input_embeddings
+
+ def forward(self, data, **kwargs):
+ vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
+ vllm_embedding = self.get_omni_embedding(
+ data,
+ input_embeddings=vllm_embedding,
+ chunk_length=self.config.audio_chunk_length,
+ )
+
+ position_ids = data["position_ids"]
+ if position_ids.dtype != torch.int64:
+ position_ids = position_ids.long()
+
+ return self.llm(
+ input_ids=None,
+ position_ids=position_ids,
+ inputs_embeds=vllm_embedding,
+ **kwargs,
+ )
+
+ def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+ outputs = self.llm.generate(
+ inputs_embeds=inputs_embeds,
+ pad_token_id=0,
+ eos_token_id=terminators,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ return_dict_in_generate=True,
+ **kwargs,
+ )
+ return outputs
+
+ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
+ generation_config = {
+ "inputs_embeds": inputs_embeds,
+ "pad_token_id": 0,
+ "eos_token_id": terminators,
+ "streamer": streamer,
+ }
+ generation_config.update(kwargs)
+ thread = Thread(target=self.llm.generate, kwargs=generation_config)
+ thread.start()
+ return streamer
+
+ def _decode_text(self, result_ids, tokenizer):
+ terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
+ result_text = []
+ for result in result_ids:
+ result = result[result != 0]
+ if result[0] == tokenizer.bos_id:
+ result = result[1:]
+ if result[-1] in terminators:
+ result = result[:-1]
+ result_text.append(tokenizer.decode(result))
+ return result_text
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ input_ids=None,
+ pixel_values=None,
+ tgt_sizes=None,
+ audio_features=None,
+ audio_feature_lens=None,
+ image_bound=None,
+ audio_bounds=None,
+ spk_bounds=None,
+ attention_mask=None,
+ tokenizer=None,
+ vision_hidden_states=None,
+ stream=False,
+ **kwargs,
+ ):
+ assert input_ids is not None
+ assert len(input_ids) == len(pixel_values)
+
+ model_inputs = {
+ "input_ids": input_ids,
+ "audio_features": audio_features,
+ "audio_feature_lens": audio_feature_lens,
+ "image_bound": image_bound,
+ "audio_bounds": audio_bounds,
+ "spk_bounds": spk_bounds,
+ }
+
+ if vision_hidden_states is None:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["tgt_sizes"] = tgt_sizes
+ else:
+ model_inputs["vision_hidden_states"] = vision_hidden_states
+
+ with torch.inference_mode():
+ model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
+ model_inputs["inputs_embeds"] = self.get_omni_embedding(
+ model_inputs,
+ input_embeddings=model_inputs["inputs_embeds"],
+ chunk_length=self.config.audio_chunk_length,
+ )
+
+ if stream:
+ result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
+ outputs = {} # if stream return TextIteratorStreamer and output is empty
+ else:
+ outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
+ result = self._decode_text(outputs.sequences, tokenizer)
+
+ return result, outputs
+
+ def _build_streaming_mask(self, tts_tokens_len):
+ tts_sequence_full_length = 1 + self.tts.streaming_text_reserved_len + 1
+ streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8)
+ streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1
+ streaming_attention_mask[-1] = 1
+ return streaming_attention_mask
+
+ def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048):
+ spk_embeds = self._get_last_spk_embeds(inputs, outputs)
+
+ text = text.split("<|tts_bos|>")[-1]
+ gen_text = text.split("<|tts_eos|>")[0]
+ tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
+ tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
+ tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
+ streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
+
+ logits_warpers, logits_processors = gen_logits(
+ num_code=626,
+ top_p=self.tts.top_p,
+ top_k=self.tts.top_k,
+ repetition_penalty=self.tts.repetition_penalty,
+ )
+
+ condition_length = 1 + self.tts.streaming_text_reserved_len + 1
+
+ dtype = self.tts.emb_text.weight.dtype
+ emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device)
+ past_key_values = [
+ (
+ torch.zeros(
+ 1,
+ self.tts.config.num_attention_heads,
+ condition_length - 1,
+ self.tts.config.hidden_size // self.tts.config.num_attention_heads,
+ dtype=emb.dtype,
+ device=self.tts.device,
+ ),
+ torch.zeros(
+ 1,
+ self.tts.config.num_attention_heads,
+ condition_length - 1,
+ self.tts.config.hidden_size // self.tts.config.num_attention_heads,
+ dtype=emb.dtype,
+ device=self.tts.device,
+ ),
+ )
+ for _ in range(self.tts.config.num_hidden_layers)
+ ]
+
+ audio_input_ids = torch.zeros(
+ 1,
+ condition_length,
+ self.tts.num_vq,
+ dtype=torch.long,
+ device=self.tts.device,
+ )
+
+ eos_lab = False
+ for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)):
+ if chunk_idx == 0:
+ begin = chunk_idx * self.tts.streaming_text_chunk_size + 0
+ end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1
+ else:
+ begin = chunk_idx * self.tts.streaming_text_chunk_size + 1
+ end = min(
+ (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1,
+ condition_length - 1,
+ )
+
+ if end - begin > 0:
+ text_input_ids = tts_input_ids[:, begin:end]
+ position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
+
+ if begin == 0:
+ past_key_values = self.tts.prefill_text(
+ input_ids=text_input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ lm_spk_emb_last_hidden_states=spk_embeds,
+ )
+ else:
+ past_key_values = self.tts.prefill_text(
+ input_ids=text_input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ )
+
+ outputs = self.tts.generate(
+ input_ids=audio_input_ids,
+ past_key_values=past_key_values,
+ streaming_tts_text_mask=streaming_tts_text_mask,
+ max_new_token=output_chunk_size,
+ force_no_stop=self.force_no_stop,
+ temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+ eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+ logits_warpers=logits_warpers,
+ logits_processors=logits_processors,
+ )
+ audio_input_ids = outputs.audio_input_ids
+ past_key_values = outputs.past_key_values
+
+ if outputs.finished:
+ eos_lab = True
+ break
+
+ if not eos_lab:
+ while True:
+ outputs = self.tts.generate(
+ input_ids=audio_input_ids,
+ past_key_values=past_key_values,
+ streaming_tts_text_mask=streaming_tts_text_mask,
+ max_new_token=output_chunk_size,
+ force_no_stop=self.force_no_stop,
+ temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
+ eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
+ logits_warpers=logits_warpers,
+ logits_processors=logits_processors,
+ )
+
+ audio_input_ids = outputs.audio_input_ids
+ past_key_values = outputs.past_key_values
+
+ if outputs.finished:
+ break
+ if outputs.new_ids.shape[1] > tts_max_new_tokens:
+ break
+
+ @staticmethod
+ def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs):
+ num_beams = kwargs.get("num_beams", 3)
+ generation_config = {
+ "num_beams": num_beams,
+ "top_p": 0.8,
+ "top_k": 100,
+ "temperature": 0.7,
+ "do_sample": True,
+ "repetition_penalty": 1.02,
+ }
+
+ if do_sample:
+ generation_config.update(
+ {
+ "top_p": 0.8,
+ "top_k": 100,
+ "temperature": 0.7,
+ "do_sample": True,
+ "repetition_penalty": 1.02,
+ }
+ )
+ elif num_beams > 1:
+ generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False})
+ else:
+ generation_config.update({"do_sample": False, "repetition_penalty": 1.02})
+
+ generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
+ generation_config["min_new_tokens"] = min_new_tokens
+ generation_config["max_new_tokens"] = max_new_tokens
+
+ return generation_config
+
+ def prepare_processor(self, processor=None, tokenizer=None):
+ if processor is not None:
+ self.processor = processor
+ if not hasattr(self, "processor") or self.processor is None:
+ self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
+ if tokenizer is not None:
+ self.processor.tokenizer = tokenizer
+
+ @torch.inference_mode()
+ def chat(
+ self,
+ image=None,
+ msgs=None,
+ vision_hidden_states=None,
+ max_new_tokens=4096,
+ min_new_tokens=0,
+ do_sample=True,
+ max_inp_length=8192,
+ max_slice_nums=None,
+ use_image_id=None,
+ enable_thinking=False,
+ use_tts_template=False,
+ generate_audio=False,
+ output_audio_path=None,
+ output_tts_inputs_embeds_path=None,
+ omni_mode=False,
+ teacher_forcing=False,
+ return_prompt=False,
+ tts_proj_layer=-1,
+ tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
+ merge_audio_from_same_content=True,
+ stream=False,
+ stream_input=False,
+ tokenizer=None,
+ processor=None,
+ **kwargs,
+ ):
+ from PIL import Image
+
+ batched = isinstance(msgs[0], list)
+ msgs_list = msgs
+ images_list = image
+
+ if not batched:
+ images_list, msgs_list = [images_list], [msgs_list]
+ else:
+ assert images_list is None, "Please integrate image to msgs when using batch inference."
+ images_list = [None] * len(msgs_list)
+ assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
+
+ self.prepare_processor(processor=processor, tokenizer=tokenizer)
+
+ prompts_lists = []
+ input_images_list = []
+ input_audios_list = []
+ audio_parts_list = []
+
+ for image, msgs in zip(images_list, msgs_list):
+ if isinstance(msgs, str):
+ msgs = json.loads(msgs)
+ copy_msgs = deepcopy(msgs)
+
+ assert len(msgs) > 0, "msgs is empty"
+ assert do_sample or not stream, "if use stream mode, make sure do_sample=True"
+
+ if image is not None and isinstance(copy_msgs[0]["content"], str):
+ copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
+
+ images = []
+ audios = []
+ audio_parts = []
+ for i, msg in enumerate(copy_msgs):
+ role = msg["role"]
+ content = msg["content"]
+ assert role in ["system", "user", "assistant"]
+ if i == 0:
+ assert role in ["user", "system"], "The role of first msg should be user"
+ # Normalize structured content (OpenAI format) to native format
+ content = normalize_content(content)
+ cur_msgs = []
+ for c in content:
+ if isinstance(c, Image.Image):
+ images.append(c)
+ cur_msgs.append("./")
+ elif isinstance(c, np.ndarray): # audio
+ audios.append(c)
+ audio_parts.append(i)
+ cur_msgs.append("")
+ use_tts_template = True
+ elif isinstance(c, str):
+ cur_msgs.append(c)
+
+ if omni_mode or stream_input:
+ msg["content"] = "".join(cur_msgs)
+ else:
+ msg["content"] = "\n".join(cur_msgs)
+
+ prompts_lists.append(
+ self.processor.tokenizer.apply_chat_template(
+ copy_msgs,
+ tokenize=False,
+ add_generation_prompt=False if teacher_forcing else True,
+ use_tts_template=use_tts_template,
+ enable_thinking=enable_thinking,
+ )
+ )
+ input_images_list.append(images)
+ input_audios_list.append(audios)
+ audio_parts_list.append(audio_parts)
+
+ if not merge_audio_from_same_content:
+ audio_parts_list = None
+
+ inputs = self.processor(
+ prompts_lists,
+ input_images_list,
+ input_audios_list,
+ audio_parts_list,
+ max_slice_nums=max_slice_nums,
+ use_image_id=use_image_id,
+ stream_input=stream_input,
+ return_tensors="pt",
+ max_length=max_inp_length,
+ ).to(self.device)
+
+ if stream:
+ kwargs["num_beams"] = 1
+ generation_config = self.prepare_generation_config(
+ do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs
+ )
+ generation_config.pop("max_new_tokens", None)
+
+ inputs.pop("image_sizes")
+
+ # teacher_forcing = True => generate audio with given text
+ with torch.inference_mode():
+ res, outputs = self.generate(
+ **inputs,
+ tokenizer=self.processor.tokenizer,
+ max_new_tokens=1 if teacher_forcing else max_new_tokens,
+ vision_hidden_states=vision_hidden_states,
+ stream=stream,
+ **generation_config,
+ )
+
+ if stream:
+ return res
+
+ # spk bound and tts bound
+ tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
+ tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
+
+ # Combine input_ids and generated sequences to get complete sequence
+ input_ids = inputs["input_ids"][0]
+ generated_ids = outputs.sequences[0]
+ # Combine by concatenating input_ids with the new tokens from generated sequence
+ full_sequence = torch.cat([input_ids, generated_ids])
+ # Update the sequences in outputs
+ full_sequences = full_sequence.unsqueeze(0)
+
+ outputs["full_sequences"] = full_sequences
+
+ tts_bos_indices = []
+ tts_eos_indices = []
+ for i, x in enumerate(full_sequences[0]):
+ if x == tts_bos_token:
+ # tts_bos + 1 is the position of the first tts, so that it is convenient to slice hidden states for tts
+ tts_bos_indices.append(i + 1)
+ elif x == tts_eos_token:
+ if teacher_forcing and i == len(full_sequences[0]) - 1:
+ continue
+ tts_eos_indices.append(i)
+
+ tts_bos_idx = tts_bos_indices[-1] if tts_bos_indices else -1
+ # Use None instead of -1 when no EOS token found, so that slice [start:None]
+ # means "to the end" rather than [start:-1] which excludes the last element
+ tts_eos_idx = tts_eos_indices[-1] if tts_eos_indices else None
+
+ tts_bound = (tts_bos_idx, tts_eos_idx)
+
+ answer = res[0]
+ if answer is not None:
+ answer = answer.split("<|tts_eos|>")[0]
+
+ if use_tts_template and generate_audio and output_audio_path:
+ import soundfile as sf
+
+ try:
+ generated_waveform = self._generate_speech_non_streaming(
+ outputs=outputs,
+ tts_bound=tts_bound,
+ tts_proj_layer=tts_proj_layer,
+ audio_prompt=(
+ input_audios_list[0][0]
+ if len(input_audios_list) > 0 and len(input_audios_list[0]) > 0
+ else None
+ ),
+ output_tts_inputs_embeds_path=output_tts_inputs_embeds_path,
+ tts_sampling_params=tts_sampling_params,
+ )
+ if isinstance(generated_waveform, torch.Tensor):
+ sf.write(output_audio_path, generated_waveform.cpu().numpy(), samplerate=24000)
+ elif isinstance(generated_waveform, np.ndarray):
+ sf.write(output_audio_path, generated_waveform, samplerate=24000)
+ logger.debug(f"audio saved to {output_audio_path}")
+ except:
+ import traceback
+
+ traceback.print_exc()
+
+ if return_prompt:
+ return answer, prompts_lists[0]
+ else:
+ return answer
+
+ @torch.inference_mode()
+ def _generate_speech_non_streaming(
+ self,
+ outputs,
+ tts_bound,
+ tts_proj_layer,
+ audio_prompt,
+ output_tts_inputs_embeds_path=None,
+ tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
+ ):
+ last_hidden_states = [hs[tts_proj_layer] for hs in outputs.hidden_states]
+ last_hidden_states = torch.vstack([i[0] for i in last_hidden_states])
+
+ spk_embeds = (
+ torch.ones([0, self.tts.config.hidden_size]).to(last_hidden_states.device).to(last_hidden_states.dtype)
+ )
+
+ if self.tts.condition_type == "hidden_text_merge":
+ llm_tokens = outputs["full_sequences"][0][tts_bound[0] : tts_bound[1]]
+ llm_tokens = torch.tensor(llm_tokens, device=self.tts.emb_text.weight.device, dtype=torch.long)
+ llm_embeds = self.tts.emb_text(llm_tokens) # make sure emb_text is compatible with llm vocab size
+
+ hidden_embeds = last_hidden_states[tts_bound[0] : tts_bound[1]]
+ hidden_embeds = self.tts.projector_semantic(hidden_embeds)
+
+ if self.tts.config.normalize_projected_hidden:
+ hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
+
+ tts_embeds = llm_embeds + hidden_embeds
+ else:
+ raise NotImplementedError
+
+ audio_bos = [self.tts.audio_bos_token_id]
+ audio_bos = torch.tensor(audio_bos, device=self.tts.emb_text.weight.device, dtype=torch.long)
+
+ audio_bos_embeds = self.tts.emb_text(audio_bos)
+
+ text_eos_embed = self.tts.emb_text(
+ torch.tensor(
+ [self.tts.config.text_eos_token_id],
+ device=self.tts.emb_text.weight.device,
+ dtype=torch.long,
+ )
+ )
+
+ inputs_embeds = torch.cat([spk_embeds, tts_embeds, text_eos_embed, audio_bos_embeds], dim=0).unsqueeze(0)
+
+ # save inputs_embeds to file
+ if output_tts_inputs_embeds_path:
+ torch.save(inputs_embeds, output_tts_inputs_embeds_path)
+
+ outputs = self.tts.generate(
+ inputs_embeds=inputs_embeds,
+ sampling_params=tts_sampling_params,
+ eos_token=torch.tensor(
+ [self.tts.config.num_audio_tokens - 1],
+ dtype=torch.long,
+ device=self.tts.device,
+ ),
+ )
+
+ import io
+
+ import soundfile as sf
+
+ generated_tokens = outputs.new_ids.squeeze(-1)
+ reference_audio = audio_prompt
+ prompt_wav_path = None
+ if reference_audio is not None:
+ logger.debug("use reference audio in data to generate waveform")
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
+ prompt_wav_path = tmp_wav.name
+ sf.write(prompt_wav_path, reference_audio, 16000)
+ wav_bytes = self.tts.audio_tokenizer(
+ generated_tokens.squeeze(0).tolist(),
+ prompt_wav_path,
+ )
+ # convert wav bytes back to tensor for caller compatibility
+ waveform, sr = sf.read(io.BytesIO(wav_bytes))
+ return torch.tensor(waveform, dtype=torch.float32)
+
+ @torch.inference_mode()
+ def init_token2wav_cache(self, prompt_speech_16k):
+ import soundfile as sf
+
+ if hasattr(self.tts.audio_tokenizer, "set_stream_cache"):
+ self.tts.audio_tokenizer.cache = None
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
+ prompt_wav_path = tmp_wav.name
+ sf.write(prompt_wav_path, prompt_speech_16k, 16000)
+ flow_cache_base, hift_cache_base = self.tts.audio_tokenizer.set_stream_cache(prompt_wav_path)
+
+ self.token2wav_cache = {
+ "flow_cache_base": torch_clone_recursive(flow_cache_base),
+ "hift_cache_base": torch_clone_recursive(hift_cache_base),
+ }
+ else:
+ model_input = self.tts.audio_tokenizer.frontend.frontend_token2wav(
+ speech_tokens=torch.zeros(1, 1, dtype=torch.long, device=self.tts.device),
+ speech_16k=None,
+ prompt_speech_16k=prompt_speech_16k,
+ resample_rate=self.tts.audio_tokenizer.sample_rate,
+ prompt_speech=None,
+ )
+
+ prompt_token = model_input["flow_prompt_speech_token"]
+ prompt_feat = model_input["prompt_speech_feat"]
+ embedding = model_input["flow_embedding"]
+
+ if self.tts.audio_tokenizer.fp16:
+ prompt_feat = prompt_feat.to(torch.half)
+ embedding = embedding.to(torch.half)
+
+ prepared_cache = self.tts.audio_tokenizer.model.prepare_cache_from_prompt(
+ prompt_token=prompt_token,
+ prompt_feat=prompt_feat,
+ embedding=embedding,
+ n_timesteps=self.tts.config.s3_stream_n_timesteps,
+ code_chunk_size=self.tts.config.s3_stream_chunk_size,
+ chunk_prelook_size=self.tts.config.s3_stream_prelook_size,
+ use_attn_idx=False,
+ )
+
+ self.token2wav_cache = prepared_cache
+
+ # for sliding window
+ def _ensure_dynamic_cache(self):
+ cache = self.llm_past_key_values
+ if cache is None:
+ return None
+
+ cache = as_dynamic_cache(cache)
+ if isinstance(cache, DynamicCache):
+ self.llm_past_key_values = cache
+ return cache
+
+ return None
+
+ def _get_kv_cache_length(self, cache=None):
+ cache = cache if cache is not None else self.llm_past_key_values
+ return get_kv_cache_length(cache)
+
+ # todo: not-used del?
+ def _rebuild_cache_from_history(self):
+ preserved_ids: List[torch.Tensor] = []
+ for entry in self._omni_chunk_history:
+ ids = entry.get("input_ids")
+ if ids is None or not isinstance(ids, torch.Tensor) or ids.numel() == 0:
+ continue
+ preserved_ids.append(ids.to(self.device))
+ if not preserved_ids:
+ self.llm_past_key_values = None
+ self.streaming_position_offset = 0
+ self._rope_inv_freq_cache.clear()
+ return
+
+ concat_ids = torch.cat(preserved_ids, dim=1)
+ attention_mask = torch.ones((1, concat_ids.shape[1]), dtype=torch.bool, device=self.device)
+ outputs = self.llm(
+ input_ids=concat_ids,
+ attention_mask=attention_mask,
+ use_cache=True,
+ return_dict=True,
+ )
+ self.llm_past_key_values = outputs.past_key_values
+ self.streaming_position_offset = 0
+ self._rope_inv_freq_cache.clear()
+
+ def _get_rope_theta(self) -> float:
+ return float(getattr(self.llm.config, "rope_theta", 10000.0))
+
+ def _realign_rotary_suffix(
+ self,
+ suffix_keys: torch.Tensor,
+ old_positions: torch.Tensor,
+ new_positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return realign_rotary_suffix(
+ suffix_keys,
+ old_positions,
+ new_positions,
+ rope_theta=self._get_rope_theta(),
+ inv_freq_cache=self._rope_inv_freq_cache,
+ )
+
+ def _encode_text(self, tokenizer, text) -> Optional[torch.Tensor]:
+ if tokenizer is None or not text:
+ return None
+ ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ return ids.to(self.device)
+
+ @staticmethod
+ def _safe_decode(tokenizer, input_ids):
+ if tokenizer is None or input_ids is None:
+ return None
+ if isinstance(input_ids, torch.Tensor):
+ ids = input_ids.cpu().tolist()
+ if ids and isinstance(ids[0], list):
+ ids = ids[0]
+ else:
+ ids = input_ids
+ try:
+ return tokenizer.decode(ids, skip_special_tokens=False)
+ except Exception:
+ return None
+
+ def _finalize_round(
+ self, round_id: Optional[int], cache_before: int, assistant_input_ids: Optional[torch.Tensor] = None
+ ):
+ if round_id is None:
+ self._pending_round_id = None
+ return
+ cache_after = self._get_kv_cache_length()
+ if assistant_input_ids is not None:
+ assistant_len = assistant_input_ids.shape[1]
+ else:
+ assistant_len = max(cache_after - cache_before, 0)
+ if assistant_len > 0:
+ self._register_chunk(
+ assistant_len,
+ "assistant",
+ round_id=round_id,
+ input_ids=assistant_input_ids,
+ tokenizer=self.processor.tokenizer if hasattr(self, "processor") else None,
+ )
+
+ self._pending_round_id = None
+ self._next_round_id += 1
+
+ def _register_chunk(
+ self,
+ seq_len: int,
+ chunk_type: str,
+ *,
+ round_id: int,
+ input_ids=None,
+ tokenizer=None,
+ ) -> None:
+ if seq_len <= 0:
+ return
+ entry = {"length": int(seq_len), "type": chunk_type, "round": round_id}
+ if input_ids is not None:
+ entry["input_ids"] = input_ids.clone().detach()
+ entry["decoded"] = self._safe_decode(tokenizer, entry["input_ids"])
+ else:
+ entry["input_ids"] = None
+ entry["decoded"] = None
+ self._omni_chunk_history.append(entry)
+
+ if chunk_type == "system":
+ self.streaming_text_preserve = max(self.streaming_text_preserve, entry["length"])
+
+ def _drop_tokens_from_cache(self, length: int, cache: DynamicCache) -> bool:
+ """Drop tokens from cache using the utility function."""
+ _, new_offset, success = drop_tokens_from_cache(
+ cache=cache,
+ length=length,
+ preserve=self.streaming_text_preserve,
+ position_offset=self.streaming_position_offset,
+ rope_theta=self._get_rope_theta(),
+ inv_freq_cache=self._rope_inv_freq_cache,
+ )
+ if success:
+ self.streaming_position_offset = new_offset
+ return success
+
+ def _drop_next_round(self, cache: DynamicCache) -> bool:
+ seen_rounds = set()
+ for entry in self._omni_chunk_history:
+ round_id = entry.get("round")
+ if round_id is None or round_id in seen_rounds:
+ continue
+ seen_rounds.add(round_id)
+ round_entries = [e for e in self._omni_chunk_history if e.get("round") == round_id]
+ if any(e.get("type") == "system" for e in round_entries):
+ continue
+ if self._drop_round(round_id, cache):
+ return True
+ return False
+
+ def _drop_round(self, round_id: int, cache: DynamicCache) -> bool:
+ entries = [e for e in self._omni_chunk_history if e.get("round") == round_id]
+ if not entries:
+ return False
+ total_len = sum(e["length"] for e in entries)
+ if total_len <= 0:
+ for e in entries:
+ self._omni_chunk_history.remove(e)
+ return False
+ if not self._drop_tokens_from_cache(total_len, cache):
+ return False
+ for e in entries:
+ self._omni_chunk_history.remove(e)
+ return True
+
+ def _enforce_text_window(self) -> None:
+ if not self.streaming_window_enabled:
+ return
+ cache = self._ensure_dynamic_cache()
+ if cache is None:
+ return
+ high_limit = max(0, int(self.streaming_window_config.text_window_high_tokens))
+ low_limit = max(0, int(self.streaming_window_config.text_window_low_tokens))
+ if high_limit <= 0:
+ return
+ target = max(0, low_limit)
+ total_len = self._get_kv_cache_length(cache)
+ if total_len <= high_limit:
+ return
+ dropped_any = False
+ while total_len > target:
+ if not self._drop_next_round(cache):
+ break
+ dropped_any = True
+ total_len = self._get_kv_cache_length(cache)
+
+ # snapshot, vad
+ def save_speculative_snapshot(self) -> SpeculativeSnapshot:
+ """Internal method: save speculative snapshot.
+
+ Called at the start of streaming_generate, saves to self._speculative_snapshot.
+
+ Save strategy:
+ - LLM KV Cache: only record length (restore by truncation, zero extra VRAM)
+ - Audio KV Cache: deep clone (as generate sets it to None)
+ - Mel processor: full state snapshot (including buffer)
+ """
+ # get LLM cache information
+ llm_cache_length = self._get_kv_cache_length()
+ llm_cache_checksum = None
+ if self.llm_past_key_values is not None and hasattr(self.llm_past_key_values, "key_cache"):
+ if len(self.llm_past_key_values.key_cache) > 0:
+ llm_cache_checksum = self.llm_past_key_values.key_cache[0].sum().item()
+
+ # get audio cache length and clone audio_past_key_values
+ audio_cache_length = 0
+ audio_cache_checksum = None
+ audio_past_key_values_clone = None
+ if self.audio_past_key_values is not None:
+ # handle DynamicCache format (Whisper encoder may return this format)
+ if isinstance(self.audio_past_key_values, DynamicCache):
+ if hasattr(self.audio_past_key_values, "key_cache") and len(self.audio_past_key_values.key_cache) > 0:
+ audio_cache_length = self.audio_past_key_values.key_cache[0].shape[2]
+ audio_cache_checksum = self.audio_past_key_values.key_cache[0].sum().item()
+ # deep clone DynamicCache
+ cloned_cache = DynamicCache()
+ for k, v in zip(self.audio_past_key_values.key_cache, self.audio_past_key_values.value_cache):
+ cloned_cache.update(k.clone(), v.clone(), layer_idx=len(cloned_cache.key_cache))
+ audio_past_key_values_clone = cloned_cache
+
+ # handle EncoderDecoderCache format
+ elif isinstance(self.audio_past_key_values, EncoderDecoderCache):
+ self_attn_cache = self.audio_past_key_values.self_attention_cache
+ if hasattr(self_attn_cache, "key_cache") and len(self_attn_cache.key_cache) > 0:
+ audio_cache_length = self_attn_cache.key_cache[0].shape[2]
+ audio_cache_checksum = self_attn_cache.key_cache[0].sum().item()
+ # deep clone EncoderDecoderCache
+ cloned_self_attn = DynamicCache()
+ if hasattr(self_attn_cache, "key_cache"):
+ for k, v in zip(self_attn_cache.key_cache, self_attn_cache.value_cache):
+ cloned_self_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_self_attn.key_cache))
+ cross_attn_cache = self.audio_past_key_values.cross_attention_cache
+ cloned_cross_attn = DynamicCache()
+ if hasattr(cross_attn_cache, "key_cache"):
+ for k, v in zip(cross_attn_cache.key_cache, cross_attn_cache.value_cache):
+ cloned_cross_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_cross_attn.key_cache))
+ audio_past_key_values_clone = EncoderDecoderCache(cloned_self_attn, cloned_cross_attn)
+
+ # handle tuple format (compatible with old format)
+ elif isinstance(self.audio_past_key_values, tuple) and len(self.audio_past_key_values) > 0:
+ audio_cache_length = self.audio_past_key_values[0][0].shape[2]
+ audio_cache_checksum = self.audio_past_key_values[0][0].sum().item()
+ # deep clone audio_past_key_values (tuple of tuples of tensors)
+ audio_past_key_values_clone = tuple(
+ tuple(t.clone() for t in layer_cache) for layer_cache in self.audio_past_key_values
+ )
+
+ # get mel processor snapshot
+ mel_processor_snapshot = None
+ mel_buffer_checksum = None
+ if hasattr(self, "processor") and self.processor is not None:
+ mel_processor_snapshot = self.processor.get_streaming_snapshot()
+ if mel_processor_snapshot:
+ buf = mel_processor_snapshot.get("buffer")
+ if buf is not None and len(buf) > 0:
+ mel_buffer_checksum = float(buf.sum())
+
+ # save RNG state (important: for deterministic dithering and other random operations after restoration)
+ rng_state_cpu = torch.get_rng_state()
+ rng_state_cuda = None
+ if torch.cuda.is_available() and self.device.type == "cuda":
+ rng_state_cuda = torch.cuda.get_rng_state(self.device)
+
+ # create snapshot
+ snapshot = SpeculativeSnapshot(
+ llm_cache_length=llm_cache_length,
+ audio_cache_length=audio_cache_length,
+ new_user_msg=self.new_user_msg,
+ llm_generated=self.llm_generated,
+ llm_generate_completed=self.llm_generate_completed,
+ next_round_id=self._next_round_id,
+ pending_round_id=self._pending_round_id,
+ omni_chunk_history_length=len(self._omni_chunk_history),
+ tts_last_turn_tokens=self.tts_last_turn_tokens.clone() if self.tts_last_turn_tokens is not None else None,
+ audio_chunk_idx=self.audio_chunk_idx,
+ mel_processor_snapshot=mel_processor_snapshot,
+ audio_past_key_values=audio_past_key_values_clone,
+ timestamp=time.time(),
+ # debug fields
+ llm_cache_checksum=llm_cache_checksum,
+ audio_cache_checksum=audio_cache_checksum,
+ mel_buffer_checksum=mel_buffer_checksum,
+ # RNG state
+ rng_state_cpu=rng_state_cpu,
+ rng_state_cuda=rng_state_cuda,
+ )
+
+ return snapshot
+
+ def restore_speculative_snapshot(self, snapshot=None) -> bool:
+ """Restore speculative snapshot - called when VAD speculation fails.
+
+ Restores model state to before streaming_generate was called,
+ allowing continued streaming_prefill for newly arrived audio.
+
+ Notes:
+ - Snapshot is saved when streaming_generate is called with enable_speculative_snapshot=True
+ - This method uses the most recent snapshot for restoration
+ - Snapshot is cleared after restore, cannot be called repeatedly
+
+ Returns:
+ bool: Whether restoration was successful
+ """
+ snapshot = snapshot or getattr(self, "_speculative_snapshot", None)
+
+ if snapshot is None:
+ return False
+
+ try:
+ current_cache_length = self._get_kv_cache_length()
+ current_history_length = len(self._omni_chunk_history)
+
+ # 1. truncate LLM KV Cache
+ if current_cache_length > snapshot.llm_cache_length:
+ self._truncate_llm_cache(snapshot.llm_cache_length)
+
+ # 2. restore Audio KV Cache (important: restore from cloned copy)
+ # because streaming_generate will set audio_past_key_values to None
+ self.audio_past_key_values = snapshot.audio_past_key_values
+
+ # 3. restore session state
+ self.new_user_msg = snapshot.new_user_msg
+ self.llm_generated = snapshot.llm_generated
+ self.llm_generate_completed = snapshot.llm_generate_completed
+
+ # 4. restore Round management
+ self._next_round_id = snapshot.next_round_id
+ self._pending_round_id = snapshot.pending_round_id
+
+ # 5. truncate chunk history
+ if current_history_length > snapshot.omni_chunk_history_length:
+ self._omni_chunk_history = self._omni_chunk_history[: snapshot.omni_chunk_history_length]
+
+ # 6. restore TTS state
+ self.tts_last_turn_tokens = snapshot.tts_last_turn_tokens
+
+ # 7. restore streaming processor state
+ self.audio_chunk_idx = snapshot.audio_chunk_idx
+
+ # 8. restore mel processor state (important: otherwise subsequent prefill will fail due to frame number mismatch)
+ if (
+ snapshot.mel_processor_snapshot is not None
+ and hasattr(self, "processor")
+ and self.processor is not None
+ ):
+ self.processor.restore_streaming_snapshot(snapshot.mel_processor_snapshot)
+
+ # 9. restore RNG state (important: ensure determinism of dithering and other random operations after restoration)
+ if snapshot.rng_state_cpu is not None:
+ torch.set_rng_state(snapshot.rng_state_cpu)
+ if snapshot.rng_state_cuda is not None and torch.cuda.is_available():
+ torch.cuda.set_rng_state(snapshot.rng_state_cuda, self.device)
+
+ # 10. clean up temporary states generated during generation
+ if hasattr(self, "_streaming_generated_token_ids"):
+ del self._streaming_generated_token_ids
+ if hasattr(self, "_last_streaming_text"):
+ del self._last_streaming_text
+
+ # 11. clear snapshot (can only be restored once)
+ self._speculative_snapshot = None
+
+ return True
+ except Exception as e:
+ import traceback
+
+ logger.error(traceback.format_exc())
+ return False
+
+ def has_speculative_snapshot(self) -> bool:
+ return getattr(self, "_speculative_snapshot", None) is not None
+
+ def clear_speculative_snapshot(self) -> None:
+ if hasattr(self, "_speculative_snapshot"):
+ self._speculative_snapshot = None
+
+ def _truncate_llm_cache(self, target_length: int) -> None:
+ if self.llm_past_key_values is None:
+ return
+
+ cache = self._ensure_dynamic_cache()
+ if cache is None:
+ return
+
+ current_length = self._get_kv_cache_length(cache)
+ if current_length <= target_length:
+ return
+
+ # truncate each layer of cache
+ for layer_idx in range(len(cache.key_cache)):
+ if cache.key_cache[layer_idx].numel() > 0:
+ cache.key_cache[layer_idx] = cache.key_cache[layer_idx][:, :, :target_length, :].contiguous()
+ cache.value_cache[layer_idx] = cache.value_cache[layer_idx][:, :, :target_length, :].contiguous()
+
+ # update cache metadata
+ cache.crop(target_length)
+ cache._seen_tokens = target_length
+
+ @torch.inference_mode()
+ def streaming_prefill(
+ self,
+ session_id,
+ msgs,
+ omni_mode=True,
+ max_slice_nums=None,
+ use_tts_template=True,
+ enable_thinking=False,
+ is_last_chunk=False, # for audio chunk, if is the last chunk, set to True
+ tokenizer=None,
+ processor=None,
+ **kwargs,
+ ):
+ from PIL import Image
+
+ assert session_id is not None, "session_id cannot be None"
+ self.is_first = self.session_id is None or session_id != self.session_id
+
+ self.prepare_processor(processor=processor, tokenizer=tokenizer)
+
+ images = []
+ audios = []
+
+ assert len(msgs) == 1
+ copy_msgs = deepcopy(msgs)
+ msg = copy_msgs[0]
+
+ assert msg["role"] in ["system", "user", "assistant"]
+ is_not_system_prefill = msg["role"] != "system"
+
+ content = msg["content"]
+ cur_msgs = []
+ for j, c in enumerate(content):
+ if isinstance(c, Image.Image):
+ images.append(c)
+ cur_msgs.append("./")
+ elif isinstance(c, np.ndarray):
+ audios.append(c)
+ cur_msgs.append("")
+ elif isinstance(c, str):
+ cur_msgs.append(c)
+ else:
+ logger.error(f"Invalid content type: {c}, ignore it.")
+
+ cur_contents = "".join(cur_msgs) if omni_mode else "\n".join(cur_msgs)
+
+ if msg["role"] in ["system", "assistant"]:
+ self.new_user_msg = True
+ self.audio_past_key_values = None
+
+ if self.is_first:
+ self.reset_session(reset_token2wav_cache=False)
+ self.session_id = session_id
+
+ self.init_streaming_processor()
+
+ if msg["role"] == "user":
+ # no system prefill, the first segment of the first user turn
+ # do not use apply_chat_template, manually build prompt to avoid automatic addition of <|im_end|>
+ prompt = "<|im_start|>user\n" + cur_contents
+ self.new_user_msg = False # mark subsequent segments do not need to add user prefix anymore
+ else:
+ # system or assistant prefill, use apply_chat_template
+ msg["content"] = cur_contents
+ prompt = self.processor.tokenizer.apply_chat_template(
+ copy_msgs,
+ tokenize=False,
+ add_generation_prompt=False,
+ use_tts_template=use_tts_template,
+ enable_thinking=enable_thinking,
+ )
+ add_special_tokens = True # add bos
+ else:
+ # non-first prefill
+ if self.new_user_msg and msg["role"] == "user":
+ # the first segment of the new user turn
+ if self.llm_generated:
+ if self.llm_generate_completed:
+ prompt = "<|im_end|>\n<|im_start|>user\n" + cur_contents
+ else:
+ prompt = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents
+ else:
+ prompt = "<|im_start|>user\n" + cur_contents
+ self.new_user_msg = False
+ else:
+ # subsequent segments of the same turn, directly use content
+ prompt = cur_contents
+ add_special_tokens = False
+
+ # when first user audio prefill, ensure audio length satisfies FIRST_CHUNK_MS requirements
+ if is_not_system_prefill and len(audios) > 0 and self.audio_chunk_idx == 0:
+ assert len(audios) == 1, f"streaming mode only supports single audio, currently {len(audios)}"
+ first_chunk_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000)
+ if len(audios[0]) < first_chunk_samples:
+ pad_len = first_chunk_samples - len(audios[0])
+ audios[0] = np.concatenate([np.zeros(pad_len, dtype=audios[0].dtype), audios[0]])
+
+ model_inputs = self.processor(
+ [prompt],
+ [images],
+ [audios],
+ max_slice_nums=1 if max_slice_nums is None else max_slice_nums,
+ use_image_id=False,
+ chunk_input=True,
+ return_tensors="pt",
+ max_length=None,
+ sampling_rate=16000,
+ add_special_tokens=add_special_tokens,
+ online_streaming=is_not_system_prefill,
+ audio_chunk_idx=self.audio_chunk_idx,
+ is_last_chunk=is_last_chunk,
+ ).to(self.device)
+
+ if len(audios) > 0 and is_not_system_prefill:
+ self.audio_chunk_idx += 1
+
+ # 1. prepare input embeddings
+ model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs)
+ # get audio embedding with audio_past_key_values
+ inputs_embeds = self.get_omni_embedding(
+ model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=is_not_system_prefill
+ )
+
+ if self.is_first:
+ self.audio_past_key_values = None
+
+ round_id = self._next_round_id
+ self._pending_round_id = round_id
+ chunk_type = "system" if msg["role"] == "system" else ("user" if msg["role"] == "user" else "assistant")
+ seq_len = inputs_embeds.shape[1]
+ self._enforce_text_window()
+ cache_length = self._get_kv_cache_length()
+
+ attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device)
+
+ # 2. do prefill
+ outputs = self.llm(
+ past_key_values=self.llm_past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=None,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ self.llm_past_key_values = as_dynamic_cache(outputs["past_key_values"])
+ self._register_chunk(
+ seq_len,
+ chunk_type,
+ round_id=round_id,
+ input_ids=model_inputs["input_ids"],
+ tokenizer=self.processor.tokenizer,
+ )
+ self._enforce_text_window()
+ if self.force_rope_reindex:
+ self._force_reindex_all_cache()
+
+ return prompt
+
+ @torch.inference_mode()
+ def streaming_generate(
+ self,
+ session_id,
+ bos_input=None,
+ generate_audio=True,
+ audio_token_chunk_size=25, # 25 token/s
+ tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
+ max_new_tokens=256,
+ enable_thinking=False,
+ use_tts_template=True,
+ do_sample=True,
+ enable_speculative_snapshot=False,
+ tokenizer=None,
+ processor=None,
+ # Teacher forcing (only for the "text → hidden → TTS condition" pipeline in streaming_generate)
+ # When enabled: instead of letting the LLM auto-regressively generate the text to be spoken,
+ # it forces the tokens from teacher_forcing_text to be fed in, using the hidden states
+ # corresponding to these tokens to construct the TTS condition, ensuring the output audio matches the input text.
+ teacher_forcing: bool = False,
+ teacher_forcing_text: str = "",
+ **kwargs,
+ ):
+ # save speculative snapshot (before modifying any state)
+ # for VAD speculative snapshot: if speculative snapshot fails, can call restore_speculative_snapshot() to restore
+ # enable_speculative_snapshot=True when enabled, skip (save some overhead) when disabled
+ if enable_speculative_snapshot:
+ self._speculative_snapshot = self.save_speculative_snapshot()
+
+ # reset buf
+ self.new_user_msg = True
+ self.llm_generated = True
+ self.llm_generate_completed = False
+ self.audio_past_key_values = None
+
+ self.prepare_processor(processor=processor, tokenizer=tokenizer)
+
+ # reset current turn generated token IDs
+ if hasattr(self, "_streaming_generated_token_ids"):
+ del self._streaming_generated_token_ids
+ # reset full generated text
+ if hasattr(self, "_last_streaming_text"):
+ del self._last_streaming_text
+
+ cache = self._ensure_dynamic_cache()
+ cache_length = self._get_kv_cache_length(cache)
+ host_round_id = self._pending_round_id
+
+ ## in single-turn streaming, each call to streaming_generate needs to reinitialize the streaming_processor, enter the next turn
+ self.init_streaming_processor()
+
+ # 1) llm generate token and hidden states per chunk=10, 2) tts generate audio token chunk per chunk=25, 3) yield 1 chunk audio token
+ def audio_chunk_generator(
+ bos_input,
+ tokenizer,
+ generate_audio,
+ tts_sampling_params,
+ max_new_tokens,
+ do_sample,
+ teacher_forcing=False,
+ teacher_forcing_text="",
+ **kwargs,
+ ):
+ generate_chunk_size = 10
+
+ if bos_input is None:
+ bos_input = "".join(
+ [
+ "<|im_end|>\n<|im_start|>assistant\n",
+ "" if enable_thinking else self.think_str.replace("\\n", "\n"),
+ "<|tts_bos|>" if use_tts_template else "",
+ ]
+ )
+
+ bos_input_ids = tokenizer.encode(bos_input)
+ bos_input_ids = torch.tensor(bos_input_ids, dtype=torch.long, device=self.device).unsqueeze(0)
+
+ bos_input_embeds = self.llm.get_input_embeddings()(bos_input_ids)
+
+ generation_inputs_embeds = bos_input_embeds
+ generated_ids = torch.empty((1, 0), dtype=torch.long, device=self.device)
+
+ num_chunks_decode = (max_new_tokens + generate_chunk_size - 1) // generate_chunk_size
+
+ conditions = []
+
+ # generate chunk by chunk, each chunk has 10 tokens, each chunk takes last hidden states, and pass tokens to tts
+ llm_streaming_generator = ChunkPrefillChunkGenerate(
+ model=self.llm,
+ tokenizer=tokenizer,
+ terminators=["<|tts_eos|>", "<|im_end|>", ""],
+ )
+
+ if generate_audio:
+ logits_warpers, logits_processors = gen_logits(
+ num_code=self.tts.config.num_audio_tokens,
+ repetition_penalty=tts_sampling_params.repetition_penalty,
+ top_p=tts_sampling_params.top_p,
+ top_k=tts_sampling_params.top_k,
+ )
+
+ tts_streaming_generator = TTSStreamingGenerator(
+ model=self.tts,
+ temperature=tts_sampling_params.temperature,
+ eos_token=torch.tensor(
+ [self.tts.config.num_audio_tokens - 1],
+ dtype=torch.long,
+ device=self.tts.device,
+ ),
+ chunk_size=audio_token_chunk_size, # s3tokenizer 1s = 25token
+ tts_last_turn_tokens=self.tts_last_turn_tokens,
+ logits_processors=logits_processors,
+ logits_warpers=logits_warpers,
+ )
+
+ # Teacher forcing branch
+ # This branch does not rely on ChunkPrefillChunkGenerate's sampling logic, instead:
+ # 1) First prefill bos_input (assistant + tts_bos) into llm_past_key_values
+ # 2) Tokenize teacher_forcing_text into token ids
+ # 3) Feed tokens one by one into the LLM (teacher forcing), obtaining the last_hidden_states for each token
+ # 4) Use (token_ids, hidden_states) to construct tts condition, then feed it to TTSStreamingGenerator
+ if teacher_forcing:
+ # --- 1) prefill bos_input,延续 streaming_prefill 的 KV cache ---
+ bos_outputs = self.llm(
+ inputs_embeds=generation_inputs_embeds,
+ past_key_values=self.llm_past_key_values,
+ use_cache=True,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ self.llm_past_key_values = bos_outputs.past_key_values
+
+ if generate_audio:
+ # Give a length-0 tensor as speaker embedding (no speaker embedding)
+ spk_emb = torch.empty(
+ (bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]),
+ dtype=bos_input_embeds.dtype,
+ device=bos_input_embeds.device,
+ )
+ tts_streaming_generator.spk_emb = spk_emb
+
+ # --- 2) tokenize teacher_forcing_text ---
+ tf_text = teacher_forcing_text or ""
+ try:
+ forced_input_ids = tokenizer(tf_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
+ except Exception:
+ # Compatible with rare tokenizer return object attributes
+ forced_input_ids = tokenizer(tf_text, add_special_tokens=False, return_tensors="pt").input_ids
+ forced_input_ids = forced_input_ids.to(self.device)
+
+ total_len = int(forced_input_ids.shape[1])
+ ptr = 0
+
+ # Special case: empty text should also let TTS finish (text_finished=True will automatically concatenate text_eos_embed)
+ if total_len == 0:
+ if not generate_audio:
+ yield forced_input_ids, True
+ return
+ empty_tts_embeds = torch.empty(
+ (1, 0, self.tts.config.hidden_size),
+ dtype=bos_input_embeds.dtype,
+ device=self.device,
+ )
+ if not hasattr(self, "_streaming_generated_token_ids"):
+ self._streaming_generated_token_ids = []
+ tts_generator = tts_streaming_generator.generate_with_buffer(
+ condition=empty_tts_embeds,
+ text_finished=True,
+ )
+ for audio_token_chunk, is_last_audio_chunk in tts_generator:
+ yield audio_token_chunk, is_last_audio_chunk
+ self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
+ self._last_streaming_text = ""
+ yield None, None
+ return
+
+ # --- 3) chunk-by-chunk teacher forcing ---
+ while ptr < total_len:
+ end = min(ptr + generate_chunk_size, total_len)
+ chunk_ids = forced_input_ids[:, ptr:end] # [1, chunk_len]
+ chunk_hidden_list = []
+
+ for j in range(chunk_ids.shape[1]):
+ tok = chunk_ids[:, j : j + 1] # [1, 1]
+ tok_emb = self.llm.get_input_embeddings()(tok)
+ out = self.llm(
+ inputs_embeds=tok_emb,
+ past_key_values=self.llm_past_key_values,
+ use_cache=True,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ self.llm_past_key_values = out.past_key_values
+ chunk_hidden_list.append(out.hidden_states[-1]) # [1, 1, hidden]
+
+ chunk_hidden = torch.cat(chunk_hidden_list, dim=1) # [1, chunk_len, hidden]
+ text_finished = end >= total_len
+
+ # Save token IDs cache (external eval script will use _last_streaming_text to write generated_text)
+ if not hasattr(self, "_streaming_generated_token_ids"):
+ self._streaming_generated_token_ids = []
+ self._streaming_generated_token_ids.extend(chunk_ids[0].tolist())
+
+ if not generate_audio:
+ yield chunk_ids, text_finished
+ else:
+ llm_embeds = self.tts.emb_text(chunk_ids)
+ hidden_embeds = self.tts.projector_semantic(chunk_hidden)
+ if self.tts.config.normalize_projected_hidden:
+ hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
+ tts_embeds = llm_embeds + hidden_embeds
+
+ tts_generator = tts_streaming_generator.generate_with_buffer(
+ condition=tts_embeds,
+ text_finished=text_finished,
+ )
+ for audio_token_chunk, is_last_audio_chunk in tts_generator:
+ yield audio_token_chunk, is_last_audio_chunk
+
+ ptr = end
+ if text_finished:
+ if generate_audio:
+ self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
+ break
+
+ # Finish: decode this round of text
+ if hasattr(self, "_streaming_generated_token_ids"):
+ try:
+ self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids)
+ assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text)
+ self._finalize_round(
+ round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids
+ )
+ except Exception:
+ self._last_streaming_text = None
+ else:
+ self._last_streaming_text = None
+
+ # Finally send the end signal
+ if generate_audio:
+ yield None, None
+ else:
+ return
+ return
+
+ # LLM chunk generate outer loop
+ for chunk_idx in range(num_chunks_decode):
+ is_first_generate_chunk = chunk_idx == 0
+
+ output = llm_streaming_generator.chunk_generate(
+ inputs_embeds=generation_inputs_embeds,
+ past_key_values=self.llm_past_key_values,
+ is_first_generate_chunk=is_first_generate_chunk,
+ return_hidden_states=True,
+ chunk_size=generate_chunk_size + 1 * is_first_generate_chunk,
+ do_sample=do_sample,
+ temperature=kwargs.get("temperature", 0.7),
+ top_p=kwargs.get("top_p", 0.8),
+ top_k=kwargs.get("top_k", 100),
+ repetition_penalty=kwargs.get("repetition_penalty", 1.02),
+ length_penalty=kwargs.get("length_penalty", 1.0),
+ all_input_ids=generated_ids,
+ )
+
+ if output.chunk_token_ids is None:
+ break
+
+ if is_first_generate_chunk:
+ if generate_audio:
+ spk_emb = torch.empty(
+ (bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]),
+ dtype=bos_input_embeds.dtype,
+ device=bos_input_embeds.device,
+ )
+ tts_streaming_generator.spk_emb = spk_emb
+
+ if output.finished:
+ yield_chunk_token_ids = output.chunk_token_ids
+ else:
+ # the first chunk generated chunk_size + 1 tokens, we only take the first chunk_size tokens,
+ # the last token is not prefilled, and last hidden states is not obtained
+ yield_chunk_token_ids = output.chunk_token_ids[:, :-1]
+
+ elif output.finished:
+ yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids], dim=1)
+ else:
+ # in the chunk that is not the first chunk, we need to add the token at the end of the previous chunk,
+ # it is not prefilled into the model to get last hidden states
+ # similarly, the last generated token of subsequent chunks is not prefilled, and last hidden states is not obtained,
+ # so it is not passed out
+ yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids[:, :-1]], dim=1)
+
+ if not generate_audio:
+ chunk_generated_text = tokenizer.decode(yield_chunk_token_ids[0])
+ yield yield_chunk_token_ids, output.finished
+ else:
+ # TTS inner loop
+ # dense connection here is hardcoded to use text-hidden merged as condition
+ llm_embeds = self.tts.emb_text(yield_chunk_token_ids)
+ hidden_embeds = output.last_hidden_states
+ hidden_embeds = self.tts.projector_semantic(hidden_embeds)
+ if self.tts.config.normalize_projected_hidden: # default should be opened
+ hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
+
+ tts_embeds = llm_embeds + hidden_embeds
+ conditions.append(tts_embeds)
+
+ # Store token IDs instead of decoded text to avoid UTF-8 multi-byte character truncation
+ if not hasattr(self, "_streaming_generated_token_ids"):
+ self._streaming_generated_token_ids = []
+ self._streaming_generated_token_ids.extend(yield_chunk_token_ids[0].tolist())
+
+ # there is buffer generated, each time exactly returns 25 audio tokens,
+ # the last audio chunk returns audio tokens of variable length, length [0, 25]
+ tts_generator = tts_streaming_generator.generate_with_buffer(
+ condition=tts_embeds, text_finished=output.finished
+ )
+
+ for audio_token_chunk, is_last_audio_chunk in tts_generator:
+ yield audio_token_chunk, is_last_audio_chunk
+
+ generated_ids = torch.cat([generated_ids, output.chunk_token_ids], dim=1)
+ generation_inputs_embeds = output.current_inputs_embeds
+ self.llm_past_key_values = output.past_key_values
+
+ if output.finished:
+ if generate_audio:
+ self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
+ break
+
+ # IMPORTANT: Flush remaining TTS buffer when LLM generation ends
+ # This handles BOTH cases:
+ # 1. LLM finished with terminator (output.finished=True) - buffer may still have tokens
+ # 2. LLM hit max chunks limit (output.finished=False) - buffer definitely has tokens
+ if generate_audio:
+ if len(tts_streaming_generator._token_buffer) > 0:
+ batch = torch.cat(tts_streaming_generator._token_buffer, dim=1)
+ yield batch, True
+ tts_streaming_generator._token_buffer = []
+
+ if generate_audio:
+ if hasattr(self, "_streaming_generated_token_ids"):
+ try:
+ self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids)
+ assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text)
+ self._finalize_round(
+ round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids
+ )
+ except Exception:
+ self._last_streaming_text = None
+ else:
+ self._last_streaming_text = None
+
+ yield None, None
+ else:
+ return
+
+ # iter for generating text chunk and audio chunk
+ audio_chunk_generator_iter = audio_chunk_generator(
+ bos_input=bos_input,
+ tokenizer=self.processor.tokenizer,
+ generate_audio=generate_audio,
+ tts_sampling_params=tts_sampling_params,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ teacher_forcing=teacher_forcing,
+ teacher_forcing_text=teacher_forcing_text,
+ **kwargs,
+ )
+
+ if generate_audio:
+ if self.tts.config.audio_tokenizer_type == "s3tokenizer_step_audio":
+ self.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.token2wav_cache["flow_cache_base"])
+ self.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(
+ self.token2wav_cache["hift_cache_base"]
+ )
+
+ # pre-insert 3-5 prefix 4218 silence tokens, each token corresponds to 0.04s,
+ # adding 5 tokens means introducing 0.2s of silence
+ buffer = [4218] * 3
+ pre_lookahead = 3
+ CHUNK_SIZE = 25
+ chunk_idx = 0
+ prev_text_len = 0 # track text position for streaming text output
+ for audio_token_chunk, is_last_audio_chunk in audio_chunk_generator_iter:
+ if audio_token_chunk is None:
+ break
+
+ buffer += audio_token_chunk.reshape(-1).tolist()
+
+ if len(buffer) >= CHUNK_SIZE + pre_lookahead:
+ waveform_chunk = self.tts.audio_tokenizer.stream(
+ buffer[: CHUNK_SIZE + pre_lookahead],
+ prompt_wav=None,
+ last_chunk=is_last_audio_chunk,
+ return_waveform=True,
+ )
+
+ waveform_chunk = torch.from_numpy(waveform_chunk)
+
+ # get new text chunk corresponding to this waveform
+ # Decode from accumulated token IDs to avoid UTF-8 multi-byte truncation
+ new_text = ""
+ if hasattr(self, "_streaming_generated_token_ids"):
+ current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids)
+ # Filter out trailing replacement characters (incomplete UTF-8 sequences)
+ safe_end = len(current_text)
+ while safe_end > 0 and current_text[safe_end - 1] == "\ufffd":
+ safe_end -= 1
+ safe_text = current_text[:safe_end]
+ new_text = safe_text[prev_text_len:]
+ prev_text_len = len(safe_text)
+
+ yield waveform_chunk, new_text
+
+ buffer = buffer[CHUNK_SIZE:]
+ chunk_idx += 1
+
+ # flush rest
+ if len(buffer) > 0:
+ waveform_chunk = self.tts.audio_tokenizer.stream(
+ buffer,
+ prompt_wav=None,
+ last_chunk=True,
+ return_waveform=True,
+ )
+
+ waveform_chunk = torch.from_numpy(waveform_chunk)
+
+ # get remaining new text for the final chunk
+ # Final chunk: decode all remaining text without filtering
+ new_text = ""
+ if hasattr(self, "_streaming_generated_token_ids"):
+ current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids)
+ new_text = current_text[prev_text_len:]
+ prev_text_len = len(current_text)
+
+ yield waveform_chunk, new_text
+
+ # maybe the buffer is empty, and text is not empty, should we flush text without wave?
+ else:
+ raise NotImplementedError(f"not supported audio tokenizer: {self.tts.config.audio_tokenizer_type}")
+ else:
+ # For text-only generation, decode tokens and handle partial multi-byte characters
+ yield from streaming_token_decoder(
+ audio_chunk_generator_iter,
+ self.processor.tokenizer,
+ skip_special_tokens=False,
+ )
+
+ def as_duplex(self, device: Optional[str] = None, **kwargs) -> "MiniCPMODuplex":
+ """Convert this MiniCPMO instance to MiniCPMODuplex for full-duplex streaming."""
+ return MiniCPMODuplex.from_existing_model(
+ model=self,
+ device=device,
+ **kwargs,
+ )
+
+
+class MiniCPMODuplex:
+ """MiniCPMODuplex model with full-duplex streaming capabilities.
+
+ This is a wrapper class that provides duplex streaming functionality.
+ Use MiniCPMO.as_duplex() to create from an existing model without reloading.
+ """
+
+ # Default duplex parameters
+ _default_duplex_params = {
+ "generate_audio": True,
+ "ls_mode": "explicit",
+ "max_new_speak_tokens_per_chunk": 20,
+ "text_repetition_penalty": 1.05,
+ "temperature": 0.7,
+ "top_k": 100,
+ "top_p": 0.8,
+ "text_repetition_window_size": 512,
+ "listen_prob_scale": 1.0,
+ "force_listen_count": 0,
+ "tts_temperature": 0.8,
+ "tts_repetition_penalty": 1.05,
+ "enable_float16": False,
+ "n_timesteps": 10,
+ "chunk_ms": 1000,
+ "first_chunk_ms": 1035,
+ "cnn_redundancy_ms": 20,
+ "sample_rate": 16000,
+ "sliding_window_mode": "off",
+ "basic_window_high_tokens": 8000,
+ "basic_window_low_tokens": 6000,
+ "context_previous_max_tokens": 500,
+ "context_max_units": 24,
+ }
+
+ @classmethod
+ def from_existing_model(
+ cls,
+ model: "MiniCPMO",
+ device: Optional[str] = None,
+ **kwargs,
+ ) -> "MiniCPMODuplex":
+ """Create MiniCPMODuplex from an existing MiniCPMO instance."""
+ # Create instance without calling __init__
+ instance = cls.__new__(cls)
+
+ instance.name_or_path = getattr(model.config, "_name_or_path", "")
+
+ # Get default params helper
+ def get_param(name):
+ if name in kwargs:
+ return kwargs[name]
+ return cls._default_duplex_params.get(name)
+
+ instance.generate_audio = get_param("generate_audio")
+ instance.ls_mode = get_param("ls_mode")
+
+ # Determine device
+ if device is not None:
+ instance.device = device
+ else:
+ try:
+ instance.device = str(next(model.parameters()).device)
+ except StopIteration:
+ instance.device = "cuda"
+
+ # Reuse the existing model - THIS IS THE KEY: no reloading!
+ instance.model = model
+ instance.processor = getattr(model, "processor", None)
+ instance.tokenizer = getattr(instance.processor, "tokenizer", None) if instance.processor else None
+
+ if instance.tokenizer is None:
+ from transformers import AutoTokenizer
+
+ instance.tokenizer = AutoTokenizer.from_pretrained(instance.name_or_path, trust_remote_code=True)
+
+ if instance.processor is None:
+ from .processing_minicpmo import MiniCPMOProcessor
+
+ instance.processor = MiniCPMOProcessor.from_pretrained(instance.name_or_path, trust_remote_code=True)
+ instance.processor.tokenizer = instance.tokenizer
+
+ # Ensure model has processor reference (same as __init__)
+ instance.model.processor = instance.processor
+
+ # Initialize TTS (same as __init__)
+ enable_float16 = get_param("enable_float16")
+ n_timesteps = get_param("n_timesteps")
+ instance.model.init_tts(enable_float16=enable_float16, n_timesteps=n_timesteps)
+
+ instance.break_event = threading.Event()
+ instance.session_stop_event = threading.Event()
+
+ # LLM generation config
+ instance.max_new_speak_tokens_per_chunk = get_param("max_new_speak_tokens_per_chunk")
+ instance.text_repetition_penalty = get_param("text_repetition_penalty")
+ instance.temperature = get_param("temperature")
+ instance.top_k = get_param("top_k")
+ instance.top_p = get_param("top_p")
+ instance.text_repetition_window_size = get_param("text_repetition_window_size")
+ instance.listen_prob_scale = get_param("listen_prob_scale")
+ instance.force_listen_count = get_param("force_listen_count")
+
+ # TTS generation config
+ tts_temp_value = get_param("tts_temperature")
+ instance.tts_temperature = torch.tensor([tts_temp_value], dtype=torch.float, device=instance.device)
+ instance.tts_repetition_penalty = get_param("tts_repetition_penalty")
+
+ # Stream config
+ instance.CHUNK_MS = get_param("chunk_ms")
+ instance.FIRST_CHUNK_MS = get_param("first_chunk_ms")
+ instance.CNN_REDUNDANCY_MS = get_param("cnn_redundancy_ms")
+ instance.SAMPLE_RATE = get_param("sample_rate")
+
+ instance.model.CHUNK_MS = instance.CHUNK_MS
+ instance.model.FIRST_CHUNK_MS = instance.FIRST_CHUNK_MS
+ instance.model.CNN_REDUNDANCY_MS = instance.CNN_REDUNDANCY_MS
+ instance.model.SAMPLE_RATE = instance.SAMPLE_RATE
+
+ # Special tokens
+ instance.unit_token_id = instance.tokenizer.convert_tokens_to_ids("")
+ instance.image_start_token_id = instance.tokenizer.convert_tokens_to_ids("")
+ instance.image_end_token_id = instance.tokenizer.convert_tokens_to_ids("")
+ instance.slice_start_token_id = instance.tokenizer.convert_tokens_to_ids("")
+ instance.slice_end_token_id = instance.tokenizer.convert_tokens_to_ids("")
+
+ instance.listen_token_id = instance.tokenizer.convert_tokens_to_ids("<|listen|>")
+ instance.speak_token_id = instance.tokenizer.convert_tokens_to_ids("<|speak|>")
+ instance.tts_bos_token_id = instance.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
+ instance.tts_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
+
+ instance.chunk_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|chunk_eos|>")
+ instance.chunk_tts_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>")
+ instance.turn_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|turn_eos|>")
+
+ instance.chunk_terminator_token_ids = [
+ instance.listen_token_id,
+ instance.chunk_eos_token_id,
+ instance.chunk_tts_eos_token_id,
+ ]
+ instance.turn_terminator_token_ids = [instance.turn_eos_token_id]
+ instance.chunk_speak_token_ids = [instance.speak_token_id]
+
+ instance.tts_pad_id = instance.tokenizer.convert_tokens_to_ids("<|tts_pad|>")
+ bad_token_ids = getattr(instance.tokenizer, "bad_token_ids", [])
+ instance.forbidden_token_ids = [instance.tts_pad_id] + list(bad_token_ids)
+
+ from .utils import StreamDecoder
+
+ instance.decoder = StreamDecoder(
+ llm=instance.model.llm, tokenizer=instance.tokenizer, forbidden_token_ids=instance.forbidden_token_ids
+ )
+
+ # Sliding window config
+ sliding_window_mode = get_param("sliding_window_mode")
+ basic_window_high_tokens = get_param("basic_window_high_tokens")
+ basic_window_low_tokens = get_param("basic_window_low_tokens")
+ context_previous_max_tokens = get_param("context_previous_max_tokens")
+ context_max_units = get_param("context_max_units")
+
+ instance.decoder.set_window_config(
+ DuplexWindowConfig(
+ sliding_window_mode=sliding_window_mode,
+ basic_window_high_tokens=basic_window_high_tokens,
+ basic_window_low_tokens=basic_window_low_tokens,
+ context_previous_max_tokens=context_previous_max_tokens,
+ context_max_units=context_max_units,
+ )
+ )
+ window_enabled = sliding_window_mode != "off"
+ instance.decoder.set_window_enabled(window_enabled)
+
+ instance.tts_logits_processors = None
+ instance.tts_eos_token = None
+ if instance.generate_audio:
+ instance.tts_logits_processors = gen_logits(
+ num_code=instance.model.tts.config.num_audio_tokens,
+ repetition_penalty=instance.tts_repetition_penalty,
+ )
+ instance.tts_eos_token = torch.tensor(
+ [instance.model.tts.config.num_audio_tokens - 1],
+ dtype=torch.long,
+ device=instance.device,
+ )
+
+ instance._reset_streaming_state()
+
+ return instance
+
+ def set_break_event(self):
+ self.break_event.set()
+
+ def clear_break_event(self):
+ self.break_event.clear()
+
+ def set_session_stop(self):
+ self.session_stop_event.set()
+ self.break_event.set()
+
+ def clear_session_stop(self):
+ self.session_stop_event.clear()
+
+ def is_break_set(self) -> bool:
+ return self.break_event.is_set()
+
+ def is_session_stop_set(self) -> bool:
+ return self.session_stop_event.is_set()
+
+ def _init_token2wav_cache(self, prompt_wav_path: str):
+ self.model.tts.audio_tokenizer.cache = None
+ flow_cache, hift_cache = self.model.tts.audio_tokenizer.set_stream_cache(prompt_wav_path)
+ self.flow_cache_base = torch_clone_recursive(flow_cache)
+ self.hift_cache_base = torch_clone_recursive(hift_cache)
+ self.pre_lookahead = int(self.model.tts.audio_tokenizer.flow.pre_lookahead_len)
+ self.token2wav_initialized = True
+
+ def _reset_token2wav_for_new_turn(self):
+ if self.token2wav_initialized:
+ self.model.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.flow_cache_base)
+ self.model.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(self.hift_cache_base)
+ self.token2wav_buffer = [4218] * 3 # silence token prefix
+
+ def _reset_streaming_state(self):
+ self.audio_chunk_idx = 0
+ self.current_turn_ended = True
+ self.speak_count = 0
+ self.res_ids = []
+ self.total_ids = []
+ self.total_hidden = []
+
+ # TTS state
+ self.tts_text_start_pos = 0
+ self.tts_past_key_values = None
+ self.tts_current_turn_start_time = None
+
+ # token2wav state
+ self.token2wav_initialized = False
+ self.token2wav_buffer = []
+ self.flow_cache_base = None
+ self.hift_cache_base = None
+
+ # Audio prefill state
+ self.audio_buffer = np.array([], dtype=np.float32)
+ self.pending_logits: Optional[torch.Tensor] = None
+ self.current_mode: Optional[str] = None
+
+ # Force listen state
+ self._streaming_generate_count = 0
+
+ # Schema tracking: record the complete prefill + generate token sequence
+ # prefill_schema_tokens: each element is a list of prefill tokens for a unit
+ # format: [[unit0_prefill_tokens], [unit1_prefill_tokens], ...]
+ self.prefill_schema_tokens = []
+ self._current_unit_prefill_tokens = []
+
+ def prepare(
+ self,
+ prefix_system_prompt: Optional[str] = None,
+ ref_audio: Optional[np.ndarray] = None,
+ prompt_wav_path: Optional[str] = None,
+ context_previous_marker: str = "\n\nprevious: ",
+ **kwargs,
+ ):
+ prefix_system_prompt = prefix_system_prompt or "Streaming Omni Conversation."
+
+ prefix_system_prompt = "<|im_start|>system\n" + prefix_system_prompt
+ suffix_system_prompt = "<|im_end|>"
+ if isinstance(ref_audio, np.ndarray):
+ prefix_system_prompt += "\n<|audio_start|>"
+ suffix_system_prompt = "<|audio_end|>" + suffix_system_prompt
+
+ self.clear_break_event()
+ self.clear_session_stop()
+
+ self._reset_streaming_state()
+ self.decoder.reset()
+
+ self.model.init_streaming_processor()
+
+ if prompt_wav_path is not None and prompt_wav_path and self.generate_audio:
+ self._init_token2wav_cache(prompt_wav_path)
+ self._reset_token2wav_for_new_turn()
+
+ # Prefill system prompt prefix
+ if prefix_system_prompt:
+ tokens = self.tokenizer.encode(prefix_system_prompt, add_special_tokens=False)
+ for token_id in tokens:
+ self.decoder.feed(self.decoder.embed_token(token_id))
+
+ # Prefill reference audio
+ if ref_audio is not None:
+ data = self.processor.process_audio([ref_audio])
+ embeds_nested = self.model.get_audio_embedding(data, chunk_length=self.model.config.audio_chunk_length)
+ embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) if embeds_nested else None
+ if embeds is not None:
+ self.decoder.feed(embeds)
+
+ # register system prompt protection length (protect this part from being removed when sliding window is enabled)
+ if prefix_system_prompt or suffix_system_prompt or ref_audio is not None:
+ if self.decoder._window_config.sliding_window_mode == "context":
+ # Context preserve mode:
+ # initial layout: [prefix] [suffix] [units...]
+ # after the first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...]
+ # register prefix length first, then feed suffix
+ self._prefix_system_prompt = prefix_system_prompt
+ self._suffix_system_prompt = suffix_system_prompt
+ self._ref_audio = ref_audio
+
+ suffix_token_ids = []
+ if suffix_system_prompt:
+ suffix_token_ids = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False)
+
+ # register (when cache only has prefix, no suffix, no previous)
+ self.decoder.register_system_prompt_with_context(
+ suffix_token_ids=suffix_token_ids,
+ context_previous_marker=context_previous_marker, # dynamically added after the first sliding window
+ )
+
+ # now feed suffix
+ for token_id in suffix_token_ids:
+ self.decoder.feed(self.decoder.embed_token(token_id))
+ else:
+ # non-context preserve mode: first feed suffix, then register total length
+ if suffix_system_prompt:
+ tokens = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False)
+ for token_id in tokens:
+ self.decoder.feed(self.decoder.embed_token(token_id))
+ self.decoder.register_system_prompt()
+
+ if prefix_system_prompt or suffix_system_prompt:
+ if ref_audio is not None:
+ full_prompt = (prefix_system_prompt or "") + "[audio embedding]" + (suffix_system_prompt or "")
+ else:
+ full_prompt = (prefix_system_prompt or "") + (suffix_system_prompt or "")
+
+ return full_prompt
+
+ return ""
+
+ @torch.no_grad()
+ def streaming_prefill(
+ self,
+ audio_waveform: Optional[np.ndarray] = None,
+ frame_list: Optional[list] = None,
+ text_list: Optional[list] = None,
+ max_slice_nums: Union[int, List[int]] = 1,
+ batch_vision_feed: bool = False,
+ ):
+ """Streaming prefill - called once per second, processing audio/video data
+
+ Args:
+ audio_waveform: audio waveform data
+ frame_list: image frame list
+ text_list: text
+ max_slice_nums: maximum number of slices for HD image encoding (default 1, no slicing)
+ Can be an int (same for all images) or a list matching frame_list length
+ batch_vision_feed: if True, batch all vision embeddings into a single feed call for better performance.
+ if False (default), feed each embedding individually (original behavior).
+
+ Process:
+ 0. determine mode based on input: AUDIO / VISION / OMNI
+ 1. feed token
+ 2. get and feed image embed (if frame_list) - return pending logits in VISION MODE
+ 3. get and feed audio embed (if audio_waveform) - return pending logits in AUDIO/OMNI MODE
+
+ Returns:
+ dict with keys:
+ - success: bool
+ - cost_vision_process: float (image processing time)
+ - cost_vision_embed: float (vision embedding time)
+ - cost_vision_feed: float (vision feed time)
+ - cost_audio_process: float (audio processing time)
+ - cost_audio_embed: float (audio embedding time)
+ - cost_audio_feed: float (audio feed time)
+ - cost_all: float (total time)
+ """
+ start_time = time.time()
+ cost_vision_process = 0.0
+ cost_vision_embed = 0.0
+ cost_vision_feed = 0.0
+ cost_audio_process = 0.0
+ cost_audio_embed = 0.0
+ cost_audio_feed = 0.0
+
+ def _make_result(success, reasons=""):
+ reason = reasons
+ if isinstance(reasons, list):
+ reason = "; ".join(reasons)
+
+ return {
+ "success": success,
+ "reason": reason,
+ "cost_vision_process": cost_vision_process,
+ "cost_vision_embed": cost_vision_embed,
+ "cost_vision_feed": cost_vision_feed,
+ "cost_audio_process": cost_audio_process,
+ "cost_audio_embed": cost_audio_embed,
+ "cost_audio_feed": cost_audio_feed,
+ "cost_all": time.time() - start_time,
+ }
+
+ if self.is_session_stop_set() or self.is_break_set():
+ return _make_result(False)
+
+ has_frames = frame_list is not None and len(frame_list) > 0
+ has_audio = audio_waveform is not None and len(audio_waveform) > 0
+ has_text = text_list is not None and len(text_list) > 0
+
+ if has_frames and has_audio:
+ mode = "OMNI"
+ elif has_frames:
+ mode = "VISION"
+ elif has_audio:
+ mode = "AUDIO"
+ elif has_text:
+ mode = "TEXT"
+ else:
+ return _make_result(False)
+
+ self.pending_logits = None
+
+ # sliding window: record unit start position
+ self.decoder.register_unit_start()
+
+ # Schema tracking: start new unit, record prefill tokens
+ self._current_unit_prefill_tokens = []
+
+ # Step 1: Feed token
+ self.decoder.feed(self.decoder.embed_token(self.unit_token_id))
+ self._current_unit_prefill_tokens.append(self.unit_token_id)
+
+ # Step 2: process image
+ if has_frames:
+ t0 = time.time()
+
+ # normalize max_slice_nums to a list matching frame_list length
+ if isinstance(max_slice_nums, int):
+ max_slice_nums_list = [max_slice_nums] * len(frame_list)
+ else:
+ max_slice_nums_list = list(max_slice_nums)
+ if len(max_slice_nums_list) != len(frame_list):
+ raise ValueError(
+ f"max_slice_nums list length ({len(max_slice_nums_list)}) "
+ f"must match frame_list length ({len(frame_list)})"
+ )
+
+ # check if all max_slice_nums are the same (can use batch processing)
+ all_same = len(set(max_slice_nums_list)) == 1
+
+ if all_same:
+ # all images use the same max_slice_nums, use batch processing
+ processed_frames = self.processor.process_image(frame_list, max_slice_nums=max_slice_nums_list[0])
+ if self.device:
+ processed_frames = processed_frames.to(self.device)
+ else:
+ # different max_slice_nums per image, process individually and merge
+ all_pixel_values = []
+ all_tgt_sizes = []
+ for frame, max_slices in zip(frame_list, max_slice_nums_list):
+ pf = self.processor.process_image([frame], max_slice_nums=max_slices)
+ if self.device:
+ pf = pf.to(self.device)
+ # pf["pixel_values"][0] is the list of slices for this image
+ all_pixel_values.extend(pf["pixel_values"][0])
+ # pf["tgt_sizes"][0] is the array of target sizes for this image's slices
+ if hasattr(pf["tgt_sizes"][0], "tolist"):
+ all_tgt_sizes.extend(pf["tgt_sizes"][0].tolist())
+ else:
+ all_tgt_sizes.extend(list(pf["tgt_sizes"][0]))
+
+ # reconstruct processed_frames with merged data
+ processed_frames = {
+ "pixel_values": [all_pixel_values],
+ "tgt_sizes": [torch.tensor(all_tgt_sizes) if all_tgt_sizes else []],
+ }
+
+ cost_vision_process = time.time() - t0
+
+ t0 = time.time()
+ # get vision embeddings for all images (each may have multiple slices)
+ # vision_hidden_states is a list, one entry per input image
+ # each entry contains embeddings for [source_image, slice_1, slice_2, ...]
+ vision_hidden_states = self.model.get_vision_embedding(processed_frames)
+ cost_vision_embed = time.time() - t0
+
+ if vision_hidden_states is not None and len(vision_hidden_states) > 0:
+ t0 = time.time()
+
+ # vision_hidden_states[0] contains ALL slices from ALL images (flattened)
+ # shape: [total_slices, 64, D] where total_slices = sum of slices across all images
+ # we need to know how many slices each image has to correctly group them
+
+ # calculate slice counts for each image using get_sliced_grid (lightweight, no actual slicing)
+ slice_counts = [] # e.g., [5, 9] means img1 has 5 slices (1 source + 4 HD), img2 has 9 slices
+ for frame_idx, frame in enumerate(frame_list):
+ max_slices = max_slice_nums_list[frame_idx]
+ if hasattr(frame, "size"):
+ # get_sliced_grid returns [M, N] grid or None if no slicing needed
+ # total images = 1 (source) + M * N (HD slices)
+ grid = self.processor.image_processor.get_sliced_grid(
+ frame.size, max_slices, nerver_split=False
+ )
+ if grid is not None:
+ slice_counts.append(1 + grid[0] * grid[1]) # 1 source + M*N slices
+ else:
+ slice_counts.append(1) # no slicing, only source image
+ else:
+ slice_counts.append(1) # default: single image, no slicing
+
+ # get the flattened embeddings tensor
+ # vision_hidden_states is a list with one element (the batch)
+ # vision_hidden_states[0] shape: [total_slices, 64, D]
+ all_embeds = vision_hidden_states[0]
+
+ # collect all feed operations first, then execute
+ # this allows us to identify the last token for VISION mode logits
+ feed_operations = [] # List of (embed, is_last_for_vision_mode, token_id_or_none)
+
+ embed_idx = 0 # current index in all_embeds
+ for img_idx, num_slices in enumerate(slice_counts):
+ if num_slices == 0:
+ continue
+
+ # the first embedding is always the source image (downsampled overview)
+ # Feed token
+ feed_operations.append(
+ (self.decoder.embed_token(self.image_start_token_id), False, self.image_start_token_id)
+ )
+ # Feed source image embedding (shape: [64, D]) - use None to indicate embedding
+ feed_operations.append((all_embeds[embed_idx], False, None))
+ # Feed token
+ feed_operations.append(
+ (self.decoder.embed_token(self.image_end_token_id), False, self.image_end_token_id)
+ )
+ embed_idx += 1
+
+ # remaining embeddings are HD slices (if num_slices > 1)
+ if num_slices > 1:
+ for slice_i in range(1, num_slices):
+ # Feed token
+ feed_operations.append(
+ (self.decoder.embed_token(self.slice_start_token_id), False, self.slice_start_token_id)
+ )
+ # Feed slice embedding (shape: [64, D])
+ feed_operations.append((all_embeds[embed_idx], False, None))
+ # Feed token
+ feed_operations.append(
+ (self.decoder.embed_token(self.slice_end_token_id), False, self.slice_end_token_id)
+ )
+ embed_idx += 1
+
+ # mark the last operation for VISION mode logits
+ if feed_operations:
+ feed_operations[-1] = (feed_operations[-1][0], True, feed_operations[-1][2])
+
+ # execute feed operations
+ if batch_vision_feed and feed_operations:
+ # batch mode: concatenate all embeddings and feed at once
+ # this reduces LLM forward passes from N to 1
+ #
+ # NOTE: batch mode may have slight numerical differences compared to for-loop mode
+ # due to floating-point precision in attention computation. This is expected behavior
+ # for causal attention with incremental vs batch computation.
+
+ all_embeds_list = []
+ for embed, is_last, token_id in feed_operations:
+ # ensure all embeddings have shape [L, H]
+ if embed.dim() == 1:
+ embed = embed.unsqueeze(0)
+ all_embeds_list.append(embed)
+
+ # concatenate all embeddings
+ # torch.cat requires consistent dtype; embeddings should already be same dtype
+ all_embeds_to_feed = torch.cat(all_embeds_list, dim=0) # [total_L, H]
+
+ if mode == "VISION":
+ # vision mode needs logits from the last token
+ self.pending_logits, _ = self.decoder.feed(all_embeds_to_feed, return_logits=True)
+ else:
+ # omni mode: just feed, wait for audio to get logits
+ self.decoder.feed(all_embeds_to_feed)
+
+ # schema tracking: record all token IDs and embedding markers
+ for embed, is_last, token_id in feed_operations:
+ if token_id is not None:
+ self._current_unit_prefill_tokens.append(token_id)
+ else:
+ embed_dim = embed.shape[0] if len(embed.shape) > 1 else 1
+ self._current_unit_prefill_tokens.append(("img", embed_dim))
+ else:
+ for embed, is_last, token_id in feed_operations:
+ if mode == "VISION" and is_last:
+ # get logits from the last token
+ self.pending_logits, _ = self.decoder.feed(embed, return_logits=True)
+ else:
+ self.decoder.feed(embed)
+ # schema tracking: record token ID or embedding marker
+ if token_id is not None:
+ self._current_unit_prefill_tokens.append(token_id)
+ else:
+ # use tuple to mark image embedding: ("img", dim)
+ embed_dim = embed.shape[0] if len(embed.shape) > 1 else 1
+ self._current_unit_prefill_tokens.append(("img", embed_dim))
+ # for omni mode, no pending logits needed here (wait for audio)
+
+ cost_vision_feed = time.time() - t0
+
+ # Step 3: process audio (if any)
+ if has_audio:
+ # accumulate audio to buffer
+ self.audio_buffer = np.concatenate([self.audio_buffer, audio_waveform])
+
+ # calculate required audio length
+ if self.audio_chunk_idx == 0:
+ required_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000)
+ if len(self.audio_buffer) < required_samples:
+ padding_samples = required_samples - len(self.audio_buffer)
+ padding = np.zeros(padding_samples, dtype=np.float32)
+ self.audio_buffer = np.concatenate([padding, self.audio_buffer])
+ else:
+ required_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000)
+
+ need_samples = self.processor.get_streaming_chunk_size()
+ if len(self.audio_buffer) < need_samples:
+ return _make_result(
+ False, f"audio not enough: need {need_samples} samples, only {len(self.audio_buffer)}"
+ )
+
+ audio_chunk = self.audio_buffer[:need_samples]
+
+ t0 = time.time()
+ batch_feature = self.processor.process_audio_streaming(
+ audio_chunk,
+ reset=False,
+ return_batch_feature=True,
+ )
+
+ if batch_feature is None or batch_feature.audio_features.shape[-1] == 0:
+ return _make_result(False, "streaming audio processing returned empty")
+
+ # metadata
+ batch_feature.chunk_idx = self.audio_chunk_idx
+ batch_feature.use_extra_context = True
+ batch_feature.prefix_extra_frames = 0 if self.audio_chunk_idx == 0 else 2
+ batch_feature.suffix_extra_frames = 2
+
+ batch_feature = batch_feature.to(self.device)
+ cost_audio_process = time.time() - t0
+
+ t0 = time.time()
+ embeds_nested = self.model.get_audio_embedding_streaming(
+ batch_feature,
+ use_extra_context=batch_feature.use_extra_context,
+ prefix_extra_frames=batch_feature.prefix_extra_frames,
+ suffix_extra_frames=batch_feature.suffix_extra_frames,
+ )
+ audio_embeds = torch.cat([t for g in embeds_nested for t in g], dim=0)
+ cost_audio_embed = time.time() - t0
+
+ t0 = time.time()
+ self.pending_logits, _ = self.decoder.feed(audio_embeds, return_logits=True)
+ cost_audio_feed = time.time() - t0
+
+ # schema tracking: use tuple to mark audio embedding: ("audio", dim)
+ embed_dim = audio_embeds.shape[0] if len(audio_embeds.shape) > 1 else 1
+ self._current_unit_prefill_tokens.append(("audio", embed_dim))
+
+ if self.audio_chunk_idx == 0:
+ cfg = self.processor._streaming_mel_processor.get_config()
+ consumed_ms = int(cfg.get("effective_first_chunk_ms", self.FIRST_CHUNK_MS))
+ consumed_samples = int(consumed_ms * self.SAMPLE_RATE / 1000)
+ else:
+ consumed_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000)
+
+ self.audio_buffer = self.audio_buffer[consumed_samples:]
+
+ self.audio_chunk_idx += 1
+
+ # Step 4: process text
+ if has_text:
+ # concatenate all text items
+ text_content = "".join(text_list) if isinstance(text_list, list) else str(text_list)
+
+ # tokenize text
+ text_token_ids = self.tokenizer.encode(text_content, add_special_tokens=False)
+
+ if len(text_token_ids) > 0:
+ # get token embeddings
+ text_token_ids_tensor = torch.tensor(text_token_ids, dtype=torch.long, device=self.device)
+ text_embeds = self.decoder.embed_token(text_token_ids_tensor)
+
+ # feed to decoder
+ if mode == "TEXT":
+ # text-only mode: get logits from the last token
+ self.pending_logits, _ = self.decoder.feed(text_embeds, return_logits=True)
+ else:
+ # mixed mode: just feed, let other modality get logits
+ self.decoder.feed(text_embeds)
+
+ # schema tracking: record text token IDs
+ for token_id in text_token_ids:
+ self._current_unit_prefill_tokens.append(token_id)
+
+ self.current_mode = mode
+
+ if mode == "VISION":
+ self.audio_chunk_idx += 1
+
+ # schema tracking: save current unit's prefill tokens
+ self.prefill_schema_tokens.append(self._current_unit_prefill_tokens)
+
+ return _make_result(True)
+
+ @torch.no_grad()
+ def streaming_generate(
+ self,
+ prompt_wav_path=None,
+ max_new_speak_tokens_per_chunk=20,
+ decode_mode: str = "sampling",
+ temperature=0.7,
+ top_k=100,
+ top_p=0.8,
+ listen_prob_scale=1.0,
+ listen_top_k=None,
+ text_repetition_penalty=1.05,
+ text_repetition_window_size=512,
+ ):
+ start_time = time.time()
+
+ if self.is_session_stop_set() or self.is_break_set():
+ return {
+ "is_listen": True,
+ "text": "",
+ "audio_waveform": self._generate_silence_waveform(),
+ "end_of_turn": True,
+ "current_time": self.audio_chunk_idx,
+ "cost_llm": 0.0,
+ "cost_tts_prep": 0.0,
+ "cost_tts": 0.0,
+ "cost_token2wav": 0.0,
+ "cost_all": time.time() - start_time,
+ "n_tokens": 0,
+ "n_tts_tokens": 0,
+ }
+
+ # check if there are pending logits to process
+ if not hasattr(self, "pending_logits") or self.pending_logits is None:
+ return {
+ "is_listen": True,
+ "text": "",
+ "audio_waveform": self._generate_silence_waveform(),
+ "end_of_turn": False,
+ "current_time": self.audio_chunk_idx,
+ "cost_llm": 0.0,
+ "cost_tts_prep": 0.0,
+ "cost_tts": 0.0,
+ "cost_token2wav": 0.0,
+ "cost_all": time.time() - start_time,
+ "n_tokens": 0,
+ "n_tts_tokens": 0,
+ }
+
+ # use pending logits generated in streaming_prefill
+ logits = self.pending_logits
+ self.pending_logits = None
+
+ # Force listen: check if we should force listen for first N calls
+ force_listen = self._streaming_generate_count < self.force_listen_count
+ self._streaming_generate_count += 1
+
+ total_hidden_in_unit = []
+ total_ids_in_unit = []
+ current_time = self.audio_chunk_idx
+ is_listen = False
+ end_of_turn = False
+
+ llm_start_time = time.time()
+
+ for j in range(max_new_speak_tokens_per_chunk):
+ if j == max_new_speak_tokens_per_chunk - 1:
+ if self.ls_mode == "explicit":
+ self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id))
+ self.total_ids.append(self.chunk_eos_token_id)
+ break
+
+ if force_listen:
+ last_id = torch.tensor([self.listen_token_id], dtype=torch.long, device=self.device)
+ else:
+ last_id = self.decoder.decode(
+ logits=logits,
+ mode=decode_mode,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ listen_top_k=listen_top_k,
+ listen_prob_scale=listen_prob_scale,
+ text_repetition_penalty=text_repetition_penalty,
+ text_repetition_window_size=text_repetition_window_size,
+ )
+
+ # if current turn not ended, not allowed to listen (only check when not force_listen)
+ if last_id.item() == self.listen_token_id and (not self.current_turn_ended):
+ last_id = torch.tensor([self.tts_bos_token_id], dtype=torch.long, device=self.device)
+
+ self.total_ids.append(last_id.item())
+
+ is_listen = last_id.item() == self.listen_token_id
+
+ # termination condition detection
+ if last_id.item() in self.chunk_terminator_token_ids:
+ if self.ls_mode == "explicit":
+ logits, _ = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
+ break
+ else:
+ # normal speak
+ self.current_turn_ended = False
+
+ if last_id.item() in self.chunk_speak_token_ids:
+ pass
+ else:
+ self.res_ids.append(last_id.item())
+ self.speak_count += 1
+
+ logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
+
+ assert len(hidden.shape) == 3
+ assert hidden.shape[0] == 1
+ assert hidden.shape[1] == 1
+
+ end_of_turn = last_id.item() in self.turn_terminator_token_ids
+
+ if end_of_turn:
+ self.current_turn_ended = True
+
+ if j != 0:
+ total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn])
+ total_ids_in_unit.append(last_id.item())
+
+ # Prefill token
+ unit_end_id = self.tokenizer.convert_tokens_to_ids("")
+ self.decoder.feed(self.decoder.embed_token(unit_end_id))
+ self.total_ids.append(unit_end_id)
+
+ # calculate generated text (for sliding window context preserve, filter out special tokens)
+ generated_text = self.tokenizer.decode(total_ids_in_unit, skip_special_tokens=True) if total_ids_in_unit else ""
+
+ # sliding window: register unit end, and check if sliding window is needed
+ input_type = self.current_mode.lower() if self.current_mode else "audio"
+
+ self.decoder.register_unit_end(
+ input_type=input_type,
+ generated_tokens=total_ids_in_unit,
+ is_listen=is_listen,
+ generated_text=generated_text,
+ )
+ # select sliding window method based on sliding window mode
+ if self.decoder._window_config.sliding_window_mode == "context":
+ self.decoder.enforce_window_with_context()
+ elif self.decoder._window_config.sliding_window_mode == "basic":
+ self.decoder.enforce_window()
+
+ llm_end_time = time.time()
+
+ if is_listen:
+ self.total_hidden.append([])
+ return {
+ "is_listen": True,
+ "text": "",
+ "audio_waveform": self._generate_silence_waveform(),
+ "end_of_turn": False,
+ "current_time": current_time,
+ "cost_llm": llm_end_time - llm_start_time,
+ "cost_tts_prep": 0.0,
+ "cost_tts": 0.0,
+ "cost_token2wav": 0.0,
+ "cost_all": time.time() - start_time,
+ "n_tokens": len(total_ids_in_unit),
+ "n_tts_tokens": 0,
+ }
+
+ self.total_hidden.append(total_hidden_in_unit)
+ text = generated_text # reuse already calculated text
+
+ if not self.generate_audio:
+ return {
+ "is_listen": False,
+ "text": text,
+ "audio_waveform": None,
+ "end_of_turn": end_of_turn,
+ "current_time": current_time,
+ "cost_llm": llm_end_time - llm_start_time,
+ "cost_tts_prep": 0.0,
+ "cost_tts": 0.0,
+ "cost_token2wav": 0.0,
+ "cost_all": time.time() - start_time,
+ "n_tokens": len(total_ids_in_unit),
+ "n_tts_tokens": 0,
+ }
+
+ # TTS generate
+ tts_start_time = time.time()
+ tts_prep_start_time = time.time()
+ tts_condition = self._convert_results_to_tts_input(total_hidden_in_unit)
+ tts_prep_end_time = time.time()
+
+ max_token_per_chunk = 25 + 1
+ min_token_per_chunk = 25 + 1
+
+ if end_of_turn:
+ min_token_per_chunk = 0
+ force_flush = False
+ if self.tts_text_start_pos == 0: # this is the start of the turn
+ min_token_per_chunk = 0 # allow decoding <1s audio
+ force_flush = True
+
+ if self.tts_current_turn_start_time is None:
+ self.tts_current_turn_start_time = current_time
+
+ new_tokens, old_kv = self.model.tts.generate_chunk(
+ inputs_embeds=tts_condition,
+ temperature=self.tts_temperature,
+ repetition_penalty=self.tts_repetition_penalty,
+ eos_token=self.tts_eos_token,
+ force_no_stop=False,
+ max_new_token=max_token_per_chunk,
+ min_new_tokens=min_token_per_chunk,
+ past_key_values=self.tts_past_key_values,
+ logits_processors=self.tts_logits_processors,
+ text_start_pos=self.tts_text_start_pos,
+ )
+
+ tts_end_time = time.time()
+
+ # update TTS state (note: token2wav reset must be after audio generation, otherwise tokens in buffer will be lost)
+ if end_of_turn:
+ self.tts_text_start_pos = 0
+ self.tts_past_key_values = None
+ self.tts_current_turn_start_time = None
+ else:
+ self.tts_past_key_values = old_kv
+ self.tts_text_start_pos += tts_condition.shape[1] + new_tokens.shape[1]
+
+ # token2wav generation (must be before reset, otherwise tokens in the last but second chunk will be lost)
+ token2wav_start_time = time.time()
+ audio_waveform = self._generate_waveform_from_tokens(
+ new_tokens, prompt_wav_path, end_of_turn, force_flush=force_flush
+ )
+ token2wav_end_time = time.time()
+
+ # reset token2wav state after audio generation, ensure all tokens in buffer are processed
+ if end_of_turn:
+ self._reset_token2wav_for_new_turn()
+
+ end_time = time.time()
+
+ return {
+ "is_listen": False,
+ "text": text,
+ "audio_waveform": audio_waveform,
+ "end_of_turn": end_of_turn,
+ "current_time": current_time,
+ "cost_llm": llm_end_time - llm_start_time,
+ "cost_tts_prep": tts_prep_end_time - tts_prep_start_time,
+ "cost_tts": tts_end_time - tts_start_time,
+ "cost_token2wav": token2wav_end_time - token2wav_start_time,
+ "cost_all": end_time - start_time,
+ "n_tokens": len(total_ids_in_unit),
+ "n_tts_tokens": new_tokens.numel(),
+ }
+
+ def get_session_schema(self, include_embeddings: bool = True) -> str:
+ """get complete schema for current session (includes prefill and generate stages)
+
+ Args:
+ include_embeddings: whether to include embedding placeholders (e.g. [img_embed_64], [audio_embed_50])
+
+ Returns:
+ complete schema string, each unit format:
+ [img_embed_64][audio_embed_50]<|listen|or|speak|>generated_content
+ """
+ if not hasattr(self, "prefill_schema_tokens") or not hasattr(self, "total_ids"):
+ return ""
+
+ # get token id for splitting generate tokens
+ unit_end_token_id = self.tokenizer.convert_tokens_to_ids("")
+
+ # split generate tokens into each unit
+ generate_units = []
+ current_unit = []
+ for tid in self.total_ids:
+ current_unit.append(tid)
+ if tid == unit_end_token_id:
+ generate_units.append(current_unit)
+ current_unit = []
+
+ # build complete schema
+ full_schema_parts = []
+ num_units = max(len(self.prefill_schema_tokens), len(generate_units))
+
+ for unit_idx in range(num_units):
+ unit_schema = ""
+
+ # prefill part
+ if unit_idx < len(self.prefill_schema_tokens):
+ prefill_tokens = self.prefill_schema_tokens[unit_idx]
+ for item in prefill_tokens:
+ if isinstance(item, tuple):
+ # tuple represents embedding: ("img", dim) or ("audio", dim)
+ embed_type, embed_dim = item
+ if include_embeddings:
+ unit_schema += f"[{embed_type}_embed_{embed_dim}]"
+ else:
+ # normal token ID
+ unit_schema += self.tokenizer.decode([item], skip_special_tokens=False)
+
+ # generate part
+ if unit_idx < len(generate_units):
+ unit_schema += self.tokenizer.decode(generate_units[unit_idx], skip_special_tokens=False)
+
+ full_schema_parts.append(unit_schema)
+
+ return "".join(full_schema_parts)
+
+ def get_unit_schemas(self, include_embeddings: bool = True) -> list:
+ """get list of schema for each unit
+
+ Returns:
+ list of schema strings for each unit
+ """
+ if not hasattr(self, "prefill_schema_tokens") or not hasattr(self, "total_ids"):
+ return []
+
+ unit_end_token_id = self.tokenizer.convert_tokens_to_ids("")
+
+ # split generate tokens into each unit
+ generate_units = []
+ current_unit = []
+ for tid in self.total_ids:
+ current_unit.append(tid)
+ if tid == unit_end_token_id:
+ generate_units.append(current_unit)
+ current_unit = []
+
+ # build schema for each unit
+ unit_schemas = []
+ num_units = max(len(self.prefill_schema_tokens), len(generate_units))
+
+ for unit_idx in range(num_units):
+ unit_schema = ""
+
+ # prefill part
+ if unit_idx < len(self.prefill_schema_tokens):
+ prefill_tokens = self.prefill_schema_tokens[unit_idx]
+ for item in prefill_tokens:
+ if isinstance(item, tuple):
+ # tuple represents embedding: ("img", dim) or ("audio", dim)
+ embed_type, embed_dim = item
+ if include_embeddings:
+ unit_schema += f"[{embed_type}_embed_{embed_dim}]"
+ else:
+ # normal token ID
+ unit_schema += self.tokenizer.decode([item], skip_special_tokens=False)
+
+ # generate part
+ if unit_idx < len(generate_units):
+ unit_schema += self.tokenizer.decode(generate_units[unit_idx], skip_special_tokens=False)
+
+ unit_schemas.append(unit_schema)
+
+ return unit_schemas
+
+ def _convert_results_to_tts_input(self, results):
+ """convert LLM hidden states to TTS input"""
+ if len(results) == 0:
+ audio_bos = self.model.tts.emb_text(
+ torch.tensor(
+ [self.model.tts.audio_bos_token_id],
+ device=self.model.tts.emb_text.weight.device,
+ dtype=torch.long,
+ )
+ )
+ return audio_bos.unsqueeze(0)
+
+ llm_tokens = []
+ llm_hidden = []
+ for hidden in results:
+ llm_tokens.append(hidden[0])
+ llm_hidden.append(hidden[1].squeeze(0))
+
+ llm_tokens_tensor = torch.Tensor(llm_tokens).to(self.device, dtype=torch.long)
+ llm_embeds = self.model.tts.emb_text(llm_tokens_tensor)
+
+ llm_hidden_tensor = torch.cat(llm_hidden, dim=0)
+ llm_hidden_tensor = self.model.tts.projector_semantic(llm_hidden_tensor)
+ llm_hidden_tensor = torch.nn.functional.normalize(llm_hidden_tensor, p=2, dim=-1)
+
+ tts_embeds = llm_embeds + llm_hidden_tensor
+
+ audio_bos = self.model.tts.emb_text(
+ torch.tensor(
+ [self.model.tts.audio_bos_token_id],
+ device=self.model.tts.emb_text.weight.device,
+ dtype=torch.long,
+ )
+ )
+
+ tts_embeds = torch.cat([tts_embeds, audio_bos], dim=0)
+ return tts_embeds.unsqueeze(0)
+
+ def _generate_waveform_from_tokens(
+ self,
+ new_tokens: torch.Tensor,
+ prompt_wav_path: Optional[str],
+ is_last_chunk: bool = False,
+ force_flush: bool = False,
+ ) -> Optional[np.ndarray]:
+ if not self.token2wav_initialized:
+ logger.warning("token2wav_initialized is uninitialized")
+ return None
+
+ CHUNK_SIZE = 25
+
+ token_ids = torch.reshape(new_tokens, (-1,)).tolist()
+ self.token2wav_buffer += token_ids
+
+ has_chunk_eos = any(tid in self.chunk_terminator_token_ids for tid in token_ids)
+
+ pcm_bytes_list = []
+
+ # process enough tokens
+ # if there is chunk_eos, try to flush more content
+ if has_chunk_eos or force_flush:
+ # when there is chunk_eos, try to flush more content
+ while len(self.token2wav_buffer) >= self.pre_lookahead + 5: # at least keep some lookahead
+ chunk_to_process = min(CHUNK_SIZE + self.pre_lookahead, len(self.token2wav_buffer))
+ pcm_bytes = self.model.tts.audio_tokenizer.stream(
+ self.token2wav_buffer[:chunk_to_process],
+ prompt_wav=prompt_wav_path,
+ )
+ pcm_bytes_list.append(pcm_bytes)
+ self.token2wav_buffer = self.token2wav_buffer[min(CHUNK_SIZE, chunk_to_process - self.pre_lookahead) :]
+ else:
+ while len(self.token2wav_buffer) >= CHUNK_SIZE + self.pre_lookahead:
+ pcm_bytes = self.model.tts.audio_tokenizer.stream(
+ self.token2wav_buffer[: CHUNK_SIZE + self.pre_lookahead],
+ prompt_wav=prompt_wav_path,
+ )
+ pcm_bytes_list.append(pcm_bytes)
+ self.token2wav_buffer = self.token2wav_buffer[CHUNK_SIZE:]
+
+ # if is the last chunk, flush remaining tokens
+ if is_last_chunk and len(self.token2wav_buffer) > 0:
+ pcm_bytes = self.model.tts.audio_tokenizer.stream(
+ self.token2wav_buffer,
+ prompt_wav=prompt_wav_path,
+ last_chunk=True,
+ )
+ pcm_bytes_list.append(pcm_bytes)
+ self.token2wav_buffer = []
+
+ if not pcm_bytes_list:
+ return None
+
+ # merge PCM and convert to numpy array (24kHz, int16 -> float32)
+ all_pcm = b"".join(pcm_bytes_list)
+ if len(all_pcm) == 0:
+ return None
+
+ pcm_np = np.frombuffer(all_pcm, dtype=" np.ndarray:
+ """generate silence waveform (24kHz)"""
+ sample_rate = 24000
+ num_samples = int(duration_sec * sample_rate)
+ return np.zeros(num_samples, dtype=np.float32)
+
+ def get_generated_text(self) -> str:
+ return self.tokenizer.decode(self.res_ids)
+
+ def get_current_time(self) -> int:
+ return self.audio_chunk_idx
+
+ def as_simplex(self, reset_session: bool = True, reset_token2wav_cache: bool = False) -> "MiniCPMO":
+ """Convert this MiniCPMODuplex instance back to MiniCPMO for simplex mode.
+
+ Args:
+ reset_session: If True, reset streaming session state (KV cache, etc.).
+ Recommended when switching from duplex to simplex mode.
+
+ Returns the underlying MiniCPMO model instance without reloading.
+ """
+ if reset_session:
+ self.model.reset_session(reset_token2wav_cache=reset_token2wav_cache)
+ return self.model
+
+
+def get_2d_sincos_pos_embed(embed_dim, image_size):
+ """
+ image_size: image_size or (image_height, image_width)
+ return:
+ pos_embed: [image_height, image_width, embed_dim]
+ """
+ if isinstance(image_size, int):
+ grid_h_size, grid_w_size = image_size, image_size
+ else:
+ grid_h_size, grid_w_size = image_size[0], image_size[1]
+
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (H, W)
+ out: (H, W, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
+
+ emb_sin = np.sin(out) # (H, W, D/2)
+ emb_cos = np.cos(out) # (H, W, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
+ return emb
+
+
+class Resampler(nn.Module):
+ """
+ A 2D perceiver-resampler network with one cross attention layers by
+ given learnable queries and 2d sincos pos_emb
+ Outputs:
+ A tensor with the shape of (batch_size, num_queries, embed_dim)
+ """
+
+ def __init__(
+ self,
+ num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim=None,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ adaptive=False,
+ max_size=(70, 70),
+ ):
+ super().__init__()
+ self.num_queries = num_queries
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.adaptive = adaptive
+ self.max_size = max_size
+
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
+
+ if kv_dim is not None and kv_dim != embed_dim:
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
+ else:
+ self.kv_proj = nn.Identity()
+
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
+ self.ln_q = norm_layer(embed_dim)
+ self.ln_kv = norm_layer(embed_dim)
+
+ self.ln_post = norm_layer(embed_dim)
+ self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
+
+ self._set_2d_pos_cache(self.max_size)
+
+ def _set_2d_pos_cache(self, max_size, device="cpu"):
+ if is_deepspeed_zero3_enabled():
+ device = "cuda"
+ pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
+
+ def _adjust_pos_cache(self, tgt_sizes, device):
+ max_h = torch.max(tgt_sizes[:, 0])
+ max_w = torch.max(tgt_sizes[:, 1])
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
+ self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
+ self._set_2d_pos_cache(self.max_size, device)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x, tgt_sizes=None):
+ assert x.shape[0] == tgt_sizes.shape[0]
+ bs = x.shape[0]
+
+ device = x.device
+ dtype = x.dtype
+
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
+
+ self._adjust_pos_cache(tgt_sizes, device=device)
+
+ max_patch_len = torch.max(patch_len)
+ key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
+
+ pos_embed = []
+ for i in range(bs):
+ tgt_h, tgt_w = tgt_sizes[i]
+ pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
+ key_padding_mask[i, patch_len[i] :] = True
+
+ pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
+ 1, 0, 2
+ ) # BLD => L * B * D
+
+ x = self.kv_proj(x) # B * L * D
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
+
+ q = self.ln_q(self.query) # Q * D
+
+ out = self.attn(
+ self._repeat(q, bs), # Q * B * D
+ x + pos_embed, # L * B * D + L * B * D
+ x,
+ key_padding_mask=key_padding_mask,
+ )[0]
+ # out: Q * B * D
+ x = out.permute(1, 0, 2) # B * Q * D
+
+ x = self.ln_post(x)
+ x = x @ self.proj
+ return x
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class MiniCPMWhisperEncoderLayer(nn.Module):
+ def __init__(self, config: WhisperConfig, layer_idx: int = None):
+ super().__init__()
+ self.embed_dim = config.d_model
+ try:
+ # compatible old transformers
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
+
+ self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ except:
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
+
+ self.self_attn = WhisperAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ layer_idx=layer_idx,
+ )
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = False,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, past_key_values = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ past_key_value=past_key_values,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_values,)
+
+ return outputs
+
+
+# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
+class MiniCPMWhisperEncoder(WhisperEncoder):
+
+ def __init__(self, config: WhisperConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]
+ )
+
+ def forward(
+ self,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ use_extra_context: Optional[bool] = False,
+ prefix_extra_frames: Optional[int] = 1,
+ suffix_extra_frames: Optional[int] = 1,
+ cnn_min_length: Optional[int] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Ignore copy
+ input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
+
+ # Optional: pad short input to minimum length for CNN computation consistency
+ original_length = input_features.shape[2]
+ padded_for_cnn = False
+ if cnn_min_length is not None and original_length < cnn_min_length:
+ padded_features = torch.zeros(
+ input_features.shape[0],
+ input_features.shape[1],
+ cnn_min_length,
+ dtype=input_features.dtype,
+ device=input_features.device,
+ )
+ padded_features[:, :, :original_length] = input_features
+ input_features = padded_features
+ padded_for_cnn = True
+
+ conv1_output = self.conv1(input_features)
+ inputs_embeds = nn.functional.gelu(conv1_output)
+ conv2_output = self.conv2(inputs_embeds)
+ inputs_embeds = nn.functional.gelu(conv2_output)
+ # If padding was done before, now need to remove the effect of padding
+ if padded_for_cnn:
+ # Conv1: stride=1, output length=input length
+ # Conv2: stride=2, output length=(input length+1)//2
+ actual_cnn_output_length = (original_length + 1) // 2
+ inputs_embeds = inputs_embeds[:, :, :actual_cnn_output_length]
+
+ # If extra context is used, CNN operations need to remove redundant frames
+ # conv2 stride=2, so the redundant frames in the input will be halved (upward rounding)
+ if use_extra_context:
+ # Input has prefix_extra_frames prefix frames and suffix_extra_frames suffix frames
+ # conv2 stride=2, output length = ceil(input length / 2)
+ # For 2 redundant frames, the output is 1 frame (ceil(2/2) = 1)
+ prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0
+ suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0
+
+ # Remove redundant frames before and after (batch, channels, time)
+ if prefix_to_remove > 0:
+ inputs_embeds = inputs_embeds[:, :, prefix_to_remove:]
+ if 0 < suffix_to_remove < inputs_embeds.shape[2]:
+ inputs_embeds = inputs_embeds[:, :, :-suffix_to_remove]
+
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
+
+ embed_pos = self.embed_positions.weight
+ past_key_values_length = 0
+ if use_cache:
+ if past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
+ elif isinstance(past_key_values, list):
+ past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache())
+ elif isinstance(past_key_values, DynamicCache):
+ past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
+ else:
+ pass
+ past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1])
+ if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
+ logger.warning("seems the audio is longer than 30s. repeating the last part of the audio")
+ embed_pos_front = embed_pos[past_key_values_length:, :]
+ embed_pos = torch.cat(
+ (
+ embed_pos_front,
+ torch.repeat_interleave(
+ embed_pos[-1, :].unsqueeze(0),
+ inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length,
+ dim=0,
+ ),
+ )
+ )
+ else:
+ embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :]
+ else:
+ embed_pos = embed_pos[: inputs_embeds.shape[1], :]
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (
+ len(self.layers)
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ # Ignore copy
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ past_key_values,
+ use_cache,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_encoder_cache = layer_outputs[2 if output_attentions else 1]
+ else:
+ next_encoder_cache = None
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ result = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return result
+ result = BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ past_key_values=next_encoder_cache,
+ )
+
+ return result
+
+
+class MultiModalProjector(nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+ self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
+ self.relu = nn.ReLU()
+ self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
+
+ def forward(self, audio_features):
+ hidden_states = self.relu(self.linear1(audio_features))
+ hidden_states = self.linear2(hidden_states)
+ return hidden_states
+
+
+class MiniCPMMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.in_dim = config.llm_hidden_size
+ self.out_dim = config.hidden_size
+ self.intermediate_size = config.llm_intermediate_size
+ self.gate_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True)
+ self.up_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True)
+ self.down_proj = nn.Linear(self.intermediate_size, self.out_dim, bias=True)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+@dataclass
+class MiniCPMTTSGenerationOutput(ModelOutput):
+ """
+ Output class for MiniCPMTTS generation.
+
+ Args:
+ new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
+ audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
+ past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
+ finished (bool): Boolean indicating whether generation is complete.
+ """
+
+ new_ids: torch.LongTensor = None
+ audio_input_ids: torch.LongTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ past_input_ids: Optional[torch.LongTensor] = None
+ finished: bool = None
+
+
+def make_streaming_chunk_mask_inference(
+ tts_text_scope: List[int],
+ tts_text_mask: torch.Tensor,
+ streaming_audio_chunk_size: int = 50,
+ dtype: torch.dtype = torch.bfloat16,
+ device: torch.device = torch.device("cuda"),
+ max_sequence_length: int = 4096,
+):
+ """
+ Example:
+ Input sequence:
+ [t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...]
+ Output 4D causal mask:
+ ------- text positions -------
+ [0] <- here is [Stts]
+ [0, 0] <- here is [spk_emb] * N
+ [0, 0, 0]
+ [0, 0, 0, 0]
+ [0, 0, 0, 0, 0]
+ ------- audio positions --------
+ [0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token
+ v- here is [Ptts]
+ [0, 0, -inf, -inf, -inf, 0, 0]
+ [0, 0, -inf, -inf, -inf, 0, 0, 0]
+ [0, 0, -inf, -inf, -inf, 0, 0, 0, 0]
+ [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0]
+ [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk
+ [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0]
+ [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0]
+ [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ """
+
+ # Create a complete attention mask for input embeds [batch_size, seq_len], without considering audio mask as audio is always at the end
+
+ assert tts_text_mask.dtype == torch.int8
+
+ padding_mask = torch.ones(max_sequence_length, dtype=torch.int8, device=device)
+ padding_mask[tts_text_scope[0] : tts_text_scope[1]] = tts_text_mask
+
+ # Initialize a standard upper triangular causal mask
+ min_dtype = torch.finfo(dtype).min
+
+ causal_mask = torch.full(
+ (max_sequence_length, max_sequence_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if max_sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ else:
+ raise ValueError("max_sequence_length of tts could not be 1.")
+
+ # For each data sample
+ audio_token_start = tts_text_scope[1]
+ audio_duration = max_sequence_length - tts_text_scope[1]
+
+ # Record which text chunk the current audio chunk can see up to
+ text_pivot = 0
+ num_valid_text_tokens = torch.sum(tts_text_mask).item() - 1 # [Ptts] excluded
+ # How many audio chunks are in total, the num of buckets should be smaller as possible
+
+ num_text_tokens_per_audio_chunk = 10
+
+ # For each chunk of audio
+ for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)):
+ audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size
+ audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size
+ # New text seen by this new audio chunk
+ new_text_this_chunk = num_text_tokens_per_audio_chunk
+ # The right bound of visible text tokens
+ text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens)
+ # Mask all text chunks after the visible ones
+ # -> [text_pivot, len(tts_text_scope)-1] excluding [Ptts]
+ causal_mask[
+ audio_chunk_start - 1 : audio_chunk_end - 1,
+ # tts_text_scope[0] + text_pivot: tts_text_scope[1],
+ tts_text_scope[0] + text_pivot : tts_text_scope[1] - 1,
+ ] = min_dtype
+
+ # Mask the padding parts in tts_text_masks (no position will attend to it)
+ causal_mask[:, padding_mask == 0] = min_dtype
+
+ # Add extra dimensions, [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len]
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
+
+ return causal_mask
+
+
+class MiniCPMTTS(PreTrainedModel):
+ config_class = MiniCPMTTSConfig
+
+ def __init__(self, config: MiniCPMTTSConfig, audio_tokenizer: None):
+ super().__init__(config)
+
+ self.use_llm_hidden_state = config.use_llm_hidden_state
+
+ self.use_text = config.use_text
+ self.streaming = config.streaming
+ self.streaming_text_chunk_min = config.streaming_text_chunk_min
+ self.streaming_text_chunk_max = config.streaming_text_chunk_max
+ self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
+ self.streaming_text_reserved_len = config.streaming_text_reserved_len
+ # streaming tts
+ self.streaming_text_chunk_size = config.streaming_text_chunk_max
+ self.audio_bos_token_id = config.audio_bos_token_id
+ self.num_mel_bins = config.num_mel_bins
+ self.num_vq = config.num_vq
+ self.num_audio_tokens = config.num_audio_tokens
+
+ self.top_p = config.top_p
+ self.top_k = config.top_k
+ self.repetition_penalty = config.repetition_penalty
+
+ self.interleaved = config.interleaved
+ self.attention_type = config.attention_type
+ self.recomputed_chunks = config.recomputed_chunks
+
+ # Two different window size concepts:
+ # 1. chunk_window_size: number of chunks for sliding_recompute mode (default 2)
+ # 2. token_window_size: number of tokens for sliding_window mode (default 300)
+ self.chunk_window_size = config.window_size # chunk-level window for sliding_recompute
+ self.token_window_size = (
+ config.streaming_sliding_window_audio_window_size
+ ) # token-level window for sliding_window
+
+ # Legacy aliases (for backward compatibility with existing code)
+ self.window_size = self.chunk_window_size # used in generate_streaming for sliding_recompute
+ self.sliding_window_size = self.token_window_size # used in TTSStreamingGenerator for sliding_window
+
+ if self.attention_type == "sliding_recompute" and self.chunk_window_size <= self.recomputed_chunks:
+ raise ValueError(
+ f"sliding_recompute requires chunk_window_size > recomputed_chunks, "
+ f"but got chunk_window_size={self.chunk_window_size} and recomputed_chunks={self.recomputed_chunks}"
+ )
+
+ if config.backbone_model == "llama":
+ model_config = LlamaConfig(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ num_attention_heads=config.num_attention_heads,
+ num_hidden_layers=config.num_hidden_layers,
+ num_key_value_heads=config.num_key_value_heads,
+ max_position_embeddings=config.max_position_embeddings,
+ attn_implementation=config.attn_implementation,
+ )
+
+ self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
+
+ model = LlamaModel(model_config)
+ self.model = model
+ else:
+ raise ValueError(f"Unsupported backbone model: {config.backbone_model}")
+
+ self.projector_spk = self.create_projector(config)
+ self.projector_semantic = self.create_projector(config)
+
+ self.audio_tokenizer = audio_tokenizer
+
+ self.emb_code = nn.ModuleList(
+ [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)]
+ )
+
+ self.head_code = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
+ name="weight",
+ )
+ for _ in range(config.num_vq)
+ ]
+ )
+
+ self.condition_type = config.condition_type
+
+ return
+
+ @staticmethod
+ def create_projector(config):
+ if config.projector_type == "mlp":
+ return MultiModalProjector(config.llm_dim, config.hidden_size)
+ elif config.projector_type == "minicpm":
+ return MiniCPMMLP(config)
+ elif config.projector_type == "default":
+ return nn.Linear(config.llm_dim, config.hidden_size, bias=False)
+ else:
+ raise ValueError(f"Unsupported projector type: {config.projector_type}")
+
+ # non-streaming
+ @torch.inference_mode()
+ def generate(
+ self,
+ inputs_embeds: torch.Tensor,
+ eos_token: Union[int, torch.Tensor],
+ force_no_stop=False,
+ min_new_token=50,
+ max_new_token=2048,
+ show_tqdm=True,
+ streaming=False,
+ text_lengths=None,
+ sampling_params: TTSSamplingParams = TTSSamplingParams(),
+ ):
+ temperature = torch.tensor(
+ [sampling_params.temperature] * self.config.num_vq,
+ dtype=torch.float,
+ device=self.device,
+ )
+ temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
+ inputs_embeds.device
+ )
+
+ logits_warpers, logits_processors = gen_logits(
+ num_code=self.config.num_audio_tokens,
+ repetition_penalty=sampling_params.repetition_penalty,
+ top_p=sampling_params.top_p,
+ top_k=sampling_params.top_k,
+ )
+
+ # We only support batch size `1` for now
+ assert inputs_embeds.shape[0] == 1
+ eos_token = eos_token.to(inputs_embeds.device)
+ finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
+
+ condition_length = inputs_embeds.shape[1]
+ pbar: Optional[tqdm] = None
+ if show_tqdm:
+ pbar = tqdm(
+ total=max_new_token,
+ desc="code",
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
+ )
+
+ if streaming:
+ raise NotImplementedError("this kind of streaming is not supported yet")
+
+ new_tokens = torch.zeros(
+ inputs_embeds.shape[0],
+ max_new_token,
+ self.num_vq,
+ device=inputs_embeds.device,
+ dtype=torch.long,
+ )
+
+ past_key_values = None
+
+ for t in range(max_new_token):
+ audio_bos = False
+ # If this is the first audio token, the case is special
+ if t == 0:
+ audio_bos = True
+ inputs_embeds = inputs_embeds
+ position_ids = torch.tensor(
+ list(range(0, condition_length)),
+ dtype=torch.long,
+ device=self.device,
+ ).unsqueeze(0)
+
+ if streaming:
+ raise NotImplementedError("this kind of streaming is not supported yet")
+ else:
+ causal_mask_4d = None
+
+ else:
+ code_emb = []
+ for q in range(self.num_vq):
+ x = self.emb_code[q](new_tokens[:, t - 1 : t, q])
+ code_emb.append(x)
+
+ inputs_embeds = torch.stack(code_emb, 3).sum(3)
+
+ position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze(
+ 0
+ )
+
+ if streaming:
+ raise NotImplementedError("this kind of streaming is not supported yet")
+ else:
+ causal_mask_4d = None
+
+ if self.config.backbone_model == "llama":
+ outputs: BaseModelOutputWithPast = self.model(
+ position_ids=position_ids,
+ cache_position=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=causal_mask_4d,
+ use_cache=True,
+ output_attentions=False,
+ # return_dict=True, # Add this to ensure returns dict with past_key_values
+ )
+ else:
+ raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}")
+
+ del position_ids
+ del inputs_embeds
+
+ hidden_states = outputs.last_hidden_state
+ past_key_values = outputs.past_key_values
+
+ with P.cached():
+ logits = torch.empty(
+ hidden_states.size(0),
+ hidden_states.size(1),
+ self.num_audio_tokens,
+ self.num_vq,
+ dtype=torch.float,
+ device=self.device,
+ )
+ for num_vq_iter in range(self.num_vq):
+ x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
+ logits[..., num_vq_iter] = x
+ del x
+
+ del hidden_states
+
+ logits = logits[:, -1].float()
+
+ logits = logits.permute(0, 2, 1)
+ logits = logits.reshape(-1, logits.size(2))
+
+ logits /= temperature
+
+ if not audio_bos:
+ input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
+ logits_token = input_ids_sliced.reshape(
+ input_ids_sliced.size(0) * input_ids_sliced.size(1),
+ -1,
+ ).to(self.device)
+
+ del input_ids_sliced
+
+ for logitsProcessors in logits_processors:
+ logits = logitsProcessors(logits_token, logits)
+
+ for logitsWarpers in logits_warpers:
+ logits = logitsWarpers(logits_token, logits)
+
+ del logits_token
+
+ if t < min_new_token:
+ logits[:, eos_token] = -torch.inf
+
+ if force_no_stop:
+ logits[:, eos_token] = -torch.inf
+
+ scores = F.softmax(logits, dim=-1)
+
+ del logits
+
+ idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
+
+ del scores
+
+ idx_next = idx_next.view(-1, self.num_vq)
+
+ finish_or = idx_next.eq(eos_token).any(1)
+ finish.logical_or_(finish_or)
+
+ del finish_or
+ new_tokens[:, t] = idx_next
+
+ if t == 0 and finish.any():
+ break
+
+ del idx_next
+
+ if finish.all():
+ break
+
+ if pbar is not None:
+ pbar.update(1)
+
+ if pbar is not None:
+ pbar.close()
+
+ if not finish.all():
+ logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
+
+ genrated_input_ids = new_tokens[:, 0:t, :]
+
+ return MiniCPMTTSGenerationOutput(
+ new_ids=genrated_input_ids,
+ audio_input_ids=None, # for update purpose
+ past_key_values=None, # for update purpose
+ past_input_ids=None, # for update purpose
+ finished=finish.all(),
+ )
+
+ # fake streaming
+ @torch.inference_mode()
+ def generate_mock_legacy_streaming(
+ self,
+ inputs_embeds: torch.Tensor,
+ eos_token: Union[int, torch.Tensor],
+ force_no_stop=False,
+ min_new_token=50,
+ max_new_token=2048,
+ show_tqdm=True,
+ streaming=False,
+ text_lengths=None,
+ sampling_params: TTSSamplingParams = TTSSamplingParams(),
+ valid_text_length=None,
+ ):
+ assert valid_text_length is not None, "valid_text_length should be not None"
+
+ tts_text_scope = [0, inputs_embeds.shape[1]]
+ tts_text_mask = torch.zeros(inputs_embeds.shape[1], dtype=torch.int8, device=inputs_embeds.device)
+ tts_text_mask[0:valid_text_length] = 1
+ tts_text_mask[-1] = 1 # [Ptts]
+
+ streaming_mask_4d_full = make_streaming_chunk_mask_inference(
+ tts_text_scope=tts_text_scope,
+ tts_text_mask=tts_text_mask,
+ dtype=torch.bfloat16,
+ device=self.device,
+ streaming_audio_chunk_size=50,
+ max_sequence_length=4096,
+ )
+
+ temperature = torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.device)
+ temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
+ inputs_embeds.device
+ )
+
+ logits_warpers, logits_processors = gen_logits(
+ num_code=self.config.num_audio_tokens,
+ repetition_penalty=sampling_params.repetition_penalty,
+ top_p=sampling_params.top_p,
+ top_k=sampling_params.top_k,
+ )
+
+ # We only support batch size `1` for now
+ assert inputs_embeds.shape[0] == 1
+ eos_token = eos_token.to(inputs_embeds.device)
+ finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
+
+ condition_length = inputs_embeds.shape[1]
+ pbar: Optional[tqdm] = None
+ if show_tqdm:
+ pbar = tqdm(
+ total=max_new_token,
+ desc="code",
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
+ )
+
+ new_tokens = torch.zeros(
+ inputs_embeds.shape[0],
+ max_new_token,
+ self.num_vq,
+ device=inputs_embeds.device,
+ dtype=torch.long,
+ )
+
+ past_key_values = None
+
+ for t in range(max_new_token):
+ audio_bos = False
+ if t == 0:
+ audio_bos = True
+ inputs_embeds = inputs_embeds
+ position_ids = torch.tensor(
+ list(range(0, condition_length)),
+ dtype=torch.long,
+ device=self.device,
+ ).unsqueeze(0)
+
+ causal_mask_4d = streaming_mask_4d_full[:, :, :condition_length, :condition_length]
+ else:
+ code_emb = []
+ for q in range(self.num_vq):
+ x = self.emb_code[q](new_tokens[:, t - 1 : t, q])
+ code_emb.append(x)
+
+ inputs_embeds = torch.stack(code_emb, 3).sum(3)
+
+ position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze(
+ 0
+ )
+
+ causal_mask_4d = streaming_mask_4d_full[
+ :,
+ :,
+ condition_length + t : condition_length + t + 1,
+ : condition_length + t,
+ ]
+
+ # get length of past_key_values
+ past_key_values_length = past_key_values[0][0].shape[2]
+
+ assert causal_mask_4d.shape[-1] == (past_key_values_length + 1)
+
+ if self.config.backbone_model == "llama":
+ outputs: BaseModelOutputWithPast = self.model(
+ position_ids=position_ids,
+ cache_position=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=causal_mask_4d,
+ use_cache=True,
+ output_attentions=False,
+ # return_dict=True, # Add this to ensure returns dict with past_key_values
+ )
+ else:
+ raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}")
+
+ del position_ids
+ del inputs_embeds
+
+ hidden_states = outputs.last_hidden_state
+ past_key_values = outputs.past_key_values
+
+ with P.cached():
+ logits = torch.empty(
+ hidden_states.size(0),
+ hidden_states.size(1),
+ self.num_audio_tokens,
+ self.num_vq,
+ dtype=torch.float,
+ device=self.device,
+ )
+ for num_vq_iter in range(self.num_vq):
+ x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
+ logits[..., num_vq_iter] = x
+ del x
+
+ del hidden_states
+
+ logits = logits[:, -1].float()
+ logits = logits.permute(0, 2, 1)
+ logits = logits.reshape(-1, logits.size(2))
+ logits /= temperature
+
+ if not audio_bos:
+ input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
+
+ logits_token = input_ids_sliced.reshape(
+ input_ids_sliced.size(0) * input_ids_sliced.size(1),
+ -1,
+ ).to(self.device)
+
+ del input_ids_sliced
+
+ for logitsProcessors in logits_processors:
+ logits = logitsProcessors(logits_token, logits)
+
+ for logitsWarpers in logits_warpers:
+ logits = logitsWarpers(logits_token, logits)
+
+ del logits_token
+
+ if t < min_new_token:
+ logits[:, eos_token] = -torch.inf
+
+ if force_no_stop:
+ logits[:, eos_token] = -torch.inf
+
+ scores = F.softmax(logits, dim=-1)
+
+ del logits
+ idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
+
+ del scores
+
+ idx_next = idx_next.view(-1, self.num_vq)
+ finish_or = idx_next.eq(eos_token).any(1)
+ finish.logical_or_(finish_or)
+
+ del finish_or
+ new_tokens[:, t] = idx_next
+
+ if t == 0 and finish.any():
+ break
+
+ del idx_next
+
+ if finish.all():
+ break
+
+ if pbar is not None:
+ pbar.update(1)
+
+ if pbar is not None:
+ pbar.close()
+
+ if not finish.all():
+ logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
+
+ genrated_input_ids = new_tokens[:, 0:t, :]
+
+ return MiniCPMTTSGenerationOutput(
+ new_ids=genrated_input_ids,
+ audio_input_ids=None, # for update purpose
+ past_key_values=None, # for update purpose
+ past_input_ids=None, # for update purpose
+ finished=finish.all(),
+ )
+
+ # non-streaming, interleave
+ @torch.inference_mode()
+ def generate_chunk(
+ self,
+ inputs_embeds: torch.Tensor,
+ temperature: torch.Tensor,
+ repetition_penalty: float,
+ eos_token: Union[int, torch.Tensor],
+ force_no_stop=False,
+ max_new_token=500,
+ min_new_tokens=0,
+ past_key_values=None,
+ logits_processors=None,
+ text_start_pos=None,
+ ):
+ """For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like:
+ |Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS|
+ where the last position is the audio BOS token.
+ So, the first iteration in generation directly forward the model with inputs_embeds, and
+ the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token.
+ """
+ logits_warpers, logits_processors = gen_logits(
+ num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty
+ )
+
+ # We only support batch size `1` for now
+ assert inputs_embeds.shape[0] == 1
+ eos_token = eos_token.to(inputs_embeds.device)
+ finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
+
+ temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
+ inputs_embeds.device
+ )
+
+ condition_length = inputs_embeds.shape[1]
+
+ new_tokens = torch.zeros(
+ inputs_embeds.shape[0],
+ max_new_token,
+ self.num_vq,
+ device=inputs_embeds.device,
+ dtype=torch.long,
+ )
+
+ for t in range(max_new_token):
+ audio_bos = False
+
+ # If this is the first audio token, the case is special
+ if t == 0:
+ audio_bos = True
+ inputs_embeds_ = inputs_embeds
+ position_ids = torch.tensor(
+ list(range(text_start_pos, text_start_pos + condition_length)),
+ dtype=torch.long,
+ device=self.device,
+ ).unsqueeze(0)
+ else:
+ # Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate`
+ inputs_embeds_ = self.emb_code[0](new_tokens[:, t - 1 : t, 0])
+
+ position_ids = torch.tensor(
+ [text_start_pos + condition_length + t - 1], # prefill the previous token
+ dtype=torch.long,
+ device=self.device,
+ ).unsqueeze(0)
+
+ outputs: BaseModelOutputWithPast = self.model(
+ position_ids=position_ids,
+ # cache_position=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds_,
+ use_cache=True,
+ output_attentions=False,
+ # return_dict=True, # Add this to ensure returns dict with past_key_values
+ )
+
+ del position_ids
+ del inputs_embeds_
+
+ hidden_states = outputs.last_hidden_state
+ past_key_values = outputs.past_key_values
+
+ with P.cached():
+ logits = torch.empty(
+ hidden_states.size(0),
+ hidden_states.size(1),
+ self.num_audio_tokens,
+ self.num_vq,
+ dtype=torch.float,
+ device=self.device,
+ )
+ for num_vq_iter in range(self.num_vq):
+ x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
+ logits[..., num_vq_iter] = x
+ del x
+
+ del hidden_states
+
+ logits = logits[:, -1].float()
+ logits = logits.permute(0, 2, 1)
+ logits = logits.reshape(-1, logits.size(2))
+
+ logits /= temperature
+
+ if not audio_bos:
+ input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
+
+ logits_token = input_ids_sliced.reshape(
+ input_ids_sliced.size(0) * input_ids_sliced.size(1),
+ -1,
+ ).to(self.device)
+
+ del input_ids_sliced
+
+ for logitsProcessors in logits_processors:
+ logits = logitsProcessors(logits_token, logits)
+
+ del logits_token
+
+ if force_no_stop or t < min_new_tokens:
+ logits[:, eos_token] = -torch.inf
+
+ scores = F.softmax(logits, dim=-1)
+ del logits
+
+ idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
+ del scores
+
+ idx_next = idx_next.view(-1, self.num_vq)
+
+ finish_or = idx_next.eq(eos_token).any(1)
+ finish.logical_or_(finish_or)
+
+ del finish_or
+ new_tokens[:, t] = idx_next
+
+ if t == 0 and finish.any():
+ break
+
+ del idx_next
+
+ if finish.all():
+ break
+
+ # The latest generated token is not in the range returned this time. If it is an eos token, it is not returned. If it is a normal token, it is not returned.
+ genrated_input_ids = new_tokens[:, 0:t, :]
+
+ return genrated_input_ids, past_key_values
+
+ @torch.inference_mode()
+ def interleaved_generate(
+ self,
+ spk_embeds: torch.Tensor,
+ conditions: List[torch.Tensor],
+ temperature: torch.Tensor,
+ repetition_penalty: float,
+ eos_token: Union[int, torch.Tensor],
+ **kwargs,
+ ):
+ """
+ For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like:
+ |Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS|
+ where the last position is the audio BOS token.
+ So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token.
+ """
+ temperature = torch.tensor([temperature], dtype=torch.float, device=self.device)
+
+ logits_warpers, logits_processors = gen_logits(
+ num_code=self.config.num_audio_tokens,
+ repetition_penalty=repetition_penalty,
+ )
+
+ eos_token = eos_token.to(conditions[0].device)
+
+ num_chunks = len(conditions)
+ text_start_pos = 0
+ last_window_size = 0
+ past_key_values = None
+
+ for idx in range(num_chunks):
+ condition = conditions[idx].to(conditions[0].device)
+ if self.attention_type == "sliding_recompute":
+ recomputed_conditions = []
+
+ if (
+ idx >= self.window_size
+ and (idx - self.recomputed_chunks) % (self.window_size - self.recomputed_chunks) == 0
+ ):
+ for i in range(self.recomputed_chunks):
+ recomputed_conditions.append(conditions[idx - self.recomputed_chunks + i])
+ recomputed_conditions.append(
+ self.emb_code[0](generated_tokens[-self.recomputed_chunks + i][:, :, 0])
+ )
+ recomputed_conditions.append(condition)
+ condition = torch.cat(recomputed_conditions, dim=1)
+
+ text_start_pos = 0
+ new_tokens, old_kv = self.generate_chunk(
+ inputs_embeds=condition,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ eos_token=eos_token,
+ force_no_stop=False,
+ max_new_token=500,
+ past_key_values=None,
+ logits_processors=logits_processors,
+ text_start_pos=text_start_pos,
+ )
+
+ else:
+ new_tokens, old_kv = self.generate_chunk(
+ inputs_embeds=condition,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ eos_token=eos_token,
+ force_no_stop=False,
+ max_new_token=500,
+ past_key_values=past_key_values,
+ logits_processors=logits_processors,
+ text_start_pos=text_start_pos,
+ )
+ else:
+ new_tokens, old_kv = self.generate_chunk(
+ inputs_embeds=condition,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ eos_token=eos_token,
+ force_no_stop=False,
+ max_new_token=500,
+ past_key_values=past_key_values,
+ logits_processors=logits_processors,
+ text_start_pos=text_start_pos,
+ )
+
+ past_key_values = []
+ if self.attention_type == "sliding_window" and idx >= 1:
+ for layer_idx in range(len(old_kv)):
+ past_key_values.append(
+ (
+ old_kv[layer_idx][0][:, :, last_window_size:, :],
+ old_kv[layer_idx][1][:, :, last_window_size:, :],
+ )
+ )
+ else:
+ past_key_values = old_kv
+
+ last_window_size = condition.shape[1] + new_tokens.shape[1]
+ text_start_pos += last_window_size
+
+ if idx == 0:
+ generated_tokens = [new_tokens]
+ else:
+ generated_tokens.append(new_tokens)
+
+ return MiniCPMTTSGenerationOutput(new_ids=torch.cat(generated_tokens, dim=1), finished=True)
+
+
+class CustomRepetitionPenaltyLogitsProcessorRepeat:
+ def __init__(self, penalty: float, max_input_ids: int, past_window: int):
+ if not isinstance(penalty, float) or not (penalty > 0):
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
+
+ self.penalty = penalty
+ self.max_input_ids = max_input_ids
+ self.past_window = past_window
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if input_ids.size(1) > self.past_window:
+ input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
+ freq = F.one_hot(input_ids, scores.size(1)).sum(1)
+ if freq.size(0) > self.max_input_ids:
+ freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_()
+ alpha = torch.pow(self.penalty, freq)
+ scores = scores.contiguous()
+ inp = scores.multiply(alpha)
+ oth = scores.divide(alpha)
+ con = scores < 0
+ out = torch.where(con, inp, oth)
+ del inp, oth, scores, con, alpha
+ return out
+
+
+def gen_logits(num_code: int, top_p=0.7, top_k=20, repetition_penalty=1.0):
+ logits_warpers = []
+
+ if top_p is not None:
+ logits_warpers.append(TopPLogitsWarper(top_p, min_tokens_to_keep=3))
+
+ if top_k is not None:
+ logits_warpers.append(TopKLogitsWarper(top_k, min_tokens_to_keep=3))
+
+ logits_processors = []
+ if repetition_penalty is not None and repetition_penalty != 1:
+ logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16))
+
+ return logits_warpers, logits_processors
+
+
+# Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+ # if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for positionidspositionidsposition_ids.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+ if model_inputs["inputs_embeds"] is not None:
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
+ device = model_inputs["inputs_embeds"].device
+ else:
+ batch_size, sequence_length = model_inputs["input_ids"].shape
+ device = model_inputs["input_ids"].device
+
+ dtype = self.lm_head.weight.dtype
+ min_dtype = torch.finfo(dtype).min
+
+ from transformers.models.paligemma.modeling_paligemma import (
+ _prepare_4d_causal_attention_mask_with_cache_position,
+ )
+
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=past_key_values.get_max_length(),
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=batch_size,
+ )
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ # "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+
+ return model_inputs