|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from typing import Any |
|
|
from typing import Dict |
|
|
from typing import List |
|
|
from typing import Literal |
|
|
from typing import Optional |
|
|
from typing import Tuple |
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.utils.parametrize as P |
|
|
from transformers.cache_utils import DynamicCache |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerateChunkOutput: |
|
|
chunk_token_ids: torch.Tensor |
|
|
current_inputs_embeds: torch.Tensor |
|
|
input_last_hidden_states: Optional[torch.Tensor] |
|
|
last_hidden_states: Optional[torch.Tensor] |
|
|
past_key_values: Optional[torch.Tensor] |
|
|
finished: bool |
|
|
|
|
|
|
|
|
class ChunkPrefillChunkGenerate: |
|
|
def __init__(self, model, tokenizer, terminators): |
|
|
self.tokenizer = tokenizer |
|
|
self.model = model |
|
|
self.terminators = terminators |
|
|
self.terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
|
|
self.embedding_layer = self.model.get_input_embeddings() |
|
|
|
|
|
self.forbidden_tokens = [ |
|
|
":", |
|
|
":", |
|
|
";", |
|
|
"#", |
|
|
"“", |
|
|
"”", |
|
|
"‘", |
|
|
"’", |
|
|
"@", |
|
|
"*", |
|
|
"【", |
|
|
"】", |
|
|
"「", |
|
|
"」", |
|
|
"(", |
|
|
")", |
|
|
"(", |
|
|
")", |
|
|
"[", |
|
|
"]", |
|
|
"&", |
|
|
"/", |
|
|
"$", |
|
|
] |
|
|
|
|
|
self.forbidden_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.forbidden_tokens] |
|
|
bad_token_ids = getattr(tokenizer, "bad_token_ids", []) |
|
|
if bad_token_ids: |
|
|
self.forbidden_token_ids.extend(bad_token_ids) |
|
|
|
|
|
@staticmethod |
|
|
def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs): |
|
|
num_beams = kwargs.get("num_beams", 3) |
|
|
generation_config = { |
|
|
"num_beams": num_beams, |
|
|
"top_p": 0.8, |
|
|
"top_k": 100, |
|
|
"temperature": 0.7, |
|
|
"do_sample": True, |
|
|
"repetition_penalty": 1.05, |
|
|
} |
|
|
|
|
|
if do_sample: |
|
|
generation_config.update( |
|
|
{ |
|
|
"top_p": 0.8, |
|
|
"top_k": 100, |
|
|
"temperature": 0.7, |
|
|
"do_sample": True, |
|
|
"repetition_penalty": 1.05, |
|
|
} |
|
|
) |
|
|
elif num_beams > 1: |
|
|
generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False}) |
|
|
else: |
|
|
generation_config.update({"do_sample": False, "repetition_penalty": 1.05}) |
|
|
|
|
|
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) |
|
|
generation_config["min_new_tokens"] = min_new_tokens |
|
|
generation_config["max_new_tokens"] = max_new_tokens |
|
|
|
|
|
return generation_config |
|
|
|
|
|
def chunk_generate( |
|
|
self, |
|
|
inputs_embeds: torch.Tensor, |
|
|
past_key_values, |
|
|
is_first_generate_chunk: bool, |
|
|
chunk_size: int, |
|
|
return_hidden_states: bool, |
|
|
do_sample: bool, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
top_k: int, |
|
|
repetition_penalty: float = 1.05, |
|
|
length_penalty: float = 1.0, |
|
|
all_input_ids: Optional[torch.Tensor] = None, |
|
|
) -> GenerateChunkOutput: |
|
|
""" |
|
|
Args: |
|
|
inputs_embeds: [1, seq_len, hidden_dim], Input embeddings of current chunk. |
|
|
past_key_values: [num_layers, 2, batch_size, num_heads, seq_len, head_dim], Past key values for llm. |
|
|
is_first_generate_chunk: bool, Whether this is the first generate chunk. |
|
|
chunk_size: int, The size of the current chunk, default is 10, and it is fixed during training. |
|
|
return_hidden_states: bool Whether to return the hidden states, default is True. |
|
|
do_sample: bool Whether to sample from the model, default is True. |
|
|
temperature: float The temperature for the model, default is 0.7. |
|
|
top_p: float The top-p for the model, default is 0.8. |
|
|
top_k: int The top-k for the model, default is 100. |
|
|
repetition_penalty: float, The repetition penalty for the model, default is 1.05. |
|
|
length_penalty: float, The length penalty for the model, default is 1.0. Higher value means more detailed generation. |
|
|
all_input_ids: Optional[torch.Tensor], The input ids for the current chunk. |
|
|
""" |
|
|
|
|
|
finished = False |
|
|
current_inputs_embeds = inputs_embeds.clone() |
|
|
input_last_hidden_states = [] |
|
|
last_hidden_states = [] |
|
|
generated_tokens = [] |
|
|
|
|
|
for token_idx in range(chunk_size): |
|
|
if is_first_generate_chunk and token_idx == 0: |
|
|
|
|
|
model_inputs = { |
|
|
"inputs_embeds": current_inputs_embeds, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": True, |
|
|
"output_hidden_states": return_hidden_states, |
|
|
} |
|
|
else: |
|
|
model_inputs = { |
|
|
"inputs_embeds": current_inputs_embeds[:, -1:, :], |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": True, |
|
|
"output_hidden_states": return_hidden_states, |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**model_inputs) |
|
|
|
|
|
|
|
|
logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=inputs_embeds.device) |
|
|
|
|
|
|
|
|
if self.forbidden_token_ids: |
|
|
logits[:, self.forbidden_token_ids] = float("-inf") |
|
|
|
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
PENALTY_WINDOW_SIZE = 128 |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
|
|
|
if all_input_ids is not None: |
|
|
|
|
|
if len(generated_tokens) > 0: |
|
|
generated_token_ids = torch.cat(generated_tokens, dim=1) |
|
|
current_sequence = torch.cat( |
|
|
[ |
|
|
all_input_ids[:, -PENALTY_WINDOW_SIZE:], |
|
|
generated_token_ids, |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
else: |
|
|
current_sequence = all_input_ids[:, -PENALTY_WINDOW_SIZE:] |
|
|
unique_token_ids = torch.unique(current_sequence.squeeze(0)) |
|
|
elif len(generated_tokens) > 0: |
|
|
|
|
|
generated_token_ids = torch.cat(generated_tokens, dim=1).squeeze(0) |
|
|
unique_token_ids = torch.unique(generated_token_ids) |
|
|
else: |
|
|
unique_token_ids = torch.tensor([], dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
|
for token_id in unique_token_ids: |
|
|
if logits[0, token_id] > 0: |
|
|
logits[0, token_id] = logits[0, token_id] / repetition_penalty |
|
|
else: |
|
|
logits[0, token_id] = logits[0, token_id] * repetition_penalty |
|
|
|
|
|
|
|
|
if length_penalty != 1.0: |
|
|
for eos_token_id in self.terminators_ids: |
|
|
if logits[0, eos_token_id] > 0: |
|
|
logits[0, eos_token_id] = logits[0, eos_token_id] / length_penalty |
|
|
else: |
|
|
logits[0, eos_token_id] = logits[0, eos_token_id] * length_penalty |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
logits = logits / temperature |
|
|
|
|
|
if do_sample: |
|
|
|
|
|
if top_k > 0: |
|
|
top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits_filtered = torch.full_like(logits, float("-inf")) |
|
|
logits_filtered.scatter_(1, top_k_indices, top_k_logits) |
|
|
logits = logits_filtered |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits[indices_to_remove] = float("-inf") |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
if return_hidden_states: |
|
|
if is_first_generate_chunk and token_idx == 0: |
|
|
input_last_hidden_states.append(outputs.hidden_states[-1]) |
|
|
else: |
|
|
last_hidden_states.append(outputs.hidden_states[-1]) |
|
|
|
|
|
|
|
|
if next_token.item() in self.terminators_ids: |
|
|
finished = True |
|
|
break |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
|
|
|
|
|
|
next_token_embed = self.embedding_layer(next_token) |
|
|
|
|
|
|
|
|
current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embed], dim=1) |
|
|
|
|
|
if len(generated_tokens) > 0: |
|
|
chunk_token_ids = torch.cat(generated_tokens, dim=1) |
|
|
else: |
|
|
|
|
|
if finished: |
|
|
chunk_token_ids = torch.zeros((1, 0), dtype=torch.long, device=current_inputs_embeds.device) |
|
|
else: |
|
|
raise Exception("this should not happen") |
|
|
|
|
|
if len(last_hidden_states) > 0: |
|
|
last_hidden_states = torch.cat(last_hidden_states, dim=1) |
|
|
else: |
|
|
|
|
|
if finished: |
|
|
last_hidden_states = torch.cat(last_hidden_states, dim=1) |
|
|
else: |
|
|
raise Exception("this should not happen") |
|
|
|
|
|
if len(input_last_hidden_states) > 0: |
|
|
input_last_hidden_states = torch.cat(input_last_hidden_states, dim=1) |
|
|
else: |
|
|
input_last_hidden_states = None |
|
|
|
|
|
return GenerateChunkOutput( |
|
|
chunk_token_ids=chunk_token_ids, |
|
|
current_inputs_embeds=current_inputs_embeds, |
|
|
input_last_hidden_states=input_last_hidden_states, |
|
|
last_hidden_states=last_hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
finished=finished, |
|
|
) |
|
|
|
|
|
|
|
|
def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False): |
|
|
""" |
|
|
Incrementally decode tokens from an iterator, handling partial multi-byte characters. |
|
|
|
|
|
When streaming tokens, multi-byte characters (like Chinese) may be split across multiple |
|
|
tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function |
|
|
buffers tokens and only yields complete characters. |
|
|
|
|
|
Args: |
|
|
token_iterator: An iterator yielding (token_ids, is_finished) tuples. |
|
|
token_ids can be torch.Tensor or any iterable of integers. |
|
|
tokenizer: The tokenizer to use for decoding. |
|
|
skip_special_tokens: Whether to skip special tokens during decoding. |
|
|
|
|
|
Yields: |
|
|
(decoded_text, is_finished) tuples where decoded_text is the new text since last yield. |
|
|
""" |
|
|
accumulated_token_ids = [] |
|
|
yielded_text_len = 0 |
|
|
|
|
|
for token_ids, is_finished in token_iterator: |
|
|
|
|
|
if torch.is_tensor(token_ids): |
|
|
accumulated_token_ids.extend(token_ids.reshape(-1).tolist()) |
|
|
else: |
|
|
accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids]) |
|
|
|
|
|
|
|
|
full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens) |
|
|
|
|
|
if is_finished: |
|
|
|
|
|
new_text = full_decoded[yielded_text_len:] |
|
|
yield new_text, is_finished |
|
|
else: |
|
|
|
|
|
|
|
|
new_text = full_decoded[yielded_text_len:] |
|
|
|
|
|
|
|
|
safe_end = len(new_text) |
|
|
while safe_end > 0 and new_text[safe_end - 1] == "\ufffd": |
|
|
safe_end -= 1 |
|
|
|
|
|
safe_text = new_text[:safe_end] if safe_end > 0 else "" |
|
|
yielded_text_len += len(safe_text) |
|
|
yield safe_text, is_finished |
|
|
|
|
|
|
|
|
def torch_clone_recursive(obj): |
|
|
"""Recursively clone nested containers of torch.Tensors. |
|
|
|
|
|
Supported container types: dict, list, tuple. Non-container non-Tensor |
|
|
objects are returned as-is. |
|
|
""" |
|
|
if torch.is_tensor(obj): |
|
|
return obj.clone() |
|
|
elif isinstance(obj, dict): |
|
|
return {k: torch_clone_recursive(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
return [torch_clone_recursive(v) for v in obj] |
|
|
elif isinstance(obj, tuple): |
|
|
return tuple(torch_clone_recursive(v) for v in obj) |
|
|
else: |
|
|
raise ValueError(f"Unsupported type: {type(obj)}") |
|
|
|
|
|
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate half the hidden dims of the input for RoPE.""" |
|
|
dim = x.shape[-1] |
|
|
x1 = x[..., : dim // 2] |
|
|
x2 = x[..., dim // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SpeculativeSnapshot: |
|
|
"""Speculative snapshot for VAD speculative rollback. |
|
|
|
|
|
Used in VAD speculative execution: creates a snapshot after streaming_prefill |
|
|
and before streaming_generate. If speculation fails (user continues speaking), |
|
|
the state can be restored to continue streaming_prefill. |
|
|
|
|
|
Implementation: |
|
|
- LLM KV Cache: only record length, restore by truncation (zero extra VRAM) |
|
|
- Audio KV Cache: requires cloning, as generate sets it to None |
|
|
- Mel processor: save full state snapshot (including buffer) |
|
|
""" |
|
|
|
|
|
|
|
|
llm_cache_length: int |
|
|
audio_cache_length: int |
|
|
|
|
|
|
|
|
new_user_msg: bool |
|
|
llm_generated: bool |
|
|
llm_generate_completed: bool |
|
|
|
|
|
|
|
|
next_round_id: int |
|
|
pending_round_id: Optional[int] |
|
|
omni_chunk_history_length: int |
|
|
|
|
|
|
|
|
tts_last_turn_tokens: Optional[torch.Tensor] |
|
|
|
|
|
|
|
|
audio_chunk_idx: int |
|
|
|
|
|
|
|
|
mel_processor_snapshot: Optional[dict] = None |
|
|
|
|
|
|
|
|
audio_past_key_values: Optional[tuple] = None |
|
|
|
|
|
|
|
|
timestamp: float = 0.0 |
|
|
|
|
|
|
|
|
llm_cache_checksum: Optional[float] = None |
|
|
audio_cache_checksum: Optional[float] = None |
|
|
mel_buffer_checksum: Optional[float] = None |
|
|
|
|
|
|
|
|
rng_state_cpu: Optional[torch.Tensor] = None |
|
|
rng_state_cuda: Optional[torch.Tensor] = None |
|
|
|
|
|
def summary(self) -> str: |
|
|
mel_buf_len = 0 |
|
|
if self.mel_processor_snapshot: |
|
|
buf = self.mel_processor_snapshot.get("buffer") |
|
|
if buf is not None: |
|
|
mel_buf_len = len(buf) |
|
|
return ( |
|
|
f"llm_cache={self.llm_cache_length}, " |
|
|
f"audio_cache={self.audio_cache_length}, " |
|
|
f"audio_chunk_idx={self.audio_chunk_idx}, " |
|
|
f"mel_buffer={mel_buf_len}, " |
|
|
f"history_len={self.omni_chunk_history_length}, " |
|
|
f"new_user_msg={self.new_user_msg}, " |
|
|
f"llm_generated={self.llm_generated}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TTSSamplingParams: |
|
|
top_p: float = 0.85 |
|
|
min_p: float = 0.01 |
|
|
top_k: int = 25 |
|
|
repetition_penalty: float = 1.05 |
|
|
temperature: float = 0.8 |
|
|
win_size: int = 16 |
|
|
tau_r: float = 0.1 |
|
|
|
|
|
|
|
|
class TTSStreamingGenerator: |
|
|
""" |
|
|
Streaming generator for TTS that processes chunks and yields audio tokens in real-time. |
|
|
|
|
|
Supported attention types: |
|
|
- full_attention: Full attention, all tokens can attend to each other |
|
|
- sliding_window: Sliding window attention, KV cache is truncated to fixed size (token_window_size) |
|
|
- sliding_recompute: Sliding recompute, only keep previous chunk and recompute with current chunk |
|
|
- reindex: Keep first chunk as sink, reindex sliding window positions via RoPE rotation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
temperature: float, |
|
|
eos_token: Union[int, torch.Tensor], |
|
|
chunk_size: int = 25, |
|
|
tts_last_turn_tokens: torch.Tensor = None, |
|
|
logits_processors=None, |
|
|
logits_warpers=None, |
|
|
): |
|
|
self.tts = model |
|
|
self.device = model.device |
|
|
self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device) |
|
|
self.eos_token = ( |
|
|
torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device) |
|
|
) |
|
|
|
|
|
self.num_vq = model.num_vq |
|
|
self.num_audio_tokens = model.num_audio_tokens |
|
|
self.recomputed_chunks = model.recomputed_chunks |
|
|
self.emb_code = model.emb_code |
|
|
self.head_code = model.head_code |
|
|
|
|
|
|
|
|
self.attention_type = model.attention_type |
|
|
self.chunk_window_size = model.chunk_window_size |
|
|
self.token_window_size = model.token_window_size |
|
|
|
|
|
|
|
|
self.rope_theta = model.model.config.rope_theta |
|
|
self.head_dim = model.model.config.hidden_size // model.model.config.num_attention_heads |
|
|
|
|
|
|
|
|
self.logits_processors = logits_processors if logits_processors is not None else [] |
|
|
|
|
|
self.logits_warpers = logits_warpers if logits_warpers is not None else [] |
|
|
|
|
|
|
|
|
self.past_key_values = None |
|
|
self.text_start_pos = 0 |
|
|
self.idx = -1 |
|
|
self.all_conditions = [] |
|
|
self.all_generated_tokens = [] |
|
|
self.tts_last_turn_tokens = tts_last_turn_tokens |
|
|
self.spk_emb = None |
|
|
|
|
|
audio_bos = [self.tts.audio_bos_token_id] |
|
|
audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long) |
|
|
|
|
|
self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0) |
|
|
self.text_eos_embed = self.tts.emb_text( |
|
|
torch.tensor( |
|
|
[self.tts.config.text_eos_token_id], |
|
|
device=self.tts.emb_text.weight.device, |
|
|
dtype=torch.long, |
|
|
) |
|
|
).unsqueeze(0) |
|
|
|
|
|
|
|
|
self.chunk_size = chunk_size |
|
|
self._token_buffer: List[torch.Tensor] = [] |
|
|
|
|
|
|
|
|
self._chunk_info: List[dict] = [] |
|
|
self._total_seq_len = 0 |
|
|
|
|
|
|
|
|
self._sink_kv_len = 0 |
|
|
|
|
|
def _build_recompute_inputs(self, current_condition: torch.Tensor) -> torch.Tensor: |
|
|
"""Build recompute inputs for sliding_recompute mode.""" |
|
|
if len(self._chunk_info) == 0: |
|
|
return current_condition |
|
|
|
|
|
prev_chunk = self._chunk_info[-1] |
|
|
prev_condition = prev_chunk["condition"] |
|
|
prev_audio_tokens = prev_chunk["audio_tokens"] |
|
|
|
|
|
recompute_list = [prev_condition] |
|
|
if len(prev_audio_tokens) > 0: |
|
|
prev_audio_embeds = torch.cat([self.emb_code[0](tok) for tok in prev_audio_tokens], dim=1) |
|
|
recompute_list.append(prev_audio_embeds) |
|
|
|
|
|
recompute_list.append(current_condition) |
|
|
return torch.cat(recompute_list, dim=1) |
|
|
|
|
|
def _truncate_kv_cache_sliding_window(self): |
|
|
"""Truncate KV cache for sliding_window mode.""" |
|
|
if self.past_key_values is None: |
|
|
return |
|
|
|
|
|
if hasattr(self.past_key_values, "get_seq_length"): |
|
|
current_kv_len = self.past_key_values.get_seq_length() |
|
|
else: |
|
|
current_kv_len = self.past_key_values[0][0].shape[2] |
|
|
|
|
|
if current_kv_len <= self.token_window_size: |
|
|
return |
|
|
|
|
|
new_cache = DynamicCache() |
|
|
num_layers = ( |
|
|
len(self.past_key_values.key_cache) |
|
|
if hasattr(self.past_key_values, "key_cache") |
|
|
else len(self.past_key_values) |
|
|
) |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
if hasattr(self.past_key_values, "key_cache"): |
|
|
key = self.past_key_values.key_cache[layer_idx][:, :, -self.token_window_size :, :] |
|
|
value = self.past_key_values.value_cache[layer_idx][:, :, -self.token_window_size :, :] |
|
|
else: |
|
|
key = self.past_key_values[layer_idx][0][:, :, -self.token_window_size :, :] |
|
|
value = self.past_key_values[layer_idx][1][:, :, -self.token_window_size :, :] |
|
|
new_cache.update(key, value, layer_idx) |
|
|
|
|
|
self.past_key_values = new_cache |
|
|
|
|
|
@staticmethod |
|
|
def _apply_rope_rotation(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply RoPE rotation to tensor.""" |
|
|
return x * cos + rotate_half(x) * sin |
|
|
|
|
|
def _compute_rope_cos_sin(self, positions: torch.Tensor, device: torch.device, dtype: torch.dtype): |
|
|
"""Compute RoPE cos and sin for given positions.""" |
|
|
dim_half = self.head_dim // 2 |
|
|
freq_seq = torch.arange(0, dim_half, dtype=torch.float32, device=device) |
|
|
inv_freq = 1.0 / (self.rope_theta ** (freq_seq / dim_half)) |
|
|
|
|
|
|
|
|
angles = positions.float().unsqueeze(-1) * inv_freq.unsqueeze(0) |
|
|
angles = torch.cat([angles, angles], dim=-1) |
|
|
|
|
|
cos = angles.cos().to(dtype) |
|
|
sin = angles.sin().to(dtype) |
|
|
return cos, sin |
|
|
|
|
|
def _reindex_kv_cache(self): |
|
|
""" |
|
|
Reindex KV cache for reindex mode: |
|
|
1. Keep first chunk as attention sink |
|
|
2. Keep last chunk |
|
|
3. Discard middle chunks |
|
|
4. Reindex the last chunk's key positions to be right after sink via RoPE rotation |
|
|
""" |
|
|
if self.past_key_values is None or len(self._chunk_info) < 2: |
|
|
return |
|
|
|
|
|
|
|
|
if hasattr(self.past_key_values, "get_seq_length"): |
|
|
current_kv_len = self.past_key_values.get_seq_length() |
|
|
else: |
|
|
current_kv_len = self.past_key_values[0][0].shape[2] |
|
|
|
|
|
|
|
|
sink_len = self._chunk_info[0]["condition_len"] + self._chunk_info[0]["audio_token_count"] |
|
|
|
|
|
|
|
|
last_chunk = self._chunk_info[-1] |
|
|
last_chunk_len = last_chunk["condition_len"] + last_chunk["audio_token_count"] |
|
|
|
|
|
keep_len = sink_len + last_chunk_len |
|
|
|
|
|
|
|
|
device = self.past_key_values.key_cache[0].device |
|
|
dtype = self.past_key_values.key_cache[0].dtype |
|
|
|
|
|
if current_kv_len <= keep_len: |
|
|
last_chunk_kv_len = current_kv_len - sink_len |
|
|
if last_chunk_kv_len <= 0: |
|
|
return |
|
|
self.text_start_pos = current_kv_len |
|
|
return |
|
|
|
|
|
|
|
|
new_cache = DynamicCache() |
|
|
num_layers = len(self.past_key_values.key_cache) |
|
|
|
|
|
original_start_pos = current_kv_len - last_chunk_len |
|
|
new_start_pos = sink_len |
|
|
delta = new_start_pos - original_start_pos |
|
|
delta_positions = torch.full((last_chunk_len,), delta, dtype=torch.float32, device=device) |
|
|
|
|
|
|
|
|
cos, sin = self._compute_rope_cos_sin(delta_positions, device, dtype) |
|
|
cos = cos.unsqueeze(0).unsqueeze(0) |
|
|
sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
key_full = self.past_key_values.key_cache[layer_idx] |
|
|
value_full = self.past_key_values.value_cache[layer_idx] |
|
|
|
|
|
|
|
|
key_sink = key_full[:, :, :sink_len, :] |
|
|
value_sink = value_full[:, :, :sink_len, :] |
|
|
key_last = key_full[:, :, -last_chunk_len:, :] |
|
|
value_last = value_full[:, :, -last_chunk_len:, :] |
|
|
|
|
|
|
|
|
key_last_reindexed = self._apply_rope_rotation(key_last, cos, sin) |
|
|
|
|
|
|
|
|
key = torch.cat([key_sink, key_last_reindexed], dim=2) |
|
|
value = torch.cat([value_sink, value_last], dim=2) |
|
|
|
|
|
new_cache.update(key, value, layer_idx) |
|
|
|
|
|
self.past_key_values = new_cache |
|
|
|
|
|
|
|
|
self.text_start_pos = sink_len + last_chunk_len |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_with_buffer( |
|
|
self, |
|
|
condition: torch.Tensor, |
|
|
text_finished: bool = False, |
|
|
max_new_token: int = 500, |
|
|
): |
|
|
"""input a condition embedding chunk, generate audio token each time, |
|
|
and accumulate to buffer, only yield when buffer satisfies chunk_size. |
|
|
|
|
|
Yields: |
|
|
torch.Tensor of shape [chunk_size] (2D: [1, chunk_size]) |
|
|
""" |
|
|
self.idx += 1 |
|
|
self.device = self.tts.device |
|
|
|
|
|
|
|
|
if text_finished: |
|
|
condition = torch.cat([condition, self.text_eos_embed], dim=1) |
|
|
|
|
|
|
|
|
condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device) |
|
|
|
|
|
self.all_conditions.append(condition) |
|
|
|
|
|
|
|
|
current_chunk_info = { |
|
|
"condition_len": condition.shape[1], |
|
|
"audio_token_count": 0, |
|
|
"condition": condition.clone(), |
|
|
"audio_tokens": [], |
|
|
} |
|
|
|
|
|
|
|
|
if self.attention_type == "sliding_recompute" and self.idx >= 1: |
|
|
|
|
|
self.past_key_values = None |
|
|
current_condition = self._build_recompute_inputs(condition) |
|
|
self.text_start_pos = 0 |
|
|
elif self.attention_type == "reindex" and self.idx >= 1: |
|
|
|
|
|
self._reindex_kv_cache() |
|
|
current_condition = condition |
|
|
|
|
|
if self.past_key_values is not None: |
|
|
if hasattr(self.past_key_values, "get_seq_length"): |
|
|
kv_len = self.past_key_values.get_seq_length() |
|
|
else: |
|
|
kv_len = self.past_key_values[0][0].shape[2] |
|
|
self.text_start_pos = kv_len |
|
|
else: |
|
|
current_condition = condition |
|
|
|
|
|
condition_length = current_condition.shape[1] |
|
|
prefill_len = condition_length |
|
|
finished = torch.zeros(1, dtype=torch.bool, device=self.device) |
|
|
chunk_generated_tokens = [] |
|
|
|
|
|
for t in range(max_new_token): |
|
|
if t == 0: |
|
|
inputs_embeds = current_condition |
|
|
pos_ids = torch.arange( |
|
|
self.text_start_pos, |
|
|
self.text_start_pos + condition_length, |
|
|
dtype=torch.long, |
|
|
device=self.device, |
|
|
).unsqueeze(0) |
|
|
else: |
|
|
last = self.all_generated_tokens[-1] |
|
|
|
|
|
inputs_embeds = self.emb_code[0](last) |
|
|
pos_ids = torch.tensor( |
|
|
[self.text_start_pos + prefill_len + t - 1], |
|
|
dtype=torch.long, |
|
|
device=self.device, |
|
|
).unsqueeze(0) |
|
|
|
|
|
outputs = self.tts.model( |
|
|
position_ids=pos_ids, |
|
|
past_key_values=self.past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=True, |
|
|
) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
if self.attention_type == "sliding_window": |
|
|
self.past_key_values = outputs.past_key_values |
|
|
self._truncate_kv_cache_sliding_window() |
|
|
else: |
|
|
self.past_key_values = outputs.past_key_values |
|
|
|
|
|
with P.cached(): |
|
|
logits = torch.empty( |
|
|
hidden_states.size(0), |
|
|
hidden_states.size(1), |
|
|
self.num_audio_tokens, |
|
|
self.num_vq, |
|
|
dtype=torch.float, |
|
|
device=self.device, |
|
|
) |
|
|
for num_vq_iter in range(self.num_vq): |
|
|
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) |
|
|
logits[..., num_vq_iter] = x |
|
|
del x |
|
|
|
|
|
del hidden_states |
|
|
|
|
|
logits = logits[:, -1].float() |
|
|
|
|
|
logits = logits.permute(0, 2, 1) |
|
|
logits = logits.reshape(-1, logits.size(2)) |
|
|
|
|
|
logits /= self.temperature |
|
|
|
|
|
audio_bos = len(self.all_generated_tokens) == 0 and t == 0 |
|
|
|
|
|
if not audio_bos: |
|
|
|
|
|
all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) |
|
|
for processor in self.logits_processors: |
|
|
logits = processor(all_generated_tokens, logits) |
|
|
|
|
|
for warper in self.logits_warpers: |
|
|
logits = warper(all_generated_tokens, logits) |
|
|
del all_generated_tokens |
|
|
|
|
|
|
|
|
scores = F.softmax(logits, dim=-1) |
|
|
idx_next = torch.multinomial(scores, num_samples=1) |
|
|
next_id = idx_next.view(-1, self.num_vq)[:, 0:1] |
|
|
del scores |
|
|
|
|
|
if next_id.eq( |
|
|
self.eos_token |
|
|
).any(): |
|
|
finished[:] = True |
|
|
else: |
|
|
|
|
|
if next_id.dim() == 0: |
|
|
next_tok = next_id.unsqueeze(0).unsqueeze(0) |
|
|
elif next_id.dim() == 1: |
|
|
next_tok = next_id.unsqueeze(0) |
|
|
else: |
|
|
next_tok = next_id |
|
|
|
|
|
self.all_generated_tokens.append(next_tok) |
|
|
chunk_generated_tokens.append(next_tok) |
|
|
|
|
|
|
|
|
current_chunk_info["audio_tokens"].append(next_tok.clone()) |
|
|
current_chunk_info["audio_token_count"] += 1 |
|
|
|
|
|
self._token_buffer.append(next_tok) |
|
|
|
|
|
if len(self._token_buffer) == 0: |
|
|
|
|
|
if text_finished: |
|
|
yield torch.empty(1, 0, dtype=torch.long, device=self.device), True |
|
|
break |
|
|
|
|
|
else: |
|
|
break |
|
|
else: |
|
|
|
|
|
if len(self._token_buffer) >= self.chunk_size: |
|
|
batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) |
|
|
yield batch, False |
|
|
|
|
|
self._token_buffer = self._token_buffer[self.chunk_size :] |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if finished.all(): |
|
|
if text_finished: |
|
|
batch = torch.cat(self._token_buffer, dim=1) |
|
|
yield batch, True |
|
|
self._token_buffer = [] |
|
|
break |
|
|
else: |
|
|
|
|
|
break |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
self._chunk_info.append(current_chunk_info) |
|
|
self._total_seq_len += condition.shape[1] + len(chunk_generated_tokens) |
|
|
|
|
|
|
|
|
if self.attention_type == "sliding_recompute": |
|
|
|
|
|
self.text_start_pos += prefill_len + len(chunk_generated_tokens) |
|
|
elif self.attention_type == "reindex": |
|
|
|
|
|
if self.past_key_values is not None: |
|
|
if hasattr(self.past_key_values, "get_seq_length"): |
|
|
self.text_start_pos = self.past_key_values.get_seq_length() |
|
|
else: |
|
|
self.text_start_pos = self.past_key_values[0][0].shape[2] |
|
|
else: |
|
|
self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens) |
|
|
else: |
|
|
self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class StreamingWindowConfig: |
|
|
text_window_high_tokens: int = 8000 |
|
|
text_window_low_tokens: int = 6000 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DuplexWindowConfig: |
|
|
"""duplex sliding window configuration |
|
|
|
|
|
sliding window mode: |
|
|
- "off": disable sliding window |
|
|
- "basic": basic sliding window (trigger by cache length) |
|
|
- "context": sliding window with context (trigger by unit number, preserve generated text to previous) |
|
|
""" |
|
|
|
|
|
|
|
|
sliding_window_mode: str = "off" |
|
|
|
|
|
|
|
|
basic_window_high_tokens: int = 8000 |
|
|
basic_window_low_tokens: int = 6000 |
|
|
|
|
|
|
|
|
context_previous_max_tokens: int = 500 |
|
|
context_max_units: int = 24 |
|
|
|
|
|
|
|
|
verify_mode: bool = False |
|
|
|
|
|
|
|
|
def as_dynamic_cache(past_key_values): |
|
|
"""Convert legacy tuple cache to DynamicCache if needed.""" |
|
|
if isinstance(past_key_values, DynamicCache): |
|
|
return past_key_values |
|
|
|
|
|
if isinstance(past_key_values, tuple): |
|
|
return DynamicCache.from_legacy_cache(past_key_values) |
|
|
|
|
|
return past_key_values |
|
|
|
|
|
|
|
|
def get_kv_cache_length(cache) -> int: |
|
|
"""Get the sequence length of a KV cache. |
|
|
|
|
|
Args: |
|
|
cache: DynamicCache or tuple-based cache |
|
|
|
|
|
Returns: |
|
|
The number of tokens in the cache |
|
|
""" |
|
|
if cache is None: |
|
|
return 0 |
|
|
|
|
|
if isinstance(cache, DynamicCache): |
|
|
if not cache.key_cache or not cache.key_cache[0].numel(): |
|
|
return 0 |
|
|
return cache.key_cache[0].shape[-2] |
|
|
|
|
|
if isinstance(cache, tuple): |
|
|
return cache[0][0].shape[2] |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
def get_rotary_cos_sin( |
|
|
head_dim: int, |
|
|
positions: torch.Tensor, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
rope_theta: float = 10000.0, |
|
|
inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute RoPE cos and sin components for given positions. |
|
|
|
|
|
Args: |
|
|
head_dim: Dimension of each attention head |
|
|
positions: Position indices tensor |
|
|
device: Target device |
|
|
dtype: Target dtype |
|
|
rope_theta: RoPE base frequency (default 10000.0) |
|
|
inv_freq_cache: Optional cache dict for inverse frequencies |
|
|
|
|
|
Returns: |
|
|
Tuple of (cos, sin) tensors with shape [1, 1, seq_len, head_dim] |
|
|
""" |
|
|
cache_key = (head_dim, device) |
|
|
|
|
|
inv_freq = inv_freq_cache.get(cache_key) if inv_freq_cache is not None else None |
|
|
if inv_freq is None or inv_freq.device != device or inv_freq.shape[0] != head_dim // 2: |
|
|
exponent = torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim |
|
|
inv_freq = 1.0 / (rope_theta**exponent) |
|
|
if inv_freq_cache is not None: |
|
|
inv_freq_cache[cache_key] = inv_freq |
|
|
|
|
|
positions = positions.to(device=device, dtype=torch.float32) |
|
|
angles = torch.einsum("i,j->ij", positions, inv_freq) |
|
|
cos = torch.cos(angles) |
|
|
sin = torch.sin(angles) |
|
|
|
|
|
|
|
|
|
|
|
cos_full = torch.cat([cos, cos], dim=-1).to(dtype=dtype) |
|
|
sin_full = torch.cat([sin, sin], dim=-1).to(dtype=dtype) |
|
|
cos_full = cos_full.unsqueeze(0).unsqueeze(0) |
|
|
sin_full = sin_full.unsqueeze(0).unsqueeze(0) |
|
|
return cos_full, sin_full |
|
|
|
|
|
|
|
|
def realign_rotary_suffix( |
|
|
suffix_keys: torch.Tensor, |
|
|
old_positions: torch.Tensor, |
|
|
new_positions: torch.Tensor, |
|
|
rope_theta: float = 10000.0, |
|
|
inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Realign RoPE position encoding after cache eviction. |
|
|
|
|
|
When tokens are dropped from the middle of a cache, the suffix tokens |
|
|
need their RoPE embeddings recalculated with new position indices. |
|
|
|
|
|
Args: |
|
|
suffix_keys: Key tensor to realign, shape [batch, heads, seq_len, head_dim] |
|
|
old_positions: Original position indices |
|
|
new_positions: New position indices after eviction |
|
|
rope_theta: RoPE base frequency |
|
|
inv_freq_cache: Optional cache dict for inverse frequencies |
|
|
|
|
|
Returns: |
|
|
Realigned key tensor with same shape as input |
|
|
""" |
|
|
if suffix_keys.numel() == 0: |
|
|
return suffix_keys |
|
|
|
|
|
head_dim = suffix_keys.shape[-1] |
|
|
device = suffix_keys.device |
|
|
dtype = suffix_keys.dtype |
|
|
|
|
|
|
|
|
cos_old, sin_old = get_rotary_cos_sin(head_dim, old_positions, device, dtype, rope_theta, inv_freq_cache) |
|
|
|
|
|
|
|
|
base = cos_old * suffix_keys - sin_old * rotate_half(suffix_keys) |
|
|
|
|
|
|
|
|
cos_new, sin_new = get_rotary_cos_sin(head_dim, new_positions, device, dtype, rope_theta, inv_freq_cache) |
|
|
|
|
|
|
|
|
return cos_new * base + sin_new * rotate_half(base) |
|
|
|
|
|
|
|
|
def drop_tokens_from_cache( |
|
|
cache: Optional[DynamicCache | Tuple], |
|
|
length: int, |
|
|
preserve: int, |
|
|
position_offset: int, |
|
|
rope_theta: float = 10000.0, |
|
|
inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None, |
|
|
) -> Tuple[Optional[DynamicCache], int, bool]: |
|
|
"""Drop tokens from a KV cache while preserving system prompt. |
|
|
|
|
|
Removes tokens in the range [preserve, preserve + length) from the cache, |
|
|
realigning RoPE embeddings for the suffix. |
|
|
|
|
|
Args: |
|
|
cache: DynamicCache or tuple-based cache (will be converted to DynamicCache) |
|
|
length: Number of tokens to drop |
|
|
preserve: Number of tokens to preserve at the start (system prompt) |
|
|
position_offset: Current position offset for RoPE calculation |
|
|
rope_theta: RoPE base frequency |
|
|
inv_freq_cache: Optional cache dict for inverse frequencies |
|
|
|
|
|
Returns: |
|
|
Tuple of (cache, new_position_offset, success) |
|
|
Note: Tuple cache will be converted to DynamicCache. Modification is in-place. |
|
|
""" |
|
|
if cache is None or length <= 0: |
|
|
return cache, position_offset, False |
|
|
|
|
|
cache = as_dynamic_cache(cache) |
|
|
|
|
|
total_len = get_kv_cache_length(cache) |
|
|
if total_len <= 0: |
|
|
return cache, position_offset, False |
|
|
|
|
|
preserve = min(preserve, total_len) |
|
|
available = total_len - preserve |
|
|
|
|
|
if available < length: |
|
|
logger.warning( |
|
|
"Cannot drop %d tokens: only %d available (total=%d, preserve=%d)", |
|
|
length, |
|
|
available, |
|
|
total_len, |
|
|
preserve, |
|
|
) |
|
|
return cache, position_offset, False |
|
|
|
|
|
suffix_len = total_len - preserve - length |
|
|
|
|
|
|
|
|
suffix_offset = preserve + length |
|
|
prefix_offset = preserve |
|
|
|
|
|
|
|
|
old_positions = None |
|
|
new_positions = None |
|
|
if suffix_len > 0: |
|
|
device = cache.key_cache[0].device |
|
|
old_positions = torch.arange( |
|
|
suffix_offset, |
|
|
suffix_offset + suffix_len, |
|
|
device=device, |
|
|
dtype=torch.long, |
|
|
) |
|
|
new_positions = torch.arange( |
|
|
prefix_offset, |
|
|
prefix_offset + suffix_len, |
|
|
device=device, |
|
|
dtype=torch.long, |
|
|
) |
|
|
|
|
|
keep_len = total_len - length |
|
|
|
|
|
|
|
|
for layer_idx in range(len(cache.key_cache)): |
|
|
key_tensor = cache.key_cache[layer_idx] |
|
|
value_tensor = cache.value_cache[layer_idx] |
|
|
|
|
|
if not key_tensor.numel(): |
|
|
continue |
|
|
|
|
|
|
|
|
prefix_keys = key_tensor[:, :, :preserve, :] |
|
|
prefix_values = value_tensor[:, :, :preserve, :] |
|
|
|
|
|
if suffix_len > 0: |
|
|
|
|
|
suffix_keys = key_tensor[:, :, preserve + length :, :] |
|
|
suffix_values = value_tensor[:, :, preserve + length :, :] |
|
|
|
|
|
if old_positions is not None and new_positions is not None and suffix_keys.numel(): |
|
|
suffix_keys = realign_rotary_suffix( |
|
|
suffix_keys, |
|
|
old_positions, |
|
|
new_positions, |
|
|
rope_theta, |
|
|
inv_freq_cache, |
|
|
) |
|
|
|
|
|
cache.key_cache[layer_idx] = torch.cat([prefix_keys, suffix_keys], dim=-2).contiguous() |
|
|
cache.value_cache[layer_idx] = torch.cat([prefix_values, suffix_values], dim=-2).contiguous() |
|
|
else: |
|
|
cache.key_cache[layer_idx] = prefix_keys.contiguous() |
|
|
cache.value_cache[layer_idx] = prefix_values.contiguous() |
|
|
|
|
|
cache.crop(keep_len) |
|
|
cache._seen_tokens = max(keep_len, 0) |
|
|
|
|
|
new_offset = position_offset + length |
|
|
logger.debug("Dropped %d tokens from cache, new length=%d", length, keep_len) |
|
|
|
|
|
return cache, new_offset, True |
|
|
|
|
|
|
|
|
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")): |
|
|
logits = logits.clone() |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
top_k = min(top_k, logits.size(-1)) |
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
|
|
|
if top_p > 0.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
probs = F.softmax(sorted_logits, dim=-1) |
|
|
cumulative_probs = torch.cumsum(probs, dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
|
logits[0, indices_to_remove] = filter_value |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
class StreamDecoder: |
|
|
def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None): |
|
|
self.m = llm |
|
|
self.tokenizer = tokenizer |
|
|
self.listen_id = self.tokenizer.eos_token_id |
|
|
|
|
|
self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>") |
|
|
self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>") |
|
|
self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>") |
|
|
self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>") |
|
|
|
|
|
self.special_token_ids = special_token_ids if special_token_ids is not None else [] |
|
|
|
|
|
|
|
|
self._all_special_ids = set() |
|
|
self._all_special_tokens_text = set() |
|
|
if self.tokenizer: |
|
|
if hasattr(self.tokenizer, "all_special_ids"): |
|
|
self._all_special_ids = set(self.tokenizer.all_special_ids) |
|
|
if hasattr(self.tokenizer, "all_special_tokens"): |
|
|
self._all_special_tokens_text = set(self.tokenizer.all_special_tokens) |
|
|
|
|
|
custom_special_tokens = [ |
|
|
"<unit>", |
|
|
"</unit>", |
|
|
"<image>", |
|
|
"</image>", |
|
|
"<slice>", |
|
|
"</slice>", |
|
|
"<|listen|>", |
|
|
"<|speak|>", |
|
|
"<|tts_bos|>", |
|
|
"<|tts_eos|>", |
|
|
"<|audio_start|>", |
|
|
"<|audio_end|>", |
|
|
"<|chunk_eos|>", |
|
|
"<|chunk_tts_eos|>", |
|
|
"<|turn_eos|>", |
|
|
"<|audio_start|>", |
|
|
"<|audio_end|>", |
|
|
] |
|
|
self._all_special_tokens_text.update(custom_special_tokens) |
|
|
for token in custom_special_tokens: |
|
|
token_id = self.tokenizer.convert_tokens_to_ids(token) |
|
|
if token_id is not None and token_id != self.tokenizer.unk_token_id: |
|
|
self._all_special_ids.add(token_id) |
|
|
|
|
|
if forbidden_token_ids is None: |
|
|
self.forbidden_token_ids = [] |
|
|
elif isinstance(forbidden_token_ids, int): |
|
|
self.forbidden_token_ids = [self.forbidden_token_ids] |
|
|
else: |
|
|
self.forbidden_token_ids = forbidden_token_ids |
|
|
self.forbidden_token_ids.append(self.chunk_eos_id) |
|
|
|
|
|
assert isinstance(self.forbidden_token_ids, list) |
|
|
|
|
|
self.cache = None |
|
|
self.context = "" |
|
|
self.generated_tokens = [] |
|
|
self.generated_special_tokens = [] |
|
|
self.reset() |
|
|
self.embeds = None |
|
|
self.system_embeds = None |
|
|
|
|
|
|
|
|
self._unit_history: List[Dict[str, Any]] = [] |
|
|
self._next_unit_id: int = 0 |
|
|
self._pending_unit_id: Optional[int] = None |
|
|
self._pending_unit_start_cache_len: int = 0 |
|
|
self._system_preserve_length: int = 0 |
|
|
self._position_offset: int = 0 |
|
|
self._window_config = DuplexWindowConfig() |
|
|
self._window_enabled: bool = True |
|
|
self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._preserve_prefix_length: int = 0 |
|
|
self._previous_content_length: int = 0 |
|
|
self._suffix_token_ids: List[int] = [] |
|
|
|
|
|
|
|
|
self._previous_marker: str = "\n\nprevious: " |
|
|
self._previous_marker_token_ids: List[int] = [] |
|
|
self._has_previous: bool = False |
|
|
|
|
|
|
|
|
self._previous_text: str = "" |
|
|
self._previous_token_ids: List[int] = [] |
|
|
|
|
|
|
|
|
self._sliding_event_count: int = 0 |
|
|
self._total_dropped_tokens: int = 0 |
|
|
self._total_dropped_units: int = 0 |
|
|
|
|
|
def sliding_embeds(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
def reset(self): |
|
|
self.context = "" |
|
|
self.cache = None |
|
|
self.generated_tokens = [] |
|
|
self.generated_special_tokens = [] |
|
|
self.embeds = None |
|
|
self.system_embeds = None |
|
|
|
|
|
|
|
|
old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0 |
|
|
self._unit_history = [] |
|
|
self._next_unit_id = 0 |
|
|
self._pending_unit_id = None |
|
|
self._pending_unit_start_cache_len = 0 |
|
|
self._system_preserve_length = 0 |
|
|
self._position_offset = 0 |
|
|
self._rope_inv_freq_cache = {} |
|
|
|
|
|
|
|
|
self._preserve_prefix_length = 0 |
|
|
self._previous_content_length = 0 |
|
|
self._suffix_token_ids = [] |
|
|
self._previous_marker = "\n\nprevious: " |
|
|
self._previous_marker_token_ids = [] |
|
|
self._has_previous = False |
|
|
self._previous_text = "" |
|
|
self._previous_token_ids = [] |
|
|
|
|
|
|
|
|
self._sliding_event_count = 0 |
|
|
self._total_dropped_tokens = 0 |
|
|
self._total_dropped_units = 0 |
|
|
|
|
|
def get_cache_length(self) -> int: |
|
|
if self.cache is None: |
|
|
return 0 |
|
|
if isinstance(self.cache, DynamicCache): |
|
|
if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0: |
|
|
return self.cache.key_cache[0].shape[2] |
|
|
return 0 |
|
|
|
|
|
return self.cache[0][0].shape[2] |
|
|
|
|
|
def get_total_generated_tokens(self) -> int: |
|
|
return sum(len(u.get("generated_tokens", [])) for u in self._unit_history) |
|
|
|
|
|
def register_unit_start(self) -> int: |
|
|
self._pending_unit_id = self._next_unit_id |
|
|
self._pending_unit_start_cache_len = self.get_cache_length() |
|
|
return self._pending_unit_id |
|
|
|
|
|
def register_unit_end( |
|
|
self, |
|
|
input_type: str, |
|
|
generated_tokens: Optional[List[int]] = None, |
|
|
is_listen: bool = False, |
|
|
generated_text: Optional[str] = None, |
|
|
): |
|
|
"""Call when unit ends, record unit information |
|
|
|
|
|
Should be called after feeding </unit> token |
|
|
|
|
|
Args: |
|
|
input_type: "audio" / "video" / "omni" / "system" |
|
|
generated_tokens: tokens generated by the unit (token ids) |
|
|
is_listen: whether the unit is in listen state |
|
|
generated_text: text generated by the unit (used for context preserving mode) |
|
|
""" |
|
|
if self._pending_unit_id is None: |
|
|
logger.warning("register_unit_end called without register_unit_start") |
|
|
return |
|
|
|
|
|
|
|
|
current_cache_len = self.get_cache_length() |
|
|
unit_len = current_cache_len - self._pending_unit_start_cache_len |
|
|
|
|
|
if unit_len > 0: |
|
|
entry = { |
|
|
"unit_id": self._pending_unit_id, |
|
|
"length": unit_len, |
|
|
"type": input_type, |
|
|
"generated_tokens": generated_tokens or [], |
|
|
"generated_text": generated_text or "", |
|
|
"is_listen": is_listen, |
|
|
} |
|
|
self._unit_history.append(entry) |
|
|
|
|
|
self._pending_unit_id = None |
|
|
self._pending_unit_start_cache_len = 0 |
|
|
self._next_unit_id += 1 |
|
|
|
|
|
def register_system_prompt(self): |
|
|
"""Call after system prompt prefill, record preserve length""" |
|
|
self._system_preserve_length = self.get_cache_length() |
|
|
|
|
|
|
|
|
|
|
|
def _get_rope_theta(self) -> float: |
|
|
"""get model rope_theta configuration""" |
|
|
return float(getattr(self.m.config, "rope_theta", 10000.0)) |
|
|
|
|
|
def _drop_tokens_from_cache(self, length: int) -> bool: |
|
|
"""remove specified number of tokens from cache (protect system prompt) |
|
|
|
|
|
remove tokens in the range [preserve, preserve + length) |
|
|
supports DynamicCache and tuple cache formats |
|
|
""" |
|
|
if self.cache is None or length <= 0: |
|
|
return False |
|
|
|
|
|
cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache" |
|
|
cache_len_before = self.get_cache_length() |
|
|
offset_before = self._position_offset |
|
|
|
|
|
new_cache, new_offset, success = drop_tokens_from_cache( |
|
|
cache=self.cache, |
|
|
length=length, |
|
|
preserve=self._system_preserve_length, |
|
|
position_offset=self._position_offset, |
|
|
rope_theta=self._get_rope_theta(), |
|
|
inv_freq_cache=self._rope_inv_freq_cache, |
|
|
) |
|
|
if success: |
|
|
self.cache = new_cache |
|
|
self._position_offset = new_offset |
|
|
|
|
|
return success |
|
|
|
|
|
def _drop_unit(self, unit_id: int) -> bool: |
|
|
"""remove specified unit""" |
|
|
entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
|
|
if not entries: |
|
|
return False |
|
|
|
|
|
total_len = sum(e["length"] for e in entries) |
|
|
if total_len <= 0: |
|
|
for e in entries: |
|
|
self._unit_history.remove(e) |
|
|
return False |
|
|
|
|
|
if not self._drop_tokens_from_cache(total_len): |
|
|
return False |
|
|
|
|
|
for e in entries: |
|
|
self._unit_history.remove(e) |
|
|
|
|
|
return True |
|
|
|
|
|
def _drop_next_unit(self) -> bool: |
|
|
"""remove the earliest non-system unit""" |
|
|
for entry in self._unit_history: |
|
|
unit_id = entry.get("unit_id") |
|
|
if unit_id is None: |
|
|
continue |
|
|
|
|
|
if entry.get("type") == "system": |
|
|
continue |
|
|
if self._drop_unit(unit_id): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def enforce_window(self) -> bool: |
|
|
"""enforce sliding window strategy (same as single-mode, only look at cache length) |
|
|
|
|
|
when cache length exceeds high water line, loop to remove the earliest unit, |
|
|
until cache length drops below the low water line. |
|
|
""" |
|
|
if not self._window_enabled: |
|
|
return False |
|
|
|
|
|
cfg = self._window_config |
|
|
cache_len_before = self.get_cache_length() |
|
|
|
|
|
if cache_len_before <= cfg.basic_window_high_tokens: |
|
|
return False |
|
|
|
|
|
dropped_count = 0 |
|
|
cache_len = cache_len_before |
|
|
while cache_len > cfg.basic_window_low_tokens: |
|
|
if not self._drop_next_unit(): |
|
|
break |
|
|
dropped_count += 1 |
|
|
cache_len = self.get_cache_length() |
|
|
|
|
|
if dropped_count > 0: |
|
|
|
|
|
self._sliding_event_count += 1 |
|
|
self._total_dropped_tokens += cache_len_before - cache_len |
|
|
self._total_dropped_units += dropped_count |
|
|
|
|
|
|
|
|
expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
|
|
is_consistent = expected == cache_len |
|
|
if not is_consistent: |
|
|
logger.error( |
|
|
"CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d", |
|
|
self._system_preserve_length, |
|
|
sum(u["length"] for u in self._unit_history), |
|
|
cache_len, |
|
|
self._position_offset, |
|
|
) |
|
|
|
|
|
return dropped_count > 0 |
|
|
|
|
|
|
|
|
|
|
|
def register_system_prompt_with_context( |
|
|
self, |
|
|
suffix_token_ids: Optional[List[int]] = None, |
|
|
context_previous_marker: str = "\n\nprevious: ", |
|
|
): |
|
|
"""register system prompt (with context preserving mode) |
|
|
|
|
|
initial cache layout: [prefix] [suffix] [units...] |
|
|
after first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...] |
|
|
|
|
|
when calling this method, cache should only have prefix (without previous marker) |
|
|
suffix will be fed in later |
|
|
|
|
|
Args: |
|
|
suffix_token_ids: suffix token ids (e.g. id of <|im_end|>) |
|
|
context_previous_marker: previous marker prefix, e.g. "\\n\\nprevious: " |
|
|
""" |
|
|
|
|
|
self._preserve_prefix_length = self.get_cache_length() |
|
|
self._previous_content_length = 0 |
|
|
self._suffix_token_ids = suffix_token_ids or [] |
|
|
|
|
|
self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids) |
|
|
|
|
|
|
|
|
self._previous_marker = context_previous_marker |
|
|
self._previous_marker_token_ids = ( |
|
|
self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else [] |
|
|
) |
|
|
self._has_previous = False |
|
|
self._previous_text = "" |
|
|
self._previous_token_ids = [] |
|
|
|
|
|
def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]: |
|
|
"""extract generated text and token ids from units |
|
|
|
|
|
Args: |
|
|
units: list of units to extract |
|
|
|
|
|
Returns: |
|
|
(text, token_ids): concatenated text and token ids (filtered out special tokens) |
|
|
""" |
|
|
text_parts = [] |
|
|
token_ids = [] |
|
|
|
|
|
for u in units: |
|
|
|
|
|
if u.get("is_listen", False): |
|
|
continue |
|
|
gen_text = u.get("generated_text", "") |
|
|
gen_tokens = u.get("generated_tokens", []) |
|
|
|
|
|
|
|
|
if gen_text: |
|
|
clean_text = gen_text |
|
|
for st in self._all_special_tokens_text: |
|
|
clean_text = clean_text.replace(st, "") |
|
|
if clean_text.strip(): |
|
|
text_parts.append(clean_text) |
|
|
|
|
|
|
|
|
if gen_tokens: |
|
|
filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids] |
|
|
token_ids.extend(filtered_tokens) |
|
|
|
|
|
return "".join(text_parts), token_ids |
|
|
|
|
|
def _rebuild_cache_with_previous( |
|
|
self, |
|
|
new_previous_tokens: List[int], |
|
|
units_to_keep_len: Optional[int] = None, |
|
|
) -> bool: |
|
|
"""rebuild cache, insert new previous content between prefix and suffix |
|
|
|
|
|
cache layout change: |
|
|
[prefix] [old_prev] [suffix] [old_units] → [prefix] [new_prev] [suffix] [remaining_units] |
|
|
|
|
|
Args: |
|
|
new_previous_tokens: new previous token ids |
|
|
units_to_keep_len: length of units to keep (from cache end backwards) |
|
|
if None, calculate based on unit_history |
|
|
|
|
|
Returns: |
|
|
whether successful rebuild |
|
|
""" |
|
|
if self.cache is None: |
|
|
return False |
|
|
|
|
|
old_previous_len = self._previous_content_length |
|
|
new_previous_len = len(new_previous_tokens) |
|
|
suffix_len = len(self._suffix_token_ids) |
|
|
total_cache_len = self.get_cache_length() |
|
|
|
|
|
|
|
|
if units_to_keep_len is None: |
|
|
units_to_keep_len = sum(u["length"] for u in self._unit_history) |
|
|
|
|
|
|
|
|
|
|
|
if new_previous_len == 0 and old_previous_len == 0: |
|
|
|
|
|
|
|
|
preserve_len = self._preserve_prefix_length + suffix_len |
|
|
|
|
|
|
|
|
|
|
|
if units_to_keep_len > 0: |
|
|
|
|
|
prefix_suffix_cache = self._slice_cache(0, preserve_len) |
|
|
units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None) |
|
|
|
|
|
|
|
|
dropped_tokens = total_cache_len - preserve_len - units_to_keep_len |
|
|
|
|
|
|
|
|
|
|
|
if dropped_tokens > 0: |
|
|
old_start = preserve_len + dropped_tokens |
|
|
new_start = preserve_len |
|
|
units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
|
|
|
|
|
self.cache = self._concat_caches(prefix_suffix_cache, units_cache) |
|
|
else: |
|
|
self.cache = self._slice_cache(0, preserve_len) |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
prefix_end = self._preserve_prefix_length |
|
|
prefix_cache = self._slice_cache(0, prefix_end) |
|
|
|
|
|
|
|
|
units_start_in_old_cache = total_cache_len - units_to_keep_len |
|
|
units_cache = None |
|
|
if units_to_keep_len > 0: |
|
|
units_cache = self._slice_cache(units_start_in_old_cache, None) |
|
|
|
|
|
|
|
|
|
|
|
prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids |
|
|
prev_suffix_len = len(prev_suffix_tokens) |
|
|
|
|
|
new_prefix_prev_suffix_cache = prefix_cache |
|
|
if prev_suffix_len > 0: |
|
|
|
|
|
prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens) |
|
|
|
|
|
start_pos = self._preserve_prefix_length + self._position_offset |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
device = prev_suffix_embeds.device |
|
|
position_ids = torch.arange( |
|
|
start_pos, |
|
|
start_pos + prev_suffix_len, |
|
|
device=device, |
|
|
).unsqueeze(0) |
|
|
|
|
|
|
|
|
outputs = self.m( |
|
|
inputs_embeds=( |
|
|
prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds |
|
|
), |
|
|
position_ids=position_ids, |
|
|
past_key_values=prefix_cache, |
|
|
use_cache=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
new_prefix_prev_suffix_cache = outputs.past_key_values |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_system_total = prefix_end + new_previous_len + suffix_len |
|
|
if units_cache is not None and self._get_cache_len(units_cache) > 0: |
|
|
old_start = units_start_in_old_cache |
|
|
new_start = new_system_total |
|
|
|
|
|
if old_start != new_start: |
|
|
units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
|
|
|
|
|
|
|
|
if units_cache is not None and self._get_cache_len(units_cache) > 0: |
|
|
self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache) |
|
|
else: |
|
|
self.cache = new_prefix_prev_suffix_cache |
|
|
|
|
|
|
|
|
self._previous_content_length = new_previous_len |
|
|
|
|
|
self._system_preserve_length = prefix_end + new_previous_len + suffix_len |
|
|
|
|
|
|
|
|
prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text |
|
|
suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else "" |
|
|
return True |
|
|
|
|
|
def _slice_cache(self, start: int, end: Optional[int], clone: bool = True): |
|
|
"""slice cache |
|
|
|
|
|
Args: |
|
|
start: start position |
|
|
end: end position (None means to end) |
|
|
clone: whether to clone (default True, to prevent shared memory issues) |
|
|
""" |
|
|
if self.cache is None: |
|
|
return None |
|
|
if isinstance(self.cache, DynamicCache): |
|
|
|
|
|
new_key_cache = [ |
|
|
k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache |
|
|
] |
|
|
new_value_cache = [ |
|
|
v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache |
|
|
] |
|
|
new_cache = DynamicCache() |
|
|
new_cache.key_cache = new_key_cache |
|
|
new_cache.value_cache = new_value_cache |
|
|
return new_cache |
|
|
else: |
|
|
|
|
|
if clone: |
|
|
return tuple( |
|
|
(layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache |
|
|
) |
|
|
else: |
|
|
return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache) |
|
|
|
|
|
@staticmethod |
|
|
def _get_cache_len(cache) -> int: |
|
|
if cache is None: |
|
|
return 0 |
|
|
if isinstance(cache, DynamicCache): |
|
|
if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0: |
|
|
return cache.key_cache[0].shape[2] |
|
|
return 0 |
|
|
|
|
|
if cache and cache[0] and cache[0][0] is not None: |
|
|
return cache[0][0].shape[2] |
|
|
return 0 |
|
|
|
|
|
@staticmethod |
|
|
def _concat_caches(cache1, cache2): |
|
|
if cache1 is None: |
|
|
return cache2 |
|
|
if cache2 is None: |
|
|
return cache1 |
|
|
|
|
|
if isinstance(cache1, DynamicCache): |
|
|
new_cache = DynamicCache() |
|
|
new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)] |
|
|
new_cache.value_cache = [ |
|
|
torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache) |
|
|
] |
|
|
return new_cache |
|
|
else: |
|
|
return tuple( |
|
|
( |
|
|
torch.cat([layer1[0], layer2[0]], dim=2), |
|
|
torch.cat([layer1[1], layer2[1]], dim=2), |
|
|
) |
|
|
for layer1, layer2 in zip(cache1, cache2) |
|
|
) |
|
|
|
|
|
def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int): |
|
|
"""reindex RoPE position for cache""" |
|
|
if cache is None or length <= 0: |
|
|
return cache |
|
|
|
|
|
if isinstance(cache, DynamicCache): |
|
|
device = cache.key_cache[0].device if cache.key_cache else None |
|
|
else: |
|
|
device = cache[0][0].device if cache and cache[0] else None |
|
|
|
|
|
if device is None: |
|
|
return cache |
|
|
|
|
|
old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long) |
|
|
new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long) |
|
|
|
|
|
rope_theta = self._get_rope_theta() |
|
|
|
|
|
if isinstance(cache, DynamicCache): |
|
|
new_key_cache = [] |
|
|
for k in cache.key_cache: |
|
|
new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache) |
|
|
new_key_cache.append(new_k) |
|
|
cache.key_cache = new_key_cache |
|
|
return cache |
|
|
else: |
|
|
new_cache = [] |
|
|
for layer in cache: |
|
|
new_k = realign_rotary_suffix( |
|
|
layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache |
|
|
) |
|
|
new_cache.append((new_k, layer[1])) |
|
|
return tuple(new_cache) |
|
|
|
|
|
def _update_previous( |
|
|
self, |
|
|
new_text: str, |
|
|
new_tokens: List[int], |
|
|
max_tokens: int, |
|
|
) -> None: |
|
|
"""update previous context (also update cache) |
|
|
|
|
|
when first sliding window, dynamically add marker + text, subsequent sliding window append text |
|
|
when content exceeds max_tokens, truncate content (keep marker) |
|
|
rebuild cache to maintain consistency |
|
|
|
|
|
Args: |
|
|
new_text: new text |
|
|
new_tokens: new token ids |
|
|
max_tokens: previous content maximum token count (without marker) |
|
|
""" |
|
|
marker_len = len(self._previous_marker_token_ids) |
|
|
tokens_to_drop = 0 |
|
|
|
|
|
|
|
|
if not new_tokens and not new_text: |
|
|
|
|
|
self._rebuild_cache_with_previous(self._previous_token_ids) |
|
|
return |
|
|
|
|
|
if not self._has_previous: |
|
|
|
|
|
self._previous_text = new_text |
|
|
self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens |
|
|
self._has_previous = True |
|
|
else: |
|
|
|
|
|
self._previous_text += new_text |
|
|
self._previous_token_ids.extend(new_tokens) |
|
|
|
|
|
|
|
|
content_token_count = len(self._previous_token_ids) - marker_len |
|
|
|
|
|
|
|
|
if content_token_count > max_tokens: |
|
|
|
|
|
tokens_to_drop = content_token_count - max_tokens |
|
|
old_text = self._previous_text |
|
|
|
|
|
content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :] |
|
|
self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens |
|
|
|
|
|
try: |
|
|
self._previous_text = self.tokenizer.decode( |
|
|
content_tokens, |
|
|
skip_special_tokens=True, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning("_update_previous: decode failed: %s", e) |
|
|
|
|
|
|
|
|
self._rebuild_cache_with_previous(self._previous_token_ids) |
|
|
|
|
|
def _drop_unit_with_context( |
|
|
self, |
|
|
unit_id: int, |
|
|
max_previous_tokens: int, |
|
|
) -> Tuple[bool, str, List[int]]: |
|
|
"""remove specified unit and return its generated content (for context preserving) |
|
|
|
|
|
process: |
|
|
1. extract generated content of unit |
|
|
2. remove unit from cache (without prefix+previous) |
|
|
3. append generated content to previous |
|
|
4. rebuild cache (in _update_previous) |
|
|
|
|
|
Args: |
|
|
unit_id: unit ID to remove |
|
|
max_previous_tokens: previous maximum token count |
|
|
|
|
|
Returns: |
|
|
(success, extracted_text, extracted_tokens): whether successful, extracted text and tokens |
|
|
""" |
|
|
entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
|
|
if not entries: |
|
|
return False, "", [] |
|
|
|
|
|
|
|
|
extracted_text, extracted_tokens = self._extract_generated_text(entries) |
|
|
|
|
|
|
|
|
total_len = sum(e["length"] for e in entries) |
|
|
if total_len <= 0: |
|
|
for e in entries: |
|
|
self._unit_history.remove(e) |
|
|
return False, extracted_text, extracted_tokens |
|
|
|
|
|
cache_before = self.get_cache_length() |
|
|
|
|
|
|
|
|
for e in entries: |
|
|
self._unit_history.remove(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._update_previous(extracted_text, extracted_tokens, max_previous_tokens) |
|
|
|
|
|
return True, extracted_text, extracted_tokens |
|
|
|
|
|
def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool: |
|
|
"""remove the earliest non-system unit (with context preserving)""" |
|
|
for entry in self._unit_history: |
|
|
unit_id = entry.get("unit_id") |
|
|
if unit_id is None: |
|
|
continue |
|
|
if entry.get("type") == "system": |
|
|
continue |
|
|
success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens) |
|
|
if success: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def enforce_window_with_context(self) -> bool: |
|
|
"""context preserving sliding window execution |
|
|
|
|
|
when unit count exceeds max_units, remove the earliest unit, |
|
|
and accumulate its generated content to previous. |
|
|
Cache will be automatically rebuilt in _update_previous. |
|
|
|
|
|
Returns: |
|
|
whether sliding window is executed |
|
|
""" |
|
|
if not self._window_enabled: |
|
|
return False |
|
|
|
|
|
cfg = self._window_config |
|
|
|
|
|
if cfg.sliding_window_mode != "context": |
|
|
|
|
|
return self.enforce_window() |
|
|
|
|
|
cache_len_before = self.get_cache_length() |
|
|
units_before = len(self._unit_history) |
|
|
|
|
|
|
|
|
|
|
|
if units_before <= cfg.context_max_units: |
|
|
return False |
|
|
|
|
|
|
|
|
dropped_count = 0 |
|
|
while len(self._unit_history) > cfg.context_max_units: |
|
|
if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens): |
|
|
break |
|
|
|
|
|
dropped_count += 1 |
|
|
|
|
|
cache_len_after = self.get_cache_length() |
|
|
|
|
|
if dropped_count > 0: |
|
|
|
|
|
self._sliding_event_count += 1 |
|
|
self._total_dropped_tokens += cache_len_before - cache_len_after |
|
|
self._total_dropped_units += dropped_count |
|
|
|
|
|
|
|
|
expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
|
|
|
|
|
return dropped_count > 0 |
|
|
|
|
|
def get_previous_context(self) -> Tuple[str, List[int]]: |
|
|
"""get current accumulated previous context |
|
|
|
|
|
Returns: |
|
|
(previous_text, previous_token_ids): current accumulated text and token ids |
|
|
""" |
|
|
return self._previous_text, self._previous_token_ids.copy() |
|
|
|
|
|
def get_window_stats(self) -> Dict[str, Any]: |
|
|
"""get sliding window statistics""" |
|
|
unit_lengths = [u["length"] for u in self._unit_history] |
|
|
return { |
|
|
"cache_length": self.get_cache_length(), |
|
|
"unit_count": len(self._unit_history), |
|
|
"unit_lengths": unit_lengths, |
|
|
"unit_total_length": sum(unit_lengths), |
|
|
"system_preserve_length": self._system_preserve_length, |
|
|
"position_offset": self._position_offset, |
|
|
"window_enabled": self._window_enabled, |
|
|
"total_generated_tokens": self.get_total_generated_tokens(), |
|
|
"pending_unit_id": self._pending_unit_id, |
|
|
"next_unit_id": self._next_unit_id, |
|
|
"config": { |
|
|
"sliding_window_mode": self._window_config.sliding_window_mode, |
|
|
"basic_window_high_tokens": self._window_config.basic_window_high_tokens, |
|
|
"basic_window_low_tokens": self._window_config.basic_window_low_tokens, |
|
|
"context_previous_max_tokens": self._window_config.context_previous_max_tokens, |
|
|
"context_max_units": self._window_config.context_max_units, |
|
|
}, |
|
|
|
|
|
"preserve_prefix_length": self._preserve_prefix_length, |
|
|
"previous_content_length": self._previous_content_length, |
|
|
"suffix_token_count": len(self._suffix_token_ids), |
|
|
"previous_text_length": len(self._previous_text), |
|
|
"previous_token_count": len(self._previous_token_ids), |
|
|
"has_system_template": self._system_prompt_template is not None, |
|
|
} |
|
|
|
|
|
def _verify_consistency(self) -> bool: |
|
|
"""verify unit history and cache length consistency""" |
|
|
expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
|
|
actual = self.get_cache_length() |
|
|
return expected == actual |
|
|
|
|
|
def print_verification_summary(self) -> Dict[str, Any]: |
|
|
"""print verification summary (for comparing off/basic/context mode) |
|
|
|
|
|
Returns: |
|
|
dictionary containing key verification data |
|
|
""" |
|
|
cfg = self._window_config |
|
|
|
|
|
|
|
|
all_generated_text = [] |
|
|
all_generated_tokens = [] |
|
|
for u in self._unit_history: |
|
|
if not u.get("is_listen", False): |
|
|
gen_text = u.get("generated_text", "") |
|
|
gen_tokens = u.get("generated_tokens", []) |
|
|
if gen_text: |
|
|
all_generated_text.append(gen_text) |
|
|
if gen_tokens: |
|
|
all_generated_tokens.extend(gen_tokens) |
|
|
|
|
|
combined_text = "".join(all_generated_text) |
|
|
|
|
|
summary = { |
|
|
"mode": cfg.sliding_window_mode, |
|
|
"final_cache_length": self.get_cache_length(), |
|
|
"final_unit_count": len(self._unit_history), |
|
|
"sliding_event_count": self._sliding_event_count, |
|
|
"total_dropped_tokens": self._total_dropped_tokens, |
|
|
"total_dropped_units": self._total_dropped_units, |
|
|
"total_generated_tokens": len(all_generated_tokens), |
|
|
"generated_text": combined_text, |
|
|
"previous_text": self._previous_text, |
|
|
"previous_token_count": len(self._previous_token_ids), |
|
|
"position_offset": self._position_offset, |
|
|
"system_preserve_length": self._system_preserve_length, |
|
|
} |
|
|
|
|
|
return summary |
|
|
|
|
|
def set_window_config(self, config: DuplexWindowConfig) -> None: |
|
|
"""set sliding window configuration""" |
|
|
self._window_config = config |
|
|
|
|
|
def set_window_enabled(self, enabled: bool) -> None: |
|
|
"""enable/disable sliding window""" |
|
|
old_enabled = self._window_enabled |
|
|
self._window_enabled = enabled |
|
|
|
|
|
def get_context(self): |
|
|
return self.context |
|
|
|
|
|
def embed_token(self, tid): |
|
|
if isinstance(tid, int): |
|
|
tid = torch.tensor([tid], device=self.m.device) |
|
|
return self.m.model.embed_tokens(tid) |
|
|
|
|
|
def embed_tokens(self, token_ids: List[int]) -> torch.Tensor: |
|
|
"""batch embed multiple tokens |
|
|
|
|
|
Args: |
|
|
token_ids: list of token ids |
|
|
|
|
|
Returns: |
|
|
embeddings tensor [L, H] |
|
|
""" |
|
|
if not token_ids: |
|
|
return torch.empty(0, self.m.config.hidden_size, device=self.m.device) |
|
|
tids = torch.tensor(token_ids, device=self.m.device) |
|
|
return self.m.model.embed_tokens(tids) |
|
|
|
|
|
@torch.no_grad() |
|
|
def feed(self, embeds: torch.Tensor, return_logits: bool = False): |
|
|
""" |
|
|
embeds : [L, H] —— new embedding sequence fed into model at once |
|
|
""" |
|
|
L = embeds.size(0) |
|
|
device = embeds.device |
|
|
|
|
|
past_len = self.get_cache_length() |
|
|
pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0) |
|
|
|
|
|
out = self.m( |
|
|
inputs_embeds=embeds.unsqueeze(0), |
|
|
position_ids=pos_ids, |
|
|
past_key_values=self.cache, |
|
|
|
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
|
|
|
) |
|
|
self.cache = out.past_key_values |
|
|
|
|
|
if return_logits: |
|
|
logits = self.m.lm_head(out.hidden_states[-1])[:, -1] |
|
|
return logits, out.hidden_states[-1] |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode( |
|
|
self, |
|
|
logits, |
|
|
mode: Literal["sampling", "greedy"] = "sampling", |
|
|
temperature=0.7, |
|
|
top_k=20, |
|
|
top_p=0.8, |
|
|
listen_top_k=None, |
|
|
listen_prob_scale=1.0, |
|
|
text_repetition_penalty=1.05, |
|
|
text_repetition_window_size=512, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
logits: |
|
|
mode: sampling or greedy |
|
|
temperature: |
|
|
top_k: |
|
|
top_p: |
|
|
listen_top_k: force listen_id to be in top-k to keep |
|
|
listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase) |
|
|
text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition |
|
|
text_repetition_window_size: repetition penalty window size |
|
|
|
|
|
Sampling strategy: |
|
|
1. first sample all tokens with original logits (apply temperature) |
|
|
2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop) |
|
|
3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens |
|
|
4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling |
|
|
""" |
|
|
|
|
|
logits = logits.clone() |
|
|
|
|
|
|
|
|
eos_id = self.chunk_eos_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
if mode == "greedy": |
|
|
sampled_token = torch.argmax(logits[0]).item() |
|
|
else: |
|
|
original_probs = F.softmax(logits[0], dim=-1) |
|
|
sampled_token = torch.multinomial(original_probs, num_samples=1).item() |
|
|
|
|
|
|
|
|
if sampled_token == eos_id: |
|
|
next_token_id = torch.tensor([eos_id], device=logits.device) |
|
|
next_token_str = self.tokenizer.decode(next_token_id) |
|
|
|
|
|
return next_token_id |
|
|
|
|
|
|
|
|
if self.forbidden_token_ids: |
|
|
logits[:, self.forbidden_token_ids] = float("-inf") |
|
|
|
|
|
|
|
|
if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0: |
|
|
|
|
|
recent_tokens = self.generated_tokens[-text_repetition_window_size:] |
|
|
|
|
|
|
|
|
recent_tokens = list(set(recent_tokens)) |
|
|
|
|
|
|
|
|
for token_id in recent_tokens: |
|
|
if token_id < logits.size(-1): |
|
|
if text_repetition_penalty > 1.0: |
|
|
|
|
|
logits[0, token_id] /= text_repetition_penalty |
|
|
else: |
|
|
|
|
|
logits[0, token_id] *= 1.0 / text_repetition_penalty |
|
|
|
|
|
if listen_prob_scale != 1.0: |
|
|
logits[0, self.listen_id] *= listen_prob_scale |
|
|
|
|
|
listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item() |
|
|
|
|
|
if listen_top_k is not None and listen_rank < listen_top_k: |
|
|
next_token_id = torch.tensor([self.listen_id], device=logits.device) |
|
|
next_token_str = self.tokenizer.decode(next_token_id) |
|
|
|
|
|
if next_token_str == "<|listen|>": |
|
|
self.context += " " |
|
|
else: |
|
|
self.context += next_token_str |
|
|
|
|
|
return next_token_id |
|
|
|
|
|
if mode == "greedy": |
|
|
next_token_id = torch.argmax(logits, dim=-1) |
|
|
elif mode == "sampling": |
|
|
logits = logits / temperature |
|
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
else: |
|
|
raise ValueError(f"Unsupported decode mode: {mode}") |
|
|
|
|
|
if next_token_id.item() not in self.special_token_ids: |
|
|
self.generated_tokens.append(next_token_id.item()) |
|
|
else: |
|
|
self.generated_special_tokens.append(next_token_id.item()) |
|
|
|
|
|
return next_token_id |
|
|
|
|
|
|
|
|
def _download_url_to_tempfile(url: str, suffix: str = "", timeout: int = 60) -> str: |
|
|
""" |
|
|
Download a URL to a temporary file and return the path. |
|
|
|
|
|
Args: |
|
|
url: HTTP/HTTPS URL to download |
|
|
suffix: File suffix (e.g., ".jpg", ".wav", ".mp4") |
|
|
timeout: Download timeout in seconds |
|
|
|
|
|
Returns: |
|
|
Path to the downloaded temporary file |
|
|
""" |
|
|
import tempfile |
|
|
|
|
|
import requests |
|
|
|
|
|
response = requests.get(url, timeout=timeout) |
|
|
response.raise_for_status() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f: |
|
|
f.write(response.content) |
|
|
return f.name |
|
|
|
|
|
|
|
|
def _is_url(path: str) -> bool: |
|
|
return path.startswith(("http://", "https://")) |
|
|
|
|
|
|
|
|
def normalize_content_item(item) -> Union[str, Any, List[Any]]: |
|
|
"""Normalize structured content item to native format. |
|
|
|
|
|
Supports: |
|
|
- Native format: str, PIL.Image, np.ndarray (pass through) |
|
|
- OpenAI structured format: |
|
|
- {"type": "text", "text": "..."} -> str |
|
|
- {"type": "image_url", "image_url": {"url": "..."}} -> PIL.Image |
|
|
- {"type": "audio_url", "audio_url": {"url": "..."}} -> np.ndarray |
|
|
- {"type": "video_url", "video_url": {"url": "...", ...}} -> List[Image, ndarray, ...] |
|
|
|
|
|
URL formats supported: |
|
|
- Local file path: "/path/to/file.jpg" |
|
|
- HTTP/HTTPS URL: "https://example.com/image.jpg" |
|
|
|
|
|
Args: |
|
|
item: Content item to normalize |
|
|
|
|
|
Returns: |
|
|
Normalized item. For video_url, returns a tuple ("__video_contents__", list) |
|
|
that will be flattened by normalize_content(). |
|
|
|
|
|
Raises: |
|
|
ValueError: If content type is unknown or unsupported |
|
|
""" |
|
|
import os |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
if isinstance(item, str): |
|
|
return item |
|
|
if isinstance(item, Image.Image): |
|
|
return item |
|
|
if isinstance(item, np.ndarray): |
|
|
return item |
|
|
|
|
|
if isinstance(item, dict): |
|
|
item_type = item.get("type") |
|
|
|
|
|
if item_type == "text": |
|
|
return item.get("text", "") |
|
|
|
|
|
elif item_type == "image_url": |
|
|
image_url_obj = item.get("image_url", {}) |
|
|
url = image_url_obj.get("url", "") if isinstance(image_url_obj, dict) else image_url_obj |
|
|
|
|
|
if _is_url(url): |
|
|
|
|
|
temp_path = _download_url_to_tempfile(url, suffix=".jpg", timeout=30) |
|
|
img = Image.open(temp_path) |
|
|
os.unlink(temp_path) |
|
|
return img |
|
|
else: |
|
|
return Image.open(url) |
|
|
elif item_type == "audio_url": |
|
|
import librosa |
|
|
|
|
|
audio_url_obj = item.get("audio_url", {}) |
|
|
url = audio_url_obj.get("url", "") if isinstance(audio_url_obj, dict) else audio_url_obj |
|
|
|
|
|
if _is_url(url): |
|
|
|
|
|
temp_path = _download_url_to_tempfile(url, suffix=".wav", timeout=60) |
|
|
audio_np, _ = librosa.load(temp_path, sr=16000, mono=True) |
|
|
os.unlink(temp_path) |
|
|
return audio_np |
|
|
else: |
|
|
audio_np, _ = librosa.load(url, sr=16000, mono=True) |
|
|
return audio_np |
|
|
elif item_type == "video_url": |
|
|
|
|
|
|
|
|
|
|
|
from minicpmo.utils import get_video_frame_audio_segments |
|
|
|
|
|
video_url_obj = item.get("video_url", {}) |
|
|
if isinstance(video_url_obj, dict): |
|
|
video_url = video_url_obj.get("url", "") |
|
|
|
|
|
stack_frames = video_url_obj.get("stack_frames", 1) |
|
|
use_ffmpeg = video_url_obj.get("use_ffmpeg", False) |
|
|
use_audio = video_url_obj.get("use_audio", True) |
|
|
else: |
|
|
video_url = video_url_obj |
|
|
stack_frames = 1 |
|
|
use_ffmpeg = False |
|
|
use_audio = True |
|
|
|
|
|
|
|
|
temp_video_path = None |
|
|
if _is_url(video_url): |
|
|
temp_video_path = _download_url_to_tempfile(video_url, suffix=".mp4", timeout=120) |
|
|
video_path = temp_video_path |
|
|
else: |
|
|
video_path = video_url |
|
|
|
|
|
|
|
|
video_frames, audio_segments, stacked_frames = get_video_frame_audio_segments( |
|
|
video_path, |
|
|
stack_frames=stack_frames, |
|
|
use_ffmpeg=use_ffmpeg, |
|
|
) |
|
|
|
|
|
|
|
|
if temp_video_path is not None: |
|
|
os.unlink(temp_video_path) |
|
|
|
|
|
|
|
|
omni_contents = [] |
|
|
for i in range(len(video_frames)): |
|
|
omni_contents.append(video_frames[i]) |
|
|
if use_audio: |
|
|
omni_contents.append(audio_segments[i]) |
|
|
if stacked_frames is not None and i < len(stacked_frames) and stacked_frames[i] is not None: |
|
|
omni_contents.append(stacked_frames[i]) |
|
|
|
|
|
|
|
|
return "__video_contents__", omni_contents |
|
|
else: |
|
|
raise ValueError(f"Unknown content type: {item_type}") |
|
|
|
|
|
raise ValueError(f"Cannot normalize content item of type: {type(item)}") |
|
|
|
|
|
|
|
|
def normalize_content(content) -> list: |
|
|
"""Normalize message content to list of native items. |
|
|
|
|
|
Input formats: |
|
|
- str: "hello" -> ["hello"] |
|
|
- list of native items: [str, Image, np.ndarray] -> pass through with normalization |
|
|
- list of structured items: [{"type": "text", ...}] -> normalize each |
|
|
- video type: automatically expanded to omni_contents |
|
|
- mixed: works too |
|
|
|
|
|
Args: |
|
|
content: Message content in any supported format |
|
|
|
|
|
Returns: |
|
|
List of native items (str, PIL.Image, np.ndarray) |
|
|
|
|
|
Examples: |
|
|
>>> normalize_content("hello") |
|
|
["hello"] |
|
|
|
|
|
>>> normalize_content([{"type": "text", "text": "hi"}]) |
|
|
["hi"] |
|
|
|
|
|
>>> normalize_content([{"type": "video", "video": "/path/to/video.mp4"}]) |
|
|
[<PIL.Image>, <np.ndarray>, <PIL.Image>, <np.ndarray>, ...] |
|
|
""" |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
if isinstance(content, str): |
|
|
return [content] |
|
|
|
|
|
if isinstance(content, list): |
|
|
result = [] |
|
|
for item in content: |
|
|
normalized = normalize_content_item(item) |
|
|
|
|
|
if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__": |
|
|
|
|
|
result.extend(normalized[1]) |
|
|
else: |
|
|
result.append(normalized) |
|
|
return result |
|
|
|
|
|
|
|
|
if isinstance(content, (Image.Image, np.ndarray)): |
|
|
return [content] |
|
|
|
|
|
normalized = normalize_content_item(content) |
|
|
if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__": |
|
|
return normalized[1] |
|
|
return [normalized] |
|
|
|