| | from melo.api import TTS |
| | import logging |
| | from baseHandler import BaseHandler |
| | import librosa |
| | import numpy as np |
| | from rich.console import Console |
| | import torch |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | console = Console() |
| |
|
| | WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { |
| | "en": "EN", |
| | "fr": "FR", |
| | "es": "ES", |
| | "zh": "ZH", |
| | "ja": "JP", |
| | "ko": "KR", |
| | } |
| |
|
| | WHISPER_LANGUAGE_TO_MELO_SPEAKER = { |
| | "en": "EN-BR", |
| | "fr": "FR", |
| | "es": "ES", |
| | "zh": "ZH", |
| | "ja": "JP", |
| | "ko": "KR", |
| | } |
| |
|
| |
|
| | class MeloTTSHandler(BaseHandler): |
| | def setup( |
| | self, |
| | should_listen, |
| | device="auto", |
| | language="en", |
| | speaker_to_id="en", |
| | gen_kwargs={}, |
| | blocksize=512, |
| | ): |
| | self.should_listen = should_listen |
| | self.device = device |
| | console.print(f"[green]Device: {device}") |
| | self.language = language |
| | self.model = TTS( |
| | language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device |
| | ) |
| | console.print(f"[green]Model device: {self.model.device}") |
| | self.speaker_id = self.model.hps.data.spk2id[ |
| | WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] |
| | ] |
| | self.blocksize = blocksize |
| | self.warmup() |
| |
|
| | def warmup(self): |
| | logger.info(f"Warming up {self.__class__.__name__}") |
| | _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) |
| |
|
| | def process(self, llm_sentence): |
| | language_code = None |
| |
|
| | if isinstance(llm_sentence, tuple): |
| | llm_sentence, language_code = llm_sentence |
| |
|
| | console.print(f"[green]ASSISTANT: {llm_sentence}") |
| |
|
| | if language_code is not None and self.language != language_code: |
| | try: |
| | self.model = TTS( |
| | language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code], |
| | device=self.device, |
| | ) |
| | self.speaker_id = self.model.hps.data.spk2id[ |
| | WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code] |
| | ] |
| | self.language = language_code |
| | except KeyError: |
| | console.print( |
| | f"[red]Language {language_code} not supported by Melo. Using {self.language} instead." |
| | ) |
| |
|
| | if self.device == "mps": |
| | import time |
| |
|
| | start = time.time() |
| | torch.mps.synchronize() |
| | torch.mps.empty_cache() |
| | _ = ( |
| | time.time() - start |
| | ) |
| |
|
| | try: |
| | audio_chunk = self.model.tts_to_file( |
| | llm_sentence, self.speaker_id, quiet=True |
| | ) |
| | except (AssertionError, RuntimeError) as e: |
| | logger.error(f"Error in MeloTTSHandler: {e}") |
| | audio_chunk = np.array([]) |
| | if len(audio_chunk) == 0: |
| | self.should_listen.set() |
| | return |
| | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) |
| | audio_chunk = (audio_chunk * 32768).astype(np.int16) |
| | for i in range(0, len(audio_chunk), self.blocksize): |
| | yield np.pad( |
| | audio_chunk[i : i + self.blocksize], |
| | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), |
| | ) |
| |
|
| | self.should_listen.set() |
| | yield b"END" |
| |
|