| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import Any |
| from typing import Dict |
| from typing import List |
| from typing import Literal |
| from typing import Optional |
| from typing import Tuple |
| from typing import Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.nn.utils.parametrize as P |
| from transformers.cache_utils import DynamicCache |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| @dataclass |
| class GenerateChunkOutput: |
| chunk_token_ids: torch.Tensor |
| current_inputs_embeds: torch.Tensor |
| input_last_hidden_states: Optional[torch.Tensor] |
| last_hidden_states: Optional[torch.Tensor] |
| past_key_values: Optional[torch.Tensor] |
| finished: bool |
|
|
|
|
| class ChunkPrefillChunkGenerate: |
| def __init__(self, model, tokenizer, terminators): |
| self.tokenizer = tokenizer |
| self.model = model |
| self.terminators = terminators |
| self.terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| self.embedding_layer = self.model.get_input_embeddings() |
|
|
| self.forbidden_tokens = [ |
| ":", |
| ":", |
| ";", |
| "#", |
| "“", |
| "”", |
| "‘", |
| "’", |
| "@", |
| "*", |
| "【", |
| "】", |
| "「", |
| "」", |
| "(", |
| ")", |
| "(", |
| ")", |
| "[", |
| "]", |
| "&", |
| "/", |
| "$", |
| ] |
|
|
| self.forbidden_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.forbidden_tokens] |
| bad_token_ids = getattr(tokenizer, "bad_token_ids", []) |
| if bad_token_ids: |
| self.forbidden_token_ids.extend(bad_token_ids) |
|
|
| @staticmethod |
| def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs): |
| num_beams = kwargs.get("num_beams", 3) |
| generation_config = { |
| "num_beams": num_beams, |
| "top_p": 0.8, |
| "top_k": 100, |
| "temperature": 0.7, |
| "do_sample": True, |
| "repetition_penalty": 1.05, |
| } |
|
|
| if do_sample: |
| generation_config.update( |
| { |
| "top_p": 0.8, |
| "top_k": 100, |
| "temperature": 0.7, |
| "do_sample": True, |
| "repetition_penalty": 1.05, |
| } |
| ) |
| elif num_beams > 1: |
| generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False}) |
| else: |
| generation_config.update({"do_sample": False, "repetition_penalty": 1.05}) |
|
|
| generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) |
| generation_config["min_new_tokens"] = min_new_tokens |
| generation_config["max_new_tokens"] = max_new_tokens |
|
|
| return generation_config |
|
|
| def chunk_generate( |
| self, |
| inputs_embeds: torch.Tensor, |
| past_key_values, |
| is_first_generate_chunk: bool, |
| chunk_size: int, |
| return_hidden_states: bool, |
| do_sample: bool, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| repetition_penalty: float = 1.05, |
| length_penalty: float = 1.0, |
| all_input_ids: Optional[torch.Tensor] = None, |
| ) -> GenerateChunkOutput: |
| """ |
| Args: |
| inputs_embeds: [1, seq_len, hidden_dim], Input embeddings of current chunk. |
| past_key_values: [num_layers, 2, batch_size, num_heads, seq_len, head_dim], Past key values for llm. |
| is_first_generate_chunk: bool, Whether this is the first generate chunk. |
| chunk_size: int, The size of the current chunk, default is 10, and it is fixed during training. |
| return_hidden_states: bool Whether to return the hidden states, default is True. |
| do_sample: bool Whether to sample from the model, default is True. |
| temperature: float The temperature for the model, default is 0.7. |
| top_p: float The top-p for the model, default is 0.8. |
| top_k: int The top-k for the model, default is 100. |
| repetition_penalty: float, The repetition penalty for the model, default is 1.05. |
| length_penalty: float, The length penalty for the model, default is 1.0. Higher value means more detailed generation. |
| all_input_ids: Optional[torch.Tensor], The input ids for the current chunk. |
| """ |
|
|
| finished = False |
| current_inputs_embeds = inputs_embeds.clone() |
| input_last_hidden_states = [] |
| last_hidden_states = [] |
| generated_tokens = [] |
|
|
| for token_idx in range(chunk_size): |
| if is_first_generate_chunk and token_idx == 0: |
| |
| model_inputs = { |
| "inputs_embeds": current_inputs_embeds, |
| "past_key_values": past_key_values, |
| "use_cache": True, |
| "output_hidden_states": return_hidden_states, |
| } |
| else: |
| model_inputs = { |
| "inputs_embeds": current_inputs_embeds[:, -1:, :], |
| "past_key_values": past_key_values, |
| "use_cache": True, |
| "output_hidden_states": return_hidden_states, |
| } |
|
|
| with torch.no_grad(): |
| outputs = self.model(**model_inputs) |
|
|
| |
| logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=inputs_embeds.device) |
|
|
| |
| if self.forbidden_token_ids: |
| logits[:, self.forbidden_token_ids] = float("-inf") |
|
|
| past_key_values = outputs.past_key_values |
|
|
| PENALTY_WINDOW_SIZE = 128 |
|
|
| |
| if repetition_penalty != 1.0: |
| |
| if all_input_ids is not None: |
| |
| if len(generated_tokens) > 0: |
| generated_token_ids = torch.cat(generated_tokens, dim=1) |
| current_sequence = torch.cat( |
| [ |
| all_input_ids[:, -PENALTY_WINDOW_SIZE:], |
| generated_token_ids, |
| ], |
| dim=1, |
| ) |
| else: |
| current_sequence = all_input_ids[:, -PENALTY_WINDOW_SIZE:] |
| unique_token_ids = torch.unique(current_sequence.squeeze(0)) |
| elif len(generated_tokens) > 0: |
| |
| generated_token_ids = torch.cat(generated_tokens, dim=1).squeeze(0) |
| unique_token_ids = torch.unique(generated_token_ids) |
| else: |
| unique_token_ids = torch.tensor([], dtype=torch.long, device=logits.device) |
|
|
| |
| for token_id in unique_token_ids: |
| if logits[0, token_id] > 0: |
| logits[0, token_id] = logits[0, token_id] / repetition_penalty |
| else: |
| logits[0, token_id] = logits[0, token_id] * repetition_penalty |
|
|
| |
| if length_penalty != 1.0: |
| for eos_token_id in self.terminators_ids: |
| if logits[0, eos_token_id] > 0: |
| logits[0, eos_token_id] = logits[0, eos_token_id] / length_penalty |
| else: |
| logits[0, eos_token_id] = logits[0, eos_token_id] * length_penalty |
|
|
| |
| if temperature != 1.0: |
| logits = logits / temperature |
|
|
| if do_sample: |
| |
| if top_k > 0: |
| top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits_filtered = torch.full_like(logits, float("-inf")) |
| logits_filtered.scatter_(1, top_k_indices, top_k_logits) |
| logits = logits_filtered |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| logits[indices_to_remove] = float("-inf") |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| if return_hidden_states: |
| if is_first_generate_chunk and token_idx == 0: |
| input_last_hidden_states.append(outputs.hidden_states[-1]) |
| else: |
| last_hidden_states.append(outputs.hidden_states[-1]) |
|
|
| |
| if next_token.item() in self.terminators_ids: |
| finished = True |
| break |
|
|
| generated_tokens.append(next_token) |
|
|
| |
| next_token_embed = self.embedding_layer(next_token) |
|
|
| |
| current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embed], dim=1) |
|
|
| if len(generated_tokens) > 0: |
| chunk_token_ids = torch.cat(generated_tokens, dim=1) |
| else: |
| |
| if finished: |
| chunk_token_ids = torch.zeros((1, 0), dtype=torch.long, device=current_inputs_embeds.device) |
| else: |
| raise Exception("this should not happen") |
|
|
| if len(last_hidden_states) > 0: |
| last_hidden_states = torch.cat(last_hidden_states, dim=1) |
| else: |
| |
| if finished: |
| last_hidden_states = torch.cat(last_hidden_states, dim=1) |
| else: |
| raise Exception("this should not happen") |
|
|
| if len(input_last_hidden_states) > 0: |
| input_last_hidden_states = torch.cat(input_last_hidden_states, dim=1) |
| else: |
| input_last_hidden_states = None |
|
|
| return GenerateChunkOutput( |
| chunk_token_ids=chunk_token_ids, |
| current_inputs_embeds=current_inputs_embeds, |
| input_last_hidden_states=input_last_hidden_states, |
| last_hidden_states=last_hidden_states, |
| past_key_values=past_key_values, |
| finished=finished, |
| ) |
|
|
|
|
| def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False): |
| """ |
| Incrementally decode tokens from an iterator, handling partial multi-byte characters. |
| |
| When streaming tokens, multi-byte characters (like Chinese) may be split across multiple |
| tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function |
| buffers tokens and only yields complete characters. |
| |
| Args: |
| token_iterator: An iterator yielding (token_ids, is_finished) tuples. |
| token_ids can be torch.Tensor or any iterable of integers. |
| tokenizer: The tokenizer to use for decoding. |
| skip_special_tokens: Whether to skip special tokens during decoding. |
| |
| Yields: |
| (decoded_text, is_finished) tuples where decoded_text is the new text since last yield. |
| """ |
| accumulated_token_ids = [] |
| yielded_text_len = 0 |
|
|
| for token_ids, is_finished in token_iterator: |
| |
| if torch.is_tensor(token_ids): |
| accumulated_token_ids.extend(token_ids.reshape(-1).tolist()) |
| else: |
| accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids]) |
|
|
| |
| full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens) |
|
|
| if is_finished: |
| |
| new_text = full_decoded[yielded_text_len:] |
| yield new_text, is_finished |
| else: |
| |
| |
| new_text = full_decoded[yielded_text_len:] |
|
|
| |
| safe_end = len(new_text) |
| while safe_end > 0 and new_text[safe_end - 1] == "\ufffd": |
| safe_end -= 1 |
|
|
| safe_text = new_text[:safe_end] if safe_end > 0 else "" |
| yielded_text_len += len(safe_text) |
| yield safe_text, is_finished |
|
|
|
|
| def torch_clone_recursive(obj): |
| """Recursively clone nested containers of torch.Tensors. |
| |
| Supported container types: dict, list, tuple. Non-container non-Tensor |
| objects are returned as-is. |
| """ |
| if torch.is_tensor(obj): |
| return obj.clone() |
| elif isinstance(obj, dict): |
| return {k: torch_clone_recursive(v) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [torch_clone_recursive(v) for v in obj] |
| elif isinstance(obj, tuple): |
| return tuple(torch_clone_recursive(v) for v in obj) |
| else: |
| raise ValueError(f"Unsupported type: {type(obj)}") |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate half the hidden dims of the input for RoPE.""" |
| dim = x.shape[-1] |
| x1 = x[..., : dim // 2] |
| x2 = x[..., dim // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| @dataclass |
| class SpeculativeSnapshot: |
| """Speculative snapshot for VAD speculative rollback. |
| |
| Used in VAD speculative execution: creates a snapshot after streaming_prefill |
| and before streaming_generate. If speculation fails (user continues speaking), |
| the state can be restored to continue streaming_prefill. |
| |
| Implementation: |
| - LLM KV Cache: only record length, restore by truncation (zero extra VRAM) |
| - Audio KV Cache: requires cloning, as generate sets it to None |
| - Mel processor: save full state snapshot (including buffer) |
| """ |
|
|
| |
| llm_cache_length: int |
| audio_cache_length: int |
|
|
| |
| new_user_msg: bool |
| llm_generated: bool |
| llm_generate_completed: bool |
|
|
| |
| next_round_id: int |
| pending_round_id: Optional[int] |
| omni_chunk_history_length: int |
|
|
| |
| tts_last_turn_tokens: Optional[torch.Tensor] |
|
|
| |
| audio_chunk_idx: int |
|
|
| |
| mel_processor_snapshot: Optional[dict] = None |
|
|
| |
| audio_past_key_values: Optional[tuple] = None |
|
|
| |
| timestamp: float = 0.0 |
|
|
| |
| llm_cache_checksum: Optional[float] = None |
| audio_cache_checksum: Optional[float] = None |
| mel_buffer_checksum: Optional[float] = None |
|
|
| |
| rng_state_cpu: Optional[torch.Tensor] = None |
| rng_state_cuda: Optional[torch.Tensor] = None |
|
|
| def summary(self) -> str: |
| mel_buf_len = 0 |
| if self.mel_processor_snapshot: |
| buf = self.mel_processor_snapshot.get("buffer") |
| if buf is not None: |
| mel_buf_len = len(buf) |
| return ( |
| f"llm_cache={self.llm_cache_length}, " |
| f"audio_cache={self.audio_cache_length}, " |
| f"audio_chunk_idx={self.audio_chunk_idx}, " |
| f"mel_buffer={mel_buf_len}, " |
| f"history_len={self.omni_chunk_history_length}, " |
| f"new_user_msg={self.new_user_msg}, " |
| f"llm_generated={self.llm_generated}" |
| ) |
|
|
|
|
| |
| @dataclass |
| class TTSSamplingParams: |
| top_p: float = 0.85 |
| min_p: float = 0.01 |
| top_k: int = 25 |
| repetition_penalty: float = 1.05 |
| temperature: float = 0.8 |
| win_size: int = 16 |
| tau_r: float = 0.1 |
|
|
|
|
| class TTSStreamingGenerator: |
| """ |
| Streaming generator for TTS that processes chunks and yields audio tokens in real-time. |
| |
| Supported attention types: |
| - full_attention: Full attention, all tokens can attend to each other |
| - sliding_window: Sliding window attention, KV cache is truncated to fixed size (token_window_size) |
| - sliding_recompute: Sliding recompute, only keep previous chunk and recompute with current chunk |
| - reindex: Keep first chunk as sink, reindex sliding window positions via RoPE rotation |
| """ |
|
|
| def __init__( |
| self, |
| model, |
| temperature: float, |
| eos_token: Union[int, torch.Tensor], |
| chunk_size: int = 25, |
| tts_last_turn_tokens: torch.Tensor = None, |
| logits_processors=None, |
| logits_warpers=None, |
| ): |
| self.tts = model |
| self.device = model.device |
| self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device) |
| self.eos_token = ( |
| torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device) |
| ) |
|
|
| self.num_vq = model.num_vq |
| self.num_audio_tokens = model.num_audio_tokens |
| self.recomputed_chunks = model.recomputed_chunks |
| self.emb_code = model.emb_code |
| self.head_code = model.head_code |
|
|
| |
| self.attention_type = model.attention_type |
| self.chunk_window_size = model.chunk_window_size |
| self.token_window_size = model.token_window_size |
|
|
| |
| self.rope_theta = model.model.config.rope_theta |
| self.head_dim = model.model.config.hidden_size // model.model.config.num_attention_heads |
|
|
| |
| self.logits_processors = logits_processors if logits_processors is not None else [] |
| |
| self.logits_warpers = logits_warpers if logits_warpers is not None else [] |
|
|
| |
| self.past_key_values = None |
| self.text_start_pos = 0 |
| self.idx = -1 |
| self.all_conditions = [] |
| self.all_generated_tokens = [] |
| self.tts_last_turn_tokens = tts_last_turn_tokens |
| self.spk_emb = None |
|
|
| audio_bos = [self.tts.audio_bos_token_id] |
| audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long) |
|
|
| self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0) |
| self.text_eos_embed = self.tts.emb_text( |
| torch.tensor( |
| [self.tts.config.text_eos_token_id], |
| device=self.tts.emb_text.weight.device, |
| dtype=torch.long, |
| ) |
| ).unsqueeze(0) |
|
|
| |
| self.chunk_size = chunk_size |
| self._token_buffer: List[torch.Tensor] = [] |
|
|
| |
| self._chunk_info: List[dict] = [] |
| self._total_seq_len = 0 |
|
|
| |
| self._sink_kv_len = 0 |
|
|
| def _build_recompute_inputs(self, current_condition: torch.Tensor) -> torch.Tensor: |
| """Build recompute inputs for sliding_recompute mode.""" |
| if len(self._chunk_info) == 0: |
| return current_condition |
|
|
| prev_chunk = self._chunk_info[-1] |
| prev_condition = prev_chunk["condition"] |
| prev_audio_tokens = prev_chunk["audio_tokens"] |
|
|
| recompute_list = [prev_condition] |
| if len(prev_audio_tokens) > 0: |
| prev_audio_embeds = torch.cat([self.emb_code[0](tok) for tok in prev_audio_tokens], dim=1) |
| recompute_list.append(prev_audio_embeds) |
|
|
| recompute_list.append(current_condition) |
| return torch.cat(recompute_list, dim=1) |
|
|
| def _truncate_kv_cache_sliding_window(self): |
| """Truncate KV cache for sliding_window mode.""" |
| if self.past_key_values is None: |
| return |
|
|
| if hasattr(self.past_key_values, "get_seq_length"): |
| current_kv_len = self.past_key_values.get_seq_length() |
| else: |
| current_kv_len = self.past_key_values[0][0].shape[2] |
|
|
| if current_kv_len <= self.token_window_size: |
| return |
|
|
| new_cache = DynamicCache() |
| num_layers = ( |
| len(self.past_key_values.key_cache) |
| if hasattr(self.past_key_values, "key_cache") |
| else len(self.past_key_values) |
| ) |
|
|
| for layer_idx in range(num_layers): |
| if hasattr(self.past_key_values, "key_cache"): |
| key = self.past_key_values.key_cache[layer_idx][:, :, -self.token_window_size :, :] |
| value = self.past_key_values.value_cache[layer_idx][:, :, -self.token_window_size :, :] |
| else: |
| key = self.past_key_values[layer_idx][0][:, :, -self.token_window_size :, :] |
| value = self.past_key_values[layer_idx][1][:, :, -self.token_window_size :, :] |
| new_cache.update(key, value, layer_idx) |
|
|
| self.past_key_values = new_cache |
|
|
| @staticmethod |
| def _apply_rope_rotation(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| """Apply RoPE rotation to tensor.""" |
| return x * cos + rotate_half(x) * sin |
|
|
| def _compute_rope_cos_sin(self, positions: torch.Tensor, device: torch.device, dtype: torch.dtype): |
| """Compute RoPE cos and sin for given positions.""" |
| dim_half = self.head_dim // 2 |
| freq_seq = torch.arange(0, dim_half, dtype=torch.float32, device=device) |
| inv_freq = 1.0 / (self.rope_theta ** (freq_seq / dim_half)) |
|
|
| |
| angles = positions.float().unsqueeze(-1) * inv_freq.unsqueeze(0) |
| angles = torch.cat([angles, angles], dim=-1) |
|
|
| cos = angles.cos().to(dtype) |
| sin = angles.sin().to(dtype) |
| return cos, sin |
|
|
| def _reindex_kv_cache(self): |
| """ |
| Reindex KV cache for reindex mode: |
| 1. Keep first chunk as attention sink |
| 2. Keep last chunk |
| 3. Discard middle chunks |
| 4. Reindex the last chunk's key positions to be right after sink via RoPE rotation |
| """ |
| if self.past_key_values is None or len(self._chunk_info) < 2: |
| return |
|
|
| |
| if hasattr(self.past_key_values, "get_seq_length"): |
| current_kv_len = self.past_key_values.get_seq_length() |
| else: |
| current_kv_len = self.past_key_values[0][0].shape[2] |
|
|
| |
| sink_len = self._chunk_info[0]["condition_len"] + self._chunk_info[0]["audio_token_count"] |
|
|
| |
| last_chunk = self._chunk_info[-1] |
| last_chunk_len = last_chunk["condition_len"] + last_chunk["audio_token_count"] |
|
|
| keep_len = sink_len + last_chunk_len |
|
|
| |
| device = self.past_key_values.key_cache[0].device |
| dtype = self.past_key_values.key_cache[0].dtype |
|
|
| if current_kv_len <= keep_len: |
| last_chunk_kv_len = current_kv_len - sink_len |
| if last_chunk_kv_len <= 0: |
| return |
| self.text_start_pos = current_kv_len |
| return |
|
|
| |
| new_cache = DynamicCache() |
| num_layers = len(self.past_key_values.key_cache) |
|
|
| original_start_pos = current_kv_len - last_chunk_len |
| new_start_pos = sink_len |
| delta = new_start_pos - original_start_pos |
| delta_positions = torch.full((last_chunk_len,), delta, dtype=torch.float32, device=device) |
|
|
| |
| cos, sin = self._compute_rope_cos_sin(delta_positions, device, dtype) |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
| for layer_idx in range(num_layers): |
| key_full = self.past_key_values.key_cache[layer_idx] |
| value_full = self.past_key_values.value_cache[layer_idx] |
|
|
| |
| key_sink = key_full[:, :, :sink_len, :] |
| value_sink = value_full[:, :, :sink_len, :] |
| key_last = key_full[:, :, -last_chunk_len:, :] |
| value_last = value_full[:, :, -last_chunk_len:, :] |
|
|
| |
| key_last_reindexed = self._apply_rope_rotation(key_last, cos, sin) |
|
|
| |
| key = torch.cat([key_sink, key_last_reindexed], dim=2) |
| value = torch.cat([value_sink, value_last], dim=2) |
|
|
| new_cache.update(key, value, layer_idx) |
|
|
| self.past_key_values = new_cache |
|
|
| |
| self.text_start_pos = sink_len + last_chunk_len |
|
|
| @torch.inference_mode() |
| def generate_with_buffer( |
| self, |
| condition: torch.Tensor, |
| text_finished: bool = False, |
| max_new_token: int = 500, |
| ): |
| """input a condition embedding chunk, generate audio token each time, |
| and accumulate to buffer, only yield when buffer satisfies chunk_size. |
| |
| Yields: |
| torch.Tensor of shape [chunk_size] (2D: [1, chunk_size]) |
| """ |
| self.idx += 1 |
| self.device = self.tts.device |
|
|
| |
| if text_finished: |
| condition = torch.cat([condition, self.text_eos_embed], dim=1) |
|
|
| |
| condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device) |
|
|
| self.all_conditions.append(condition) |
|
|
| |
| current_chunk_info = { |
| "condition_len": condition.shape[1], |
| "audio_token_count": 0, |
| "condition": condition.clone(), |
| "audio_tokens": [], |
| } |
|
|
| |
| if self.attention_type == "sliding_recompute" and self.idx >= 1: |
| |
| self.past_key_values = None |
| current_condition = self._build_recompute_inputs(condition) |
| self.text_start_pos = 0 |
| elif self.attention_type == "reindex" and self.idx >= 1: |
| |
| self._reindex_kv_cache() |
| current_condition = condition |
| |
| if self.past_key_values is not None: |
| if hasattr(self.past_key_values, "get_seq_length"): |
| kv_len = self.past_key_values.get_seq_length() |
| else: |
| kv_len = self.past_key_values[0][0].shape[2] |
| self.text_start_pos = kv_len |
| else: |
| current_condition = condition |
|
|
| condition_length = current_condition.shape[1] |
| prefill_len = condition_length |
| finished = torch.zeros(1, dtype=torch.bool, device=self.device) |
| chunk_generated_tokens = [] |
|
|
| for t in range(max_new_token): |
| if t == 0: |
| inputs_embeds = current_condition |
| pos_ids = torch.arange( |
| self.text_start_pos, |
| self.text_start_pos + condition_length, |
| dtype=torch.long, |
| device=self.device, |
| ).unsqueeze(0) |
| else: |
| last = self.all_generated_tokens[-1] |
| |
| inputs_embeds = self.emb_code[0](last) |
| pos_ids = torch.tensor( |
| [self.text_start_pos + prefill_len + t - 1], |
| dtype=torch.long, |
| device=self.device, |
| ).unsqueeze(0) |
|
|
| outputs = self.tts.model( |
| position_ids=pos_ids, |
| past_key_values=self.past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| ) |
| hidden_states = outputs.last_hidden_state |
|
|
| |
| if self.attention_type == "sliding_window": |
| self.past_key_values = outputs.past_key_values |
| self._truncate_kv_cache_sliding_window() |
| else: |
| self.past_key_values = outputs.past_key_values |
|
|
| with P.cached(): |
| logits = torch.empty( |
| hidden_states.size(0), |
| hidden_states.size(1), |
| self.num_audio_tokens, |
| self.num_vq, |
| dtype=torch.float, |
| device=self.device, |
| ) |
| for num_vq_iter in range(self.num_vq): |
| x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) |
| logits[..., num_vq_iter] = x |
| del x |
|
|
| del hidden_states |
|
|
| logits = logits[:, -1].float() |
|
|
| logits = logits.permute(0, 2, 1) |
| logits = logits.reshape(-1, logits.size(2)) |
|
|
| logits /= self.temperature |
|
|
| audio_bos = len(self.all_generated_tokens) == 0 and t == 0 |
|
|
| if not audio_bos: |
| |
| all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) |
| for processor in self.logits_processors: |
| logits = processor(all_generated_tokens, logits) |
|
|
| for warper in self.logits_warpers: |
| logits = warper(all_generated_tokens, logits) |
| del all_generated_tokens |
|
|
| |
| scores = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(scores, num_samples=1) |
| next_id = idx_next.view(-1, self.num_vq)[:, 0:1] |
| del scores |
|
|
| if next_id.eq( |
| self.eos_token |
| ).any(): |
| finished[:] = True |
| else: |
| |
| if next_id.dim() == 0: |
| next_tok = next_id.unsqueeze(0).unsqueeze(0) |
| elif next_id.dim() == 1: |
| next_tok = next_id.unsqueeze(0) |
| else: |
| next_tok = next_id |
|
|
| self.all_generated_tokens.append(next_tok) |
| chunk_generated_tokens.append(next_tok) |
|
|
| |
| current_chunk_info["audio_tokens"].append(next_tok.clone()) |
| current_chunk_info["audio_token_count"] += 1 |
|
|
| self._token_buffer.append(next_tok) |
|
|
| if len(self._token_buffer) == 0: |
| |
| if text_finished: |
| yield torch.empty(1, 0, dtype=torch.long, device=self.device), True |
| break |
| |
| else: |
| break |
| else: |
| |
| if len(self._token_buffer) >= self.chunk_size: |
| batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) |
| yield batch, False |
| |
| self._token_buffer = self._token_buffer[self.chunk_size :] |
|
|
| |
| else: |
| |
| if finished.all(): |
| if text_finished: |
| batch = torch.cat(self._token_buffer, dim=1) |
| yield batch, True |
| self._token_buffer = [] |
| break |
| else: |
| |
| break |
| else: |
| continue |
|
|
| |
| self._chunk_info.append(current_chunk_info) |
| self._total_seq_len += condition.shape[1] + len(chunk_generated_tokens) |
|
|
| |
| if self.attention_type == "sliding_recompute": |
| |
| self.text_start_pos += prefill_len + len(chunk_generated_tokens) |
| elif self.attention_type == "reindex": |
| |
| if self.past_key_values is not None: |
| if hasattr(self.past_key_values, "get_seq_length"): |
| self.text_start_pos = self.past_key_values.get_seq_length() |
| else: |
| self.text_start_pos = self.past_key_values[0][0].shape[2] |
| else: |
| self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens) |
| else: |
| self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens) |
| |
|
|
|
|
| |
| @dataclass |
| class StreamingWindowConfig: |
| text_window_high_tokens: int = 8000 |
| text_window_low_tokens: int = 6000 |
|
|
|
|
| @dataclass |
| class DuplexWindowConfig: |
| """duplex sliding window configuration |
| |
| sliding window mode: |
| - "off": disable sliding window |
| - "basic": basic sliding window (trigger by cache length) |
| - "context": sliding window with context (trigger by unit number, preserve generated text to previous) |
| """ |
|
|
| |
| sliding_window_mode: str = "off" |
|
|
| |
| basic_window_high_tokens: int = 8000 |
| basic_window_low_tokens: int = 6000 |
|
|
| |
| context_previous_max_tokens: int = 500 |
| context_max_units: int = 24 |
|
|
| |
| verify_mode: bool = False |
|
|
|
|
| def as_dynamic_cache(past_key_values): |
| """Convert legacy tuple cache to DynamicCache if needed.""" |
| if isinstance(past_key_values, DynamicCache): |
| return past_key_values |
|
|
| if isinstance(past_key_values, tuple): |
| return DynamicCache.from_legacy_cache(past_key_values) |
|
|
| return past_key_values |
|
|
|
|
| def get_kv_cache_length(cache) -> int: |
| """Get the sequence length of a KV cache. |
| |
| Args: |
| cache: DynamicCache or tuple-based cache |
| |
| Returns: |
| The number of tokens in the cache |
| """ |
| if cache is None: |
| return 0 |
|
|
| if isinstance(cache, DynamicCache): |
| if not cache.key_cache or not cache.key_cache[0].numel(): |
| return 0 |
| return cache.key_cache[0].shape[-2] |
|
|
| if isinstance(cache, tuple): |
| return cache[0][0].shape[2] |
|
|
| return 0 |
|
|
|
|
| def get_rotary_cos_sin( |
| head_dim: int, |
| positions: torch.Tensor, |
| device: torch.device, |
| dtype: torch.dtype, |
| rope_theta: float = 10000.0, |
| inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute RoPE cos and sin components for given positions. |
| |
| Args: |
| head_dim: Dimension of each attention head |
| positions: Position indices tensor |
| device: Target device |
| dtype: Target dtype |
| rope_theta: RoPE base frequency (default 10000.0) |
| inv_freq_cache: Optional cache dict for inverse frequencies |
| |
| Returns: |
| Tuple of (cos, sin) tensors with shape [1, 1, seq_len, head_dim] |
| """ |
| cache_key = (head_dim, device) |
|
|
| inv_freq = inv_freq_cache.get(cache_key) if inv_freq_cache is not None else None |
| if inv_freq is None or inv_freq.device != device or inv_freq.shape[0] != head_dim // 2: |
| exponent = torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim |
| inv_freq = 1.0 / (rope_theta**exponent) |
| if inv_freq_cache is not None: |
| inv_freq_cache[cache_key] = inv_freq |
|
|
| positions = positions.to(device=device, dtype=torch.float32) |
| angles = torch.einsum("i,j->ij", positions, inv_freq) |
| cos = torch.cos(angles) |
| sin = torch.sin(angles) |
|
|
| |
| |
| cos_full = torch.cat([cos, cos], dim=-1).to(dtype=dtype) |
| sin_full = torch.cat([sin, sin], dim=-1).to(dtype=dtype) |
| cos_full = cos_full.unsqueeze(0).unsqueeze(0) |
| sin_full = sin_full.unsqueeze(0).unsqueeze(0) |
| return cos_full, sin_full |
|
|
|
|
| def realign_rotary_suffix( |
| suffix_keys: torch.Tensor, |
| old_positions: torch.Tensor, |
| new_positions: torch.Tensor, |
| rope_theta: float = 10000.0, |
| inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| """Realign RoPE position encoding after cache eviction. |
| |
| When tokens are dropped from the middle of a cache, the suffix tokens |
| need their RoPE embeddings recalculated with new position indices. |
| |
| Args: |
| suffix_keys: Key tensor to realign, shape [batch, heads, seq_len, head_dim] |
| old_positions: Original position indices |
| new_positions: New position indices after eviction |
| rope_theta: RoPE base frequency |
| inv_freq_cache: Optional cache dict for inverse frequencies |
| |
| Returns: |
| Realigned key tensor with same shape as input |
| """ |
| if suffix_keys.numel() == 0: |
| return suffix_keys |
|
|
| head_dim = suffix_keys.shape[-1] |
| device = suffix_keys.device |
| dtype = suffix_keys.dtype |
|
|
| |
| cos_old, sin_old = get_rotary_cos_sin(head_dim, old_positions, device, dtype, rope_theta, inv_freq_cache) |
|
|
| |
| base = cos_old * suffix_keys - sin_old * rotate_half(suffix_keys) |
|
|
| |
| cos_new, sin_new = get_rotary_cos_sin(head_dim, new_positions, device, dtype, rope_theta, inv_freq_cache) |
|
|
| |
| return cos_new * base + sin_new * rotate_half(base) |
|
|
|
|
| def drop_tokens_from_cache( |
| cache: Optional[DynamicCache | Tuple], |
| length: int, |
| preserve: int, |
| position_offset: int, |
| rope_theta: float = 10000.0, |
| inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None, |
| ) -> Tuple[Optional[DynamicCache], int, bool]: |
| """Drop tokens from a KV cache while preserving system prompt. |
| |
| Removes tokens in the range [preserve, preserve + length) from the cache, |
| realigning RoPE embeddings for the suffix. |
| |
| Args: |
| cache: DynamicCache or tuple-based cache (will be converted to DynamicCache) |
| length: Number of tokens to drop |
| preserve: Number of tokens to preserve at the start (system prompt) |
| position_offset: Current position offset for RoPE calculation |
| rope_theta: RoPE base frequency |
| inv_freq_cache: Optional cache dict for inverse frequencies |
| |
| Returns: |
| Tuple of (cache, new_position_offset, success) |
| Note: Tuple cache will be converted to DynamicCache. Modification is in-place. |
| """ |
| if cache is None or length <= 0: |
| return cache, position_offset, False |
|
|
| cache = as_dynamic_cache(cache) |
|
|
| total_len = get_kv_cache_length(cache) |
| if total_len <= 0: |
| return cache, position_offset, False |
|
|
| preserve = min(preserve, total_len) |
| available = total_len - preserve |
|
|
| if available < length: |
| logger.warning( |
| "Cannot drop %d tokens: only %d available (total=%d, preserve=%d)", |
| length, |
| available, |
| total_len, |
| preserve, |
| ) |
| return cache, position_offset, False |
|
|
| suffix_len = total_len - preserve - length |
| |
| |
| suffix_offset = preserve + length |
| prefix_offset = preserve |
|
|
| |
| old_positions = None |
| new_positions = None |
| if suffix_len > 0: |
| device = cache.key_cache[0].device |
| old_positions = torch.arange( |
| suffix_offset, |
| suffix_offset + suffix_len, |
| device=device, |
| dtype=torch.long, |
| ) |
| new_positions = torch.arange( |
| prefix_offset, |
| prefix_offset + suffix_len, |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| keep_len = total_len - length |
|
|
| |
| for layer_idx in range(len(cache.key_cache)): |
| key_tensor = cache.key_cache[layer_idx] |
| value_tensor = cache.value_cache[layer_idx] |
|
|
| if not key_tensor.numel(): |
| continue |
|
|
| |
| prefix_keys = key_tensor[:, :, :preserve, :] |
| prefix_values = value_tensor[:, :, :preserve, :] |
|
|
| if suffix_len > 0: |
| |
| suffix_keys = key_tensor[:, :, preserve + length :, :] |
| suffix_values = value_tensor[:, :, preserve + length :, :] |
|
|
| if old_positions is not None and new_positions is not None and suffix_keys.numel(): |
| suffix_keys = realign_rotary_suffix( |
| suffix_keys, |
| old_positions, |
| new_positions, |
| rope_theta, |
| inv_freq_cache, |
| ) |
|
|
| cache.key_cache[layer_idx] = torch.cat([prefix_keys, suffix_keys], dim=-2).contiguous() |
| cache.value_cache[layer_idx] = torch.cat([prefix_values, suffix_values], dim=-2).contiguous() |
| else: |
| cache.key_cache[layer_idx] = prefix_keys.contiguous() |
| cache.value_cache[layer_idx] = prefix_values.contiguous() |
|
|
| cache.crop(keep_len) |
| cache._seen_tokens = max(keep_len, 0) |
|
|
| new_offset = position_offset + length |
| logger.debug("Dropped %d tokens from cache, new length=%d", length, keep_len) |
|
|
| return cache, new_offset, True |
|
|
|
|
| |
| def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")): |
| logits = logits.clone() |
|
|
| |
| if top_k > 0: |
| top_k = min(top_k, logits.size(-1)) |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| logits[indices_to_remove] = filter_value |
|
|
| |
| if top_p > 0.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(probs, dim=-1) |
|
|
| sorted_indices_to_remove = cumulative_probs > top_p |
| |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| logits[0, indices_to_remove] = filter_value |
|
|
| return logits |
|
|
|
|
| class StreamDecoder: |
| def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None): |
| self.m = llm |
| self.tokenizer = tokenizer |
| self.listen_id = self.tokenizer.eos_token_id |
|
|
| self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>") |
| self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>") |
| self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>") |
| self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>") |
|
|
| self.special_token_ids = special_token_ids if special_token_ids is not None else [] |
|
|
| |
| self._all_special_ids = set() |
| self._all_special_tokens_text = set() |
| if self.tokenizer: |
| if hasattr(self.tokenizer, "all_special_ids"): |
| self._all_special_ids = set(self.tokenizer.all_special_ids) |
| if hasattr(self.tokenizer, "all_special_tokens"): |
| self._all_special_tokens_text = set(self.tokenizer.all_special_tokens) |
|
|
| custom_special_tokens = [ |
| "<unit>", |
| "</unit>", |
| "<image>", |
| "</image>", |
| "<slice>", |
| "</slice>", |
| "<|listen|>", |
| "<|speak|>", |
| "<|tts_bos|>", |
| "<|tts_eos|>", |
| "<|audio_start|>", |
| "<|audio_end|>", |
| "<|chunk_eos|>", |
| "<|chunk_tts_eos|>", |
| "<|turn_eos|>", |
| "<|audio_start|>", |
| "<|audio_end|>", |
| ] |
| self._all_special_tokens_text.update(custom_special_tokens) |
| for token in custom_special_tokens: |
| token_id = self.tokenizer.convert_tokens_to_ids(token) |
| if token_id is not None and token_id != self.tokenizer.unk_token_id: |
| self._all_special_ids.add(token_id) |
|
|
| if forbidden_token_ids is None: |
| self.forbidden_token_ids = [] |
| elif isinstance(forbidden_token_ids, int): |
| self.forbidden_token_ids = [self.forbidden_token_ids] |
| else: |
| self.forbidden_token_ids = forbidden_token_ids |
| self.forbidden_token_ids.append(self.chunk_eos_id) |
|
|
| assert isinstance(self.forbidden_token_ids, list) |
|
|
| self.cache = None |
| self.context = "" |
| self.generated_tokens = [] |
| self.generated_special_tokens = [] |
| self.reset() |
| self.embeds = None |
| self.system_embeds = None |
|
|
| |
| self._unit_history: List[Dict[str, Any]] = [] |
| self._next_unit_id: int = 0 |
| self._pending_unit_id: Optional[int] = None |
| self._pending_unit_start_cache_len: int = 0 |
| self._system_preserve_length: int = 0 |
| self._position_offset: int = 0 |
| self._window_config = DuplexWindowConfig() |
| self._window_enabled: bool = True |
| self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {} |
|
|
| |
| |
| |
| |
| self._preserve_prefix_length: int = 0 |
| self._previous_content_length: int = 0 |
| self._suffix_token_ids: List[int] = [] |
|
|
| |
| self._previous_marker: str = "\n\nprevious: " |
| self._previous_marker_token_ids: List[int] = [] |
| self._has_previous: bool = False |
|
|
| |
| self._previous_text: str = "" |
| self._previous_token_ids: List[int] = [] |
|
|
| |
| self._sliding_event_count: int = 0 |
| self._total_dropped_tokens: int = 0 |
| self._total_dropped_units: int = 0 |
|
|
| def sliding_embeds(self): |
| |
| |
| |
| |
| pass |
|
|
| def reset(self): |
| self.context = "" |
| self.cache = None |
| self.generated_tokens = [] |
| self.generated_special_tokens = [] |
| self.embeds = None |
| self.system_embeds = None |
|
|
| |
| old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0 |
| self._unit_history = [] |
| self._next_unit_id = 0 |
| self._pending_unit_id = None |
| self._pending_unit_start_cache_len = 0 |
| self._system_preserve_length = 0 |
| self._position_offset = 0 |
| self._rope_inv_freq_cache = {} |
|
|
| |
| self._preserve_prefix_length = 0 |
| self._previous_content_length = 0 |
| self._suffix_token_ids = [] |
| self._previous_marker = "\n\nprevious: " |
| self._previous_marker_token_ids = [] |
| self._has_previous = False |
| self._previous_text = "" |
| self._previous_token_ids = [] |
|
|
| |
| self._sliding_event_count = 0 |
| self._total_dropped_tokens = 0 |
| self._total_dropped_units = 0 |
|
|
| def get_cache_length(self) -> int: |
| if self.cache is None: |
| return 0 |
| if isinstance(self.cache, DynamicCache): |
| if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0: |
| return self.cache.key_cache[0].shape[2] |
| return 0 |
| |
| return self.cache[0][0].shape[2] |
|
|
| def get_total_generated_tokens(self) -> int: |
| return sum(len(u.get("generated_tokens", [])) for u in self._unit_history) |
|
|
| def register_unit_start(self) -> int: |
| self._pending_unit_id = self._next_unit_id |
| self._pending_unit_start_cache_len = self.get_cache_length() |
| return self._pending_unit_id |
|
|
| def register_unit_end( |
| self, |
| input_type: str, |
| generated_tokens: Optional[List[int]] = None, |
| is_listen: bool = False, |
| generated_text: Optional[str] = None, |
| ): |
| """Call when unit ends, record unit information |
| |
| Should be called after feeding </unit> token |
| |
| Args: |
| input_type: "audio" / "video" / "omni" / "system" |
| generated_tokens: tokens generated by the unit (token ids) |
| is_listen: whether the unit is in listen state |
| generated_text: text generated by the unit (used for context preserving mode) |
| """ |
| if self._pending_unit_id is None: |
| logger.warning("register_unit_end called without register_unit_start") |
| return |
|
|
| |
| current_cache_len = self.get_cache_length() |
| unit_len = current_cache_len - self._pending_unit_start_cache_len |
|
|
| if unit_len > 0: |
| entry = { |
| "unit_id": self._pending_unit_id, |
| "length": unit_len, |
| "type": input_type, |
| "generated_tokens": generated_tokens or [], |
| "generated_text": generated_text or "", |
| "is_listen": is_listen, |
| } |
| self._unit_history.append(entry) |
|
|
| self._pending_unit_id = None |
| self._pending_unit_start_cache_len = 0 |
| self._next_unit_id += 1 |
|
|
| def register_system_prompt(self): |
| """Call after system prompt prefill, record preserve length""" |
| self._system_preserve_length = self.get_cache_length() |
|
|
| |
|
|
| def _get_rope_theta(self) -> float: |
| """get model rope_theta configuration""" |
| return float(getattr(self.m.config, "rope_theta", 10000.0)) |
|
|
| def _drop_tokens_from_cache(self, length: int) -> bool: |
| """remove specified number of tokens from cache (protect system prompt) |
| |
| remove tokens in the range [preserve, preserve + length) |
| supports DynamicCache and tuple cache formats |
| """ |
| if self.cache is None or length <= 0: |
| return False |
|
|
| cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache" |
| cache_len_before = self.get_cache_length() |
| offset_before = self._position_offset |
|
|
| new_cache, new_offset, success = drop_tokens_from_cache( |
| cache=self.cache, |
| length=length, |
| preserve=self._system_preserve_length, |
| position_offset=self._position_offset, |
| rope_theta=self._get_rope_theta(), |
| inv_freq_cache=self._rope_inv_freq_cache, |
| ) |
| if success: |
| self.cache = new_cache |
| self._position_offset = new_offset |
|
|
| return success |
|
|
| def _drop_unit(self, unit_id: int) -> bool: |
| """remove specified unit""" |
| entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
| if not entries: |
| return False |
|
|
| total_len = sum(e["length"] for e in entries) |
| if total_len <= 0: |
| for e in entries: |
| self._unit_history.remove(e) |
| return False |
|
|
| if not self._drop_tokens_from_cache(total_len): |
| return False |
|
|
| for e in entries: |
| self._unit_history.remove(e) |
|
|
| return True |
|
|
| def _drop_next_unit(self) -> bool: |
| """remove the earliest non-system unit""" |
| for entry in self._unit_history: |
| unit_id = entry.get("unit_id") |
| if unit_id is None: |
| continue |
| |
| if entry.get("type") == "system": |
| continue |
| if self._drop_unit(unit_id): |
| return True |
| return False |
|
|
| def enforce_window(self) -> bool: |
| """enforce sliding window strategy (same as single-mode, only look at cache length) |
| |
| when cache length exceeds high water line, loop to remove the earliest unit, |
| until cache length drops below the low water line. |
| """ |
| if not self._window_enabled: |
| return False |
|
|
| cfg = self._window_config |
| cache_len_before = self.get_cache_length() |
|
|
| if cache_len_before <= cfg.basic_window_high_tokens: |
| return False |
|
|
| dropped_count = 0 |
| cache_len = cache_len_before |
| while cache_len > cfg.basic_window_low_tokens: |
| if not self._drop_next_unit(): |
| break |
| dropped_count += 1 |
| cache_len = self.get_cache_length() |
|
|
| if dropped_count > 0: |
| |
| self._sliding_event_count += 1 |
| self._total_dropped_tokens += cache_len_before - cache_len |
| self._total_dropped_units += dropped_count |
|
|
| |
| expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| is_consistent = expected == cache_len |
| if not is_consistent: |
| logger.error( |
| "CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d", |
| self._system_preserve_length, |
| sum(u["length"] for u in self._unit_history), |
| cache_len, |
| self._position_offset, |
| ) |
|
|
| return dropped_count > 0 |
|
|
| |
|
|
| def register_system_prompt_with_context( |
| self, |
| suffix_token_ids: Optional[List[int]] = None, |
| context_previous_marker: str = "\n\nprevious: ", |
| ): |
| """register system prompt (with context preserving mode) |
| |
| initial cache layout: [prefix] [suffix] [units...] |
| after first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...] |
| |
| when calling this method, cache should only have prefix (without previous marker) |
| suffix will be fed in later |
| |
| Args: |
| suffix_token_ids: suffix token ids (e.g. id of <|im_end|>) |
| context_previous_marker: previous marker prefix, e.g. "\\n\\nprevious: " |
| """ |
| |
| self._preserve_prefix_length = self.get_cache_length() |
| self._previous_content_length = 0 |
| self._suffix_token_ids = suffix_token_ids or [] |
| |
| self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids) |
|
|
| |
| self._previous_marker = context_previous_marker |
| self._previous_marker_token_ids = ( |
| self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else [] |
| ) |
| self._has_previous = False |
| self._previous_text = "" |
| self._previous_token_ids = [] |
|
|
| def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]: |
| """extract generated text and token ids from units |
| |
| Args: |
| units: list of units to extract |
| |
| Returns: |
| (text, token_ids): concatenated text and token ids (filtered out special tokens) |
| """ |
| text_parts = [] |
| token_ids = [] |
|
|
| for u in units: |
| |
| if u.get("is_listen", False): |
| continue |
| gen_text = u.get("generated_text", "") |
| gen_tokens = u.get("generated_tokens", []) |
|
|
| |
| if gen_text: |
| clean_text = gen_text |
| for st in self._all_special_tokens_text: |
| clean_text = clean_text.replace(st, "") |
| if clean_text.strip(): |
| text_parts.append(clean_text) |
|
|
| |
| if gen_tokens: |
| filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids] |
| token_ids.extend(filtered_tokens) |
|
|
| return "".join(text_parts), token_ids |
|
|
| def _rebuild_cache_with_previous( |
| self, |
| new_previous_tokens: List[int], |
| units_to_keep_len: Optional[int] = None, |
| ) -> bool: |
| """rebuild cache, insert new previous content between prefix and suffix |
| |
| cache layout change: |
| [prefix] [old_prev] [suffix] [old_units] → [prefix] [new_prev] [suffix] [remaining_units] |
| |
| Args: |
| new_previous_tokens: new previous token ids |
| units_to_keep_len: length of units to keep (from cache end backwards) |
| if None, calculate based on unit_history |
| |
| Returns: |
| whether successful rebuild |
| """ |
| if self.cache is None: |
| return False |
|
|
| old_previous_len = self._previous_content_length |
| new_previous_len = len(new_previous_tokens) |
| suffix_len = len(self._suffix_token_ids) |
| total_cache_len = self.get_cache_length() |
|
|
| |
| if units_to_keep_len is None: |
| units_to_keep_len = sum(u["length"] for u in self._unit_history) |
|
|
| |
| |
| if new_previous_len == 0 and old_previous_len == 0: |
| |
| |
| preserve_len = self._preserve_prefix_length + suffix_len |
|
|
| |
| |
| if units_to_keep_len > 0: |
| |
| prefix_suffix_cache = self._slice_cache(0, preserve_len) |
| units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None) |
|
|
| |
| dropped_tokens = total_cache_len - preserve_len - units_to_keep_len |
|
|
| |
| |
| if dropped_tokens > 0: |
| old_start = preserve_len + dropped_tokens |
| new_start = preserve_len |
| units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
|
|
| self.cache = self._concat_caches(prefix_suffix_cache, units_cache) |
| else: |
| self.cache = self._slice_cache(0, preserve_len) |
|
|
| return True |
|
|
| |
| prefix_end = self._preserve_prefix_length |
| prefix_cache = self._slice_cache(0, prefix_end) |
|
|
| |
| units_start_in_old_cache = total_cache_len - units_to_keep_len |
| units_cache = None |
| if units_to_keep_len > 0: |
| units_cache = self._slice_cache(units_start_in_old_cache, None) |
|
|
| |
| |
| prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids |
| prev_suffix_len = len(prev_suffix_tokens) |
|
|
| new_prefix_prev_suffix_cache = prefix_cache |
| if prev_suffix_len > 0: |
| |
| prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens) |
| |
| start_pos = self._preserve_prefix_length + self._position_offset |
|
|
| |
| with torch.no_grad(): |
| device = prev_suffix_embeds.device |
| position_ids = torch.arange( |
| start_pos, |
| start_pos + prev_suffix_len, |
| device=device, |
| ).unsqueeze(0) |
|
|
| |
| outputs = self.m( |
| inputs_embeds=( |
| prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds |
| ), |
| position_ids=position_ids, |
| past_key_values=prefix_cache, |
| use_cache=True, |
| return_dict=True, |
| ) |
| |
| new_prefix_prev_suffix_cache = outputs.past_key_values |
|
|
| |
| |
| |
| new_system_total = prefix_end + new_previous_len + suffix_len |
| if units_cache is not None and self._get_cache_len(units_cache) > 0: |
| old_start = units_start_in_old_cache |
| new_start = new_system_total |
|
|
| if old_start != new_start: |
| units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
|
|
| |
| if units_cache is not None and self._get_cache_len(units_cache) > 0: |
| self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache) |
| else: |
| self.cache = new_prefix_prev_suffix_cache |
|
|
| |
| self._previous_content_length = new_previous_len |
| |
| self._system_preserve_length = prefix_end + new_previous_len + suffix_len |
|
|
| |
| prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text |
| suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else "" |
| return True |
|
|
| def _slice_cache(self, start: int, end: Optional[int], clone: bool = True): |
| """slice cache |
| |
| Args: |
| start: start position |
| end: end position (None means to end) |
| clone: whether to clone (default True, to prevent shared memory issues) |
| """ |
| if self.cache is None: |
| return None |
| if isinstance(self.cache, DynamicCache): |
| |
| new_key_cache = [ |
| k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache |
| ] |
| new_value_cache = [ |
| v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache |
| ] |
| new_cache = DynamicCache() |
| new_cache.key_cache = new_key_cache |
| new_cache.value_cache = new_value_cache |
| return new_cache |
| else: |
| |
| if clone: |
| return tuple( |
| (layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache |
| ) |
| else: |
| return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache) |
|
|
| @staticmethod |
| def _get_cache_len(cache) -> int: |
| if cache is None: |
| return 0 |
| if isinstance(cache, DynamicCache): |
| if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0: |
| return cache.key_cache[0].shape[2] |
| return 0 |
|
|
| if cache and cache[0] and cache[0][0] is not None: |
| return cache[0][0].shape[2] |
| return 0 |
|
|
| @staticmethod |
| def _concat_caches(cache1, cache2): |
| if cache1 is None: |
| return cache2 |
| if cache2 is None: |
| return cache1 |
|
|
| if isinstance(cache1, DynamicCache): |
| new_cache = DynamicCache() |
| new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)] |
| new_cache.value_cache = [ |
| torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache) |
| ] |
| return new_cache |
| else: |
| return tuple( |
| ( |
| torch.cat([layer1[0], layer2[0]], dim=2), |
| torch.cat([layer1[1], layer2[1]], dim=2), |
| ) |
| for layer1, layer2 in zip(cache1, cache2) |
| ) |
|
|
| def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int): |
| """reindex RoPE position for cache""" |
| if cache is None or length <= 0: |
| return cache |
|
|
| if isinstance(cache, DynamicCache): |
| device = cache.key_cache[0].device if cache.key_cache else None |
| else: |
| device = cache[0][0].device if cache and cache[0] else None |
|
|
| if device is None: |
| return cache |
|
|
| old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long) |
| new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long) |
|
|
| rope_theta = self._get_rope_theta() |
|
|
| if isinstance(cache, DynamicCache): |
| new_key_cache = [] |
| for k in cache.key_cache: |
| new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache) |
| new_key_cache.append(new_k) |
| cache.key_cache = new_key_cache |
| return cache |
| else: |
| new_cache = [] |
| for layer in cache: |
| new_k = realign_rotary_suffix( |
| layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache |
| ) |
| new_cache.append((new_k, layer[1])) |
| return tuple(new_cache) |
|
|
| def _update_previous( |
| self, |
| new_text: str, |
| new_tokens: List[int], |
| max_tokens: int, |
| ) -> None: |
| """update previous context (also update cache) |
| |
| when first sliding window, dynamically add marker + text, subsequent sliding window append text |
| when content exceeds max_tokens, truncate content (keep marker) |
| rebuild cache to maintain consistency |
| |
| Args: |
| new_text: new text |
| new_tokens: new token ids |
| max_tokens: previous content maximum token count (without marker) |
| """ |
| marker_len = len(self._previous_marker_token_ids) |
| tokens_to_drop = 0 |
|
|
| |
| if not new_tokens and not new_text: |
| |
| self._rebuild_cache_with_previous(self._previous_token_ids) |
| return |
|
|
| if not self._has_previous: |
| |
| self._previous_text = new_text |
| self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens |
| self._has_previous = True |
| else: |
| |
| self._previous_text += new_text |
| self._previous_token_ids.extend(new_tokens) |
|
|
| |
| content_token_count = len(self._previous_token_ids) - marker_len |
|
|
| |
| if content_token_count > max_tokens: |
| |
| tokens_to_drop = content_token_count - max_tokens |
| old_text = self._previous_text |
| |
| content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :] |
| self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens |
| |
| try: |
| self._previous_text = self.tokenizer.decode( |
| content_tokens, |
| skip_special_tokens=True, |
| ) |
| except Exception as e: |
| logger.warning("_update_previous: decode failed: %s", e) |
|
|
| |
| self._rebuild_cache_with_previous(self._previous_token_ids) |
|
|
| def _drop_unit_with_context( |
| self, |
| unit_id: int, |
| max_previous_tokens: int, |
| ) -> Tuple[bool, str, List[int]]: |
| """remove specified unit and return its generated content (for context preserving) |
| |
| process: |
| 1. extract generated content of unit |
| 2. remove unit from cache (without prefix+previous) |
| 3. append generated content to previous |
| 4. rebuild cache (in _update_previous) |
| |
| Args: |
| unit_id: unit ID to remove |
| max_previous_tokens: previous maximum token count |
| |
| Returns: |
| (success, extracted_text, extracted_tokens): whether successful, extracted text and tokens |
| """ |
| entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
| if not entries: |
| return False, "", [] |
|
|
| |
| extracted_text, extracted_tokens = self._extract_generated_text(entries) |
|
|
| |
| total_len = sum(e["length"] for e in entries) |
| if total_len <= 0: |
| for e in entries: |
| self._unit_history.remove(e) |
| return False, extracted_text, extracted_tokens |
|
|
| cache_before = self.get_cache_length() |
|
|
| |
| for e in entries: |
| self._unit_history.remove(e) |
|
|
| |
| |
|
|
| |
| self._update_previous(extracted_text, extracted_tokens, max_previous_tokens) |
|
|
| return True, extracted_text, extracted_tokens |
|
|
| def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool: |
| """remove the earliest non-system unit (with context preserving)""" |
| for entry in self._unit_history: |
| unit_id = entry.get("unit_id") |
| if unit_id is None: |
| continue |
| if entry.get("type") == "system": |
| continue |
| success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens) |
| if success: |
| return True |
| return False |
|
|
| def enforce_window_with_context(self) -> bool: |
| """context preserving sliding window execution |
| |
| when unit count exceeds max_units, remove the earliest unit, |
| and accumulate its generated content to previous. |
| Cache will be automatically rebuilt in _update_previous. |
| |
| Returns: |
| whether sliding window is executed |
| """ |
| if not self._window_enabled: |
| return False |
|
|
| cfg = self._window_config |
|
|
| if cfg.sliding_window_mode != "context": |
| |
| return self.enforce_window() |
|
|
| cache_len_before = self.get_cache_length() |
| units_before = len(self._unit_history) |
|
|
| |
| |
| if units_before <= cfg.context_max_units: |
| return False |
|
|
| |
| dropped_count = 0 |
| while len(self._unit_history) > cfg.context_max_units: |
| if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens): |
| break |
|
|
| dropped_count += 1 |
|
|
| cache_len_after = self.get_cache_length() |
|
|
| if dropped_count > 0: |
| |
| self._sliding_event_count += 1 |
| self._total_dropped_tokens += cache_len_before - cache_len_after |
| self._total_dropped_units += dropped_count |
|
|
| |
| expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
|
|
| return dropped_count > 0 |
|
|
| def get_previous_context(self) -> Tuple[str, List[int]]: |
| """get current accumulated previous context |
| |
| Returns: |
| (previous_text, previous_token_ids): current accumulated text and token ids |
| """ |
| return self._previous_text, self._previous_token_ids.copy() |
|
|
| def get_window_stats(self) -> Dict[str, Any]: |
| """get sliding window statistics""" |
| unit_lengths = [u["length"] for u in self._unit_history] |
| return { |
| "cache_length": self.get_cache_length(), |
| "unit_count": len(self._unit_history), |
| "unit_lengths": unit_lengths, |
| "unit_total_length": sum(unit_lengths), |
| "system_preserve_length": self._system_preserve_length, |
| "position_offset": self._position_offset, |
| "window_enabled": self._window_enabled, |
| "total_generated_tokens": self.get_total_generated_tokens(), |
| "pending_unit_id": self._pending_unit_id, |
| "next_unit_id": self._next_unit_id, |
| "config": { |
| "sliding_window_mode": self._window_config.sliding_window_mode, |
| "basic_window_high_tokens": self._window_config.basic_window_high_tokens, |
| "basic_window_low_tokens": self._window_config.basic_window_low_tokens, |
| "context_previous_max_tokens": self._window_config.context_previous_max_tokens, |
| "context_max_units": self._window_config.context_max_units, |
| }, |
| |
| "preserve_prefix_length": self._preserve_prefix_length, |
| "previous_content_length": self._previous_content_length, |
| "suffix_token_count": len(self._suffix_token_ids), |
| "previous_text_length": len(self._previous_text), |
| "previous_token_count": len(self._previous_token_ids), |
| "has_system_template": self._system_prompt_template is not None, |
| } |
|
|
| def _verify_consistency(self) -> bool: |
| """verify unit history and cache length consistency""" |
| expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| actual = self.get_cache_length() |
| return expected == actual |
|
|
| def print_verification_summary(self) -> Dict[str, Any]: |
| """print verification summary (for comparing off/basic/context mode) |
| |
| Returns: |
| dictionary containing key verification data |
| """ |
| cfg = self._window_config |
|
|
| |
| all_generated_text = [] |
| all_generated_tokens = [] |
| for u in self._unit_history: |
| if not u.get("is_listen", False): |
| gen_text = u.get("generated_text", "") |
| gen_tokens = u.get("generated_tokens", []) |
| if gen_text: |
| all_generated_text.append(gen_text) |
| if gen_tokens: |
| all_generated_tokens.extend(gen_tokens) |
|
|
| combined_text = "".join(all_generated_text) |
|
|
| summary = { |
| "mode": cfg.sliding_window_mode, |
| "final_cache_length": self.get_cache_length(), |
| "final_unit_count": len(self._unit_history), |
| "sliding_event_count": self._sliding_event_count, |
| "total_dropped_tokens": self._total_dropped_tokens, |
| "total_dropped_units": self._total_dropped_units, |
| "total_generated_tokens": len(all_generated_tokens), |
| "generated_text": combined_text, |
| "previous_text": self._previous_text, |
| "previous_token_count": len(self._previous_token_ids), |
| "position_offset": self._position_offset, |
| "system_preserve_length": self._system_preserve_length, |
| } |
|
|
| return summary |
|
|
| def set_window_config(self, config: DuplexWindowConfig) -> None: |
| """set sliding window configuration""" |
| self._window_config = config |
|
|
| def set_window_enabled(self, enabled: bool) -> None: |
| """enable/disable sliding window""" |
| old_enabled = self._window_enabled |
| self._window_enabled = enabled |
|
|
| def get_context(self): |
| return self.context |
|
|
| def embed_token(self, tid): |
| if isinstance(tid, int): |
| tid = torch.tensor([tid], device=self.m.device) |
| return self.m.model.embed_tokens(tid) |
|
|
| def embed_tokens(self, token_ids: List[int]) -> torch.Tensor: |
| """batch embed multiple tokens |
| |
| Args: |
| token_ids: list of token ids |
| |
| Returns: |
| embeddings tensor [L, H] |
| """ |
| if not token_ids: |
| return torch.empty(0, self.m.config.hidden_size, device=self.m.device) |
| tids = torch.tensor(token_ids, device=self.m.device) |
| return self.m.model.embed_tokens(tids) |
|
|
| @torch.no_grad() |
| def feed(self, embeds: torch.Tensor, return_logits: bool = False): |
| """ |
| embeds : [L, H] —— new embedding sequence fed into model at once |
| """ |
| L = embeds.size(0) |
| device = embeds.device |
|
|
| past_len = self.get_cache_length() |
| pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0) |
|
|
| out = self.m( |
| inputs_embeds=embeds.unsqueeze(0), |
| position_ids=pos_ids, |
| past_key_values=self.cache, |
| |
| return_dict=True, |
| output_hidden_states=True, |
| |
| ) |
| self.cache = out.past_key_values |
|
|
| if return_logits: |
| logits = self.m.lm_head(out.hidden_states[-1])[:, -1] |
| return logits, out.hidden_states[-1] |
|
|
| @torch.no_grad() |
| def decode( |
| self, |
| logits, |
| mode: Literal["sampling", "greedy"] = "sampling", |
| temperature=0.7, |
| top_k=20, |
| top_p=0.8, |
| listen_top_k=None, |
| listen_prob_scale=1.0, |
| text_repetition_penalty=1.05, |
| text_repetition_window_size=512, |
| ): |
| """ |
| Args: |
| logits: |
| mode: sampling or greedy |
| temperature: |
| top_k: |
| top_p: |
| listen_top_k: force listen_id to be in top-k to keep |
| listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase) |
| text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition |
| text_repetition_window_size: repetition penalty window size |
| |
| Sampling strategy: |
| 1. first sample all tokens with original logits (apply temperature) |
| 2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop) |
| 3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens |
| 4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling |
| """ |
|
|
| logits = logits.clone() |
|
|
| |
| eos_id = self.chunk_eos_id |
|
|
| with torch.no_grad(): |
| if mode == "greedy": |
| sampled_token = torch.argmax(logits[0]).item() |
| else: |
| original_probs = F.softmax(logits[0], dim=-1) |
| sampled_token = torch.multinomial(original_probs, num_samples=1).item() |
|
|
| |
| if sampled_token == eos_id: |
| next_token_id = torch.tensor([eos_id], device=logits.device) |
| next_token_str = self.tokenizer.decode(next_token_id) |
|
|
| return next_token_id |
|
|
| |
| if self.forbidden_token_ids: |
| logits[:, self.forbidden_token_ids] = float("-inf") |
|
|
| |
| if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0: |
| |
| recent_tokens = self.generated_tokens[-text_repetition_window_size:] |
|
|
| |
| recent_tokens = list(set(recent_tokens)) |
|
|
| |
| for token_id in recent_tokens: |
| if token_id < logits.size(-1): |
| if text_repetition_penalty > 1.0: |
| |
| logits[0, token_id] /= text_repetition_penalty |
| else: |
| |
| logits[0, token_id] *= 1.0 / text_repetition_penalty |
|
|
| if listen_prob_scale != 1.0: |
| logits[0, self.listen_id] *= listen_prob_scale |
|
|
| listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item() |
|
|
| if listen_top_k is not None and listen_rank < listen_top_k: |
| next_token_id = torch.tensor([self.listen_id], device=logits.device) |
| next_token_str = self.tokenizer.decode(next_token_id) |
|
|
| if next_token_str == "<|listen|>": |
| self.context += " " |
| else: |
| self.context += next_token_str |
|
|
| return next_token_id |
|
|
| if mode == "greedy": |
| next_token_id = torch.argmax(logits, dim=-1) |
| elif mode == "sampling": |
| logits = logits / temperature |
| logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
| probs = F.softmax(logits, dim=-1) |
| next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| raise ValueError(f"Unsupported decode mode: {mode}") |
|
|
| if next_token_id.item() not in self.special_token_ids: |
| self.generated_tokens.append(next_token_id.item()) |
| else: |
| self.generated_special_tokens.append(next_token_id.item()) |
|
|
| return next_token_id |
|
|
|
|
| def _download_url_to_tempfile(url: str, suffix: str = "", timeout: int = 60) -> str: |
| """ |
| Download a URL to a temporary file and return the path. |
| |
| Args: |
| url: HTTP/HTTPS URL to download |
| suffix: File suffix (e.g., ".jpg", ".wav", ".mp4") |
| timeout: Download timeout in seconds |
| |
| Returns: |
| Path to the downloaded temporary file |
| """ |
| import tempfile |
|
|
| import requests |
|
|
| response = requests.get(url, timeout=timeout) |
| response.raise_for_status() |
|
|
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f: |
| f.write(response.content) |
| return f.name |
|
|
|
|
| def _is_url(path: str) -> bool: |
| return path.startswith(("http://", "https://")) |
|
|
|
|
| def normalize_content_item(item) -> Union[str, Any, List[Any]]: |
| """Normalize structured content item to native format. |
| |
| Supports: |
| - Native format: str, PIL.Image, np.ndarray (pass through) |
| - OpenAI structured format: |
| - {"type": "text", "text": "..."} -> str |
| - {"type": "image_url", "image_url": {"url": "..."}} -> PIL.Image |
| - {"type": "audio_url", "audio_url": {"url": "..."}} -> np.ndarray |
| - {"type": "video_url", "video_url": {"url": "...", ...}} -> List[Image, ndarray, ...] |
| |
| URL formats supported: |
| - Local file path: "/path/to/file.jpg" |
| - HTTP/HTTPS URL: "https://example.com/image.jpg" |
| |
| Args: |
| item: Content item to normalize |
| |
| Returns: |
| Normalized item. For video_url, returns a tuple ("__video_contents__", list) |
| that will be flattened by normalize_content(). |
| |
| Raises: |
| ValueError: If content type is unknown or unsupported |
| """ |
| import os |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| if isinstance(item, str): |
| return item |
| if isinstance(item, Image.Image): |
| return item |
| if isinstance(item, np.ndarray): |
| return item |
|
|
| if isinstance(item, dict): |
| item_type = item.get("type") |
|
|
| if item_type == "text": |
| return item.get("text", "") |
|
|
| elif item_type == "image_url": |
| image_url_obj = item.get("image_url", {}) |
| url = image_url_obj.get("url", "") if isinstance(image_url_obj, dict) else image_url_obj |
|
|
| if _is_url(url): |
| |
| temp_path = _download_url_to_tempfile(url, suffix=".jpg", timeout=30) |
| img = Image.open(temp_path) |
| os.unlink(temp_path) |
| return img |
| else: |
| return Image.open(url) |
| elif item_type == "audio_url": |
| import librosa |
|
|
| audio_url_obj = item.get("audio_url", {}) |
| url = audio_url_obj.get("url", "") if isinstance(audio_url_obj, dict) else audio_url_obj |
|
|
| if _is_url(url): |
| |
| temp_path = _download_url_to_tempfile(url, suffix=".wav", timeout=60) |
| audio_np, _ = librosa.load(temp_path, sr=16000, mono=True) |
| os.unlink(temp_path) |
| return audio_np |
| else: |
| audio_np, _ = librosa.load(url, sr=16000, mono=True) |
| return audio_np |
| elif item_type == "video_url": |
| |
| |
| |
| from minicpmo.utils import get_video_frame_audio_segments |
|
|
| video_url_obj = item.get("video_url", {}) |
| if isinstance(video_url_obj, dict): |
| video_url = video_url_obj.get("url", "") |
| |
| stack_frames = video_url_obj.get("stack_frames", 1) |
| use_ffmpeg = video_url_obj.get("use_ffmpeg", False) |
| use_audio = video_url_obj.get("use_audio", True) |
| else: |
| video_url = video_url_obj |
| stack_frames = 1 |
| use_ffmpeg = False |
| use_audio = True |
|
|
| |
| temp_video_path = None |
| if _is_url(video_url): |
| temp_video_path = _download_url_to_tempfile(video_url, suffix=".mp4", timeout=120) |
| video_path = temp_video_path |
| else: |
| video_path = video_url |
|
|
| |
| video_frames, audio_segments, stacked_frames = get_video_frame_audio_segments( |
| video_path, |
| stack_frames=stack_frames, |
| use_ffmpeg=use_ffmpeg, |
| ) |
|
|
| |
| if temp_video_path is not None: |
| os.unlink(temp_video_path) |
|
|
| |
| omni_contents = [] |
| for i in range(len(video_frames)): |
| omni_contents.append(video_frames[i]) |
| if use_audio: |
| omni_contents.append(audio_segments[i]) |
| if stacked_frames is not None and i < len(stacked_frames) and stacked_frames[i] is not None: |
| omni_contents.append(stacked_frames[i]) |
|
|
| |
| return "__video_contents__", omni_contents |
| else: |
| raise ValueError(f"Unknown content type: {item_type}") |
|
|
| raise ValueError(f"Cannot normalize content item of type: {type(item)}") |
|
|
|
|
| def normalize_content(content) -> list: |
| """Normalize message content to list of native items. |
| |
| Input formats: |
| - str: "hello" -> ["hello"] |
| - list of native items: [str, Image, np.ndarray] -> pass through with normalization |
| - list of structured items: [{"type": "text", ...}] -> normalize each |
| - video type: automatically expanded to omni_contents |
| - mixed: works too |
| |
| Args: |
| content: Message content in any supported format |
| |
| Returns: |
| List of native items (str, PIL.Image, np.ndarray) |
| |
| Examples: |
| >>> normalize_content("hello") |
| ["hello"] |
| |
| >>> normalize_content([{"type": "text", "text": "hi"}]) |
| ["hi"] |
| |
| >>> normalize_content([{"type": "video", "video": "/path/to/video.mp4"}]) |
| [<PIL.Image>, <np.ndarray>, <PIL.Image>, <np.ndarray>, ...] |
| """ |
| import numpy as np |
| from PIL import Image |
|
|
| if isinstance(content, str): |
| return [content] |
|
|
| if isinstance(content, list): |
| result = [] |
| for item in content: |
| normalized = normalize_content_item(item) |
| |
| if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__": |
| |
| result.extend(normalized[1]) |
| else: |
| result.append(normalized) |
| return result |
|
|
| |
| if isinstance(content, (Image.Image, np.ndarray)): |
| return [content] |
|
|
| normalized = normalize_content_item(content) |
| if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__": |
| return normalized[1] |
| return [normalized] |
|
|