| | from threading import Thread |
| | from time import perf_counter |
| | from baseHandler import BaseHandler |
| | import numpy as np |
| | import torch |
| | from transformers import ( |
| | AutoTokenizer, |
| | ) |
| | from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer |
| | import librosa |
| | import logging |
| | from rich.console import Console |
| | from utils.utils import next_power_of_2 |
| | from transformers.utils.import_utils import ( |
| | is_flash_attn_2_available, |
| | ) |
| |
|
| | torch._inductor.config.fx_graph_cache = True |
| | |
| | torch._dynamo.config.cache_size_limit = 15 |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | console = Console() |
| |
|
| |
|
| | if not is_flash_attn_2_available() and torch.cuda.is_available(): |
| | logger.warn( |
| | """Parler TTS works best with flash attention 2, but is not installed |
| | Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`""" |
| | ) |
| |
|
| |
|
| | class ParlerTTSHandler(BaseHandler): |
| | def setup( |
| | self, |
| | should_listen, |
| | model_name="ylacombe/parler-tts-mini-jenny-30H", |
| | device="cuda", |
| | torch_dtype="float16", |
| | compile_mode=None, |
| | gen_kwargs={}, |
| | max_prompt_pad_length=8, |
| | description=( |
| | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " |
| | "She speaks very fast." |
| | ), |
| | play_steps_s=1, |
| | blocksize=512, |
| | ): |
| | self.should_listen = should_listen |
| | self.device = device |
| | self.torch_dtype = getattr(torch, torch_dtype) |
| | self.gen_kwargs = gen_kwargs |
| | self.compile_mode = compile_mode |
| | self.max_prompt_pad_length = max_prompt_pad_length |
| | self.description = description |
| |
|
| | self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = ParlerTTSForConditionalGeneration.from_pretrained( |
| | model_name, torch_dtype=self.torch_dtype |
| | ).to(device) |
| |
|
| | framerate = self.model.audio_encoder.config.frame_rate |
| | self.play_steps = int(framerate * play_steps_s) |
| | self.blocksize = blocksize |
| |
|
| | if self.compile_mode not in (None, "default"): |
| | logger.warning( |
| | "Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'" |
| | ) |
| | self.compile_mode = "default" |
| |
|
| | if self.compile_mode: |
| | self.model.generation_config.cache_implementation = "static" |
| | self.model.forward = torch.compile( |
| | self.model.forward, mode=self.compile_mode, fullgraph=True |
| | ) |
| |
|
| | self.warmup() |
| |
|
| | def prepare_model_inputs( |
| | self, |
| | prompt, |
| | max_length_prompt=50, |
| | pad=False, |
| | ): |
| | pad_args_prompt = ( |
| | {"padding": "max_length", "max_length": max_length_prompt} if pad else {} |
| | ) |
| |
|
| | tokenized_description = self.description_tokenizer( |
| | self.description, return_tensors="pt" |
| | ) |
| | input_ids = tokenized_description.input_ids.to(self.device) |
| | attention_mask = tokenized_description.attention_mask.to(self.device) |
| |
|
| | tokenized_prompt = self.prompt_tokenizer( |
| | prompt, return_tensors="pt", **pad_args_prompt |
| | ) |
| | prompt_input_ids = tokenized_prompt.input_ids.to(self.device) |
| | prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) |
| |
|
| | gen_kwargs = { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "prompt_input_ids": prompt_input_ids, |
| | "prompt_attention_mask": prompt_attention_mask, |
| | **self.gen_kwargs, |
| | } |
| |
|
| | return gen_kwargs |
| |
|
| | def warmup(self): |
| | logger.info(f"Warming up {self.__class__.__name__}") |
| |
|
| | if self.device == "cuda": |
| | start_event = torch.cuda.Event(enable_timing=True) |
| | end_event = torch.cuda.Event(enable_timing=True) |
| |
|
| | |
| | n_steps = 1 if self.compile_mode == "default" else 2 |
| |
|
| | if self.device == "cuda": |
| | torch.cuda.synchronize() |
| | start_event.record() |
| | if self.compile_mode: |
| | pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] |
| | for pad_length in pad_lengths[::-1]: |
| | model_kwargs = self.prepare_model_inputs( |
| | "dummy prompt", max_length_prompt=pad_length, pad=True |
| | ) |
| | for _ in range(n_steps): |
| | _ = self.model.generate(**model_kwargs) |
| | logger.info(f"Warmed up length {pad_length} tokens!") |
| | else: |
| | model_kwargs = self.prepare_model_inputs("dummy prompt") |
| | for _ in range(n_steps): |
| | _ = self.model.generate(**model_kwargs) |
| |
|
| | if self.device == "cuda": |
| | end_event.record() |
| | torch.cuda.synchronize() |
| | logger.info( |
| | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" |
| | ) |
| |
|
| | def process(self, llm_sentence): |
| | if isinstance(llm_sentence, tuple): |
| | llm_sentence, _ = llm_sentence |
| | |
| | console.print(f"[green]ASSISTANT: {llm_sentence}") |
| | nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) |
| |
|
| | pad_args = {} |
| | if self.compile_mode: |
| | |
| | pad_length = next_power_of_2(nb_tokens) |
| | logger.debug(f"padding to {pad_length}") |
| | pad_args["pad"] = True |
| | pad_args["max_length_prompt"] = pad_length |
| |
|
| | tts_gen_kwargs = self.prepare_model_inputs( |
| | llm_sentence, |
| | **pad_args, |
| | ) |
| |
|
| | streamer = ParlerTTSStreamer( |
| | self.model, device=self.device, play_steps=self.play_steps |
| | ) |
| | tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} |
| | torch.manual_seed(0) |
| | thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) |
| | thread.start() |
| |
|
| | for i, audio_chunk in enumerate(streamer): |
| | global pipeline_start |
| | if i == 0 and "pipeline_start" in globals(): |
| | logger.info( |
| | f"Time to first audio: {perf_counter() - pipeline_start:.3f}" |
| | ) |
| | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) |
| | audio_chunk = (audio_chunk * 32768).astype(np.int16) |
| | for i in range(0, len(audio_chunk), self.blocksize): |
| | yield np.pad( |
| | audio_chunk[i : i + self.blocksize], |
| | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), |
| | ) |
| |
|
| | self.should_listen.set() |
| | yield b"END" |