MiniCPM-o-4.5-nvidia-FlagOS / tts_streaming_generate.py
YummyYum's picture
Upload folder using huggingface_hub
be99bcf verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import List
from typing import Union
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as P
class TTSStreamingGenerator:
def __init__(
self,
model,
temperature: float,
eos_token: Union[int, torch.Tensor],
chunk_size: int = 25, # s3tokenizer 1s = 25token
tts_last_turn_tokens: torch.Tensor = None,
logits_processors=[],
logits_warpers=[],
):
self.tts = model
self.device = model.device
self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device)
self.eos_token = (
torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device)
)
self.num_vq = model.num_vq
self.num_audio_tokens = model.num_audio_tokens
self.window_size = model.window_size
self.recomputed_chunks = model.recomputed_chunks
self.emb_code = model.emb_code
self.head_code = model.head_code
# Logits processors
self.logits_processors = logits_processors
# Logits warpers (like TopP/TopK), separate from processors
self.logits_warpers = logits_warpers
# initialize state
self.past_key_values = None
self.text_start_pos = 0
self.idx = -1 # start from -1, become 0 when first called
self.all_conditions = []
self.all_generated_tokens = []
self.tts_last_turn_tokens = tts_last_turn_tokens
self.spk_emb = None
audio_bos = [self.tts.audio_bos_token_id]
audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long)
self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0)
self.text_eos_embed = self.tts.emb_text(
torch.tensor(
[self.tts.config.text_eos_token_id],
device=self.tts.emb_text.weight.device,
dtype=torch.long,
)
).unsqueeze(0)
# buffer related, used to fill up chunk_size and yield to outside
self.chunk_size = chunk_size
self._token_buffer: List[torch.Tensor] = []
@torch.inference_mode()
def generate_with_buffer(
self,
condition: torch.Tensor,
text_finished: bool = False,
max_new_token: int = 500,
):
"""input a condition embedding chunk, generate audio token each time,
and accumulate to buffer, only yield when buffer satisfies chunk_size.
Yields:
torch.Tensor of shape [chunk_size] (2D: [1, chunk_size])
"""
self.idx += 1
self.device = self.tts.device
# if text finished, first concatenate Text EOS
if text_finished:
condition = torch.cat([condition, self.text_eos_embed], dim=1)
# always concatenate Audio BOS
condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device)
self.all_conditions.append(condition)
current_condition = condition
condition_length = current_condition.shape[1]
finished = torch.zeros(1, dtype=torch.bool, device=self.device)
chunk_generated_tokens = []
for t in range(max_new_token):
if t == 0:
inputs_embeds = current_condition
pos_ids = torch.arange(
self.text_start_pos,
self.text_start_pos + condition_length,
dtype=torch.long,
device=self.device,
).unsqueeze(0)
else:
last = self.all_generated_tokens[-1]
# last: [1,1], directly as code id
inputs_embeds = self.emb_code[0](last)
pos_ids = torch.tensor(
[self.text_start_pos + condition_length + t - 1],
dtype=torch.long,
device=self.device,
).unsqueeze(0)
outputs = self.tts.model(
position_ids=pos_ids,
past_key_values=self.past_key_values,
inputs_embeds=inputs_embeds,
use_cache=True,
)
hidden_states = outputs.last_hidden_state
self.past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= self.temperature
audio_bos = len(self.all_generated_tokens) == 0 and t == 0
if not audio_bos:
# use generated tokens (current chunk) as input for processor/warper (align with modeling_minicpmo)
all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) # [1, T]
for processor in self.logits_processors:
logits = processor(all_generated_tokens, logits)
for warper in self.logits_warpers:
logits = warper(all_generated_tokens, logits)
del all_generated_tokens
# sample next token (only use first codebook, same as generate)
scores = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(scores, num_samples=1) # [(B*num_vq), 1]
next_id = idx_next.view(-1, self.num_vq)[:, 0:1] # only take first codebook → [B, 1]
del scores
if next_id.eq(
self.eos_token
).any(): # generated audio eos token, means this chunk is finished, no longer generate new tokens
finished[:] = True
else: # eos token cannot be added to buffer, he does not speak.
# convert next_id to correct shape [1, 1], no num_vq dimension
if next_id.dim() == 0: # if scalar
next_tok = next_id.unsqueeze(0).unsqueeze(0) # [1, 1]
elif next_id.dim() == 1: # if 1D [1]
next_tok = next_id.unsqueeze(0) # [1, 1]
else:
next_tok = next_id
self.all_generated_tokens.append(next_tok)
chunk_generated_tokens.append(next_tok)
self._token_buffer.append(next_tok)
if len(self._token_buffer) == 0:
# case 1: if last text chunk, yield None
if text_finished:
yield torch.empty(1, 0, dtype=torch.long, device=self.device), True
break
# case 2: if not last text chunk, break directly
else:
break
else: # buffer has something
# case 1: if buffer is larger/equal to chunk_size, yield out
if len(self._token_buffer) >= self.chunk_size:
batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) # [1, chunk_size]
yield batch, False # → [1, chunk_size]
# discard yielded part
self._token_buffer = self._token_buffer[self.chunk_size :]
# case 2: if buffer is smaller than chunk_size
else:
# if generation finished, and is the last text chunk, yield all remaining tokens, then break
if finished.all():
if text_finished:
batch = torch.cat(self._token_buffer, dim=1) # [1, chunk_size]
yield batch, True # → [1, chunk_size]
self._token_buffer = []
break
else:
# not the last text chunk, need to wait for next text chunk to fill up buffer, then this call ends
break
else: # generation of this audio chunk is not finished, continue generating
continue
self.text_start_pos += condition_length + len(chunk_generated_tokens)
# note: remaining tokens in buffer will be kept, and accumulated next time