| import logging |
| import os |
| import sys |
| from copy import copy |
| from pathlib import Path |
| from queue import Queue |
| from threading import Event |
| from typing import Optional |
| from sys import platform |
| from VAD.vad_handler import VADHandler |
| from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments |
| from arguments_classes.language_model_arguments import LanguageModelHandlerArguments |
| from arguments_classes.mlx_language_model_arguments import ( |
| MLXLanguageModelHandlerArguments, |
| ) |
| from arguments_classes.module_arguments import ModuleArguments |
| from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments |
| from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments |
| from arguments_classes.socket_receiver_arguments import SocketReceiverArguments |
| from arguments_classes.socket_sender_arguments import SocketSenderArguments |
| from arguments_classes.vad_arguments import VADHandlerArguments |
| from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments |
| from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments |
| import torch |
| import nltk |
| from rich.console import Console |
| from transformers import ( |
| HfArgumentParser, |
| ) |
|
|
| from utils.thread_manager import ThreadManager |
|
|
| |
| try: |
| nltk.data.find("tokenizers/punkt_tab") |
| except (LookupError, OSError): |
| nltk.download("punkt_tab") |
| try: |
| nltk.data.find("tokenizers/averaged_perceptron_tagger_eng") |
| except (LookupError, OSError): |
| nltk.download("averaged_perceptron_tagger_eng") |
|
|
| |
| |
| CURRENT_DIR = Path(__file__).resolve().parent |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") |
|
|
| console = Console() |
| logging.getLogger("numba").setLevel(logging.WARNING) |
|
|
|
|
| def rename_args(args, prefix): |
| """ |
| Rename arguments by removing the prefix and prepares the gen_kwargs. |
| """ |
| gen_kwargs = {} |
| for key in copy(args.__dict__): |
| if key.startswith(prefix): |
| value = args.__dict__.pop(key) |
| new_key = key[len(prefix) + 1 :] |
| if new_key.startswith("gen_"): |
| gen_kwargs[new_key[4:]] = value |
| else: |
| args.__dict__[new_key] = value |
|
|
| args.__dict__["gen_kwargs"] = gen_kwargs |
|
|
| def get_default_arguments(**kwargs): |
| default_args = [ |
| ModuleArguments(), |
| SocketReceiverArguments(), |
| SocketSenderArguments(), |
| VADHandlerArguments(), |
| WhisperSTTHandlerArguments(), |
| ParaformerSTTHandlerArguments(), |
| LanguageModelHandlerArguments(), |
| MLXLanguageModelHandlerArguments(), |
| ParlerTTSHandlerArguments(), |
| MeloTTSHandlerArguments(), |
| ChatTTSHandlerArguments(), |
| ] |
| |
| for arg_obj in default_args: |
| for key, value in kwargs.items(): |
| if hasattr(arg_obj, key): |
| setattr(arg_obj, key, value) |
|
|
| return tuple(default_args) |
|
|
| def parse_arguments(): |
| parser = HfArgumentParser( |
| ( |
| ModuleArguments, |
| SocketReceiverArguments, |
| SocketSenderArguments, |
| VADHandlerArguments, |
| WhisperSTTHandlerArguments, |
| ParaformerSTTHandlerArguments, |
| LanguageModelHandlerArguments, |
| MLXLanguageModelHandlerArguments, |
| ParlerTTSHandlerArguments, |
| MeloTTSHandlerArguments, |
| ChatTTSHandlerArguments, |
| ) |
| ) |
|
|
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| |
| return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| else: |
| |
| return parser.parse_args_into_dataclasses() |
|
|
|
|
| def setup_logger(log_level): |
| global logger |
| logging.basicConfig( |
| level=log_level.upper(), |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| if log_level == "debug": |
| torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) |
|
|
|
|
| def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): |
| if mac_optimal_settings: |
| for kwargs in handler_kwargs: |
| if hasattr(kwargs, "device"): |
| kwargs.device = "mps" |
| if hasattr(kwargs, "mode"): |
| kwargs.mode = "local" |
| if hasattr(kwargs, "stt"): |
| kwargs.stt = "whisper-mlx" |
| if hasattr(kwargs, "llm"): |
| kwargs.llm = "mlx-lm" |
| if hasattr(kwargs, "tts"): |
| kwargs.tts = "melo" |
|
|
|
|
| def check_mac_settings(module_kwargs): |
| if platform == "darwin": |
| if module_kwargs.device == "cuda": |
| raise ValueError( |
| "Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'." |
| ) |
| if module_kwargs.llm != "mlx-lm": |
| logger.warning( |
| "For macOS users, it is recommended to use mlx-lm. You can activate it by passing --llm mlx-lm." |
| ) |
| if module_kwargs.tts != "melo": |
| logger.warning( |
| "If you experiences issues generating the voice, considering setting the tts to melo." |
| ) |
|
|
|
|
| def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): |
| if common_device: |
| for kwargs in handler_kwargs: |
| if hasattr(kwargs, "lm_device"): |
| kwargs.lm_device = common_device |
| if hasattr(kwargs, "tts_device"): |
| kwargs.tts_device = common_device |
| if hasattr(kwargs, "stt_device"): |
| kwargs.stt_device = common_device |
| if hasattr(kwargs, "paraformer_stt_device"): |
| kwargs.paraformer_stt_device = common_device |
|
|
|
|
| def prepare_module_args(module_kwargs, *handler_kwargs): |
| optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs) |
| if platform == "darwin": |
| check_mac_settings(module_kwargs) |
| overwrite_device_argument(module_kwargs.device, *handler_kwargs) |
|
|
|
|
| def prepare_all_args( |
| module_kwargs, |
| whisper_stt_handler_kwargs, |
| paraformer_stt_handler_kwargs, |
| language_model_handler_kwargs, |
| mlx_language_model_handler_kwargs, |
| parler_tts_handler_kwargs, |
| melo_tts_handler_kwargs, |
| chat_tts_handler_kwargs, |
| ): |
| prepare_module_args( |
| module_kwargs, |
| whisper_stt_handler_kwargs, |
| paraformer_stt_handler_kwargs, |
| language_model_handler_kwargs, |
| mlx_language_model_handler_kwargs, |
| parler_tts_handler_kwargs, |
| melo_tts_handler_kwargs, |
| chat_tts_handler_kwargs, |
| ) |
|
|
| rename_args(whisper_stt_handler_kwargs, "stt") |
| rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") |
| rename_args(language_model_handler_kwargs, "lm") |
| rename_args(mlx_language_model_handler_kwargs, "mlx_lm") |
| rename_args(parler_tts_handler_kwargs, "tts") |
| rename_args(melo_tts_handler_kwargs, "melo") |
| rename_args(chat_tts_handler_kwargs, "chat_tts") |
|
|
|
|
| def initialize_queues_and_events(): |
| return { |
| "stop_event": Event(), |
| "should_listen": Event(), |
| "recv_audio_chunks_queue": Queue(), |
| "send_audio_chunks_queue": Queue(), |
| "spoken_prompt_queue": Queue(), |
| "text_prompt_queue": Queue(), |
| "lm_response_queue": Queue(), |
| } |
|
|
|
|
| def build_pipeline( |
| module_kwargs, |
| socket_receiver_kwargs, |
| socket_sender_kwargs, |
| vad_handler_kwargs, |
| whisper_stt_handler_kwargs, |
| paraformer_stt_handler_kwargs, |
| language_model_handler_kwargs, |
| mlx_language_model_handler_kwargs, |
| parler_tts_handler_kwargs, |
| melo_tts_handler_kwargs, |
| chat_tts_handler_kwargs, |
| queues_and_events, |
| ): |
| stop_event = queues_and_events["stop_event"] |
| should_listen = queues_and_events["should_listen"] |
| recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"] |
| send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"] |
| spoken_prompt_queue = queues_and_events["spoken_prompt_queue"] |
| text_prompt_queue = queues_and_events["text_prompt_queue"] |
| lm_response_queue = queues_and_events["lm_response_queue"] |
|
|
| if module_kwargs.mode == "local": |
| from connections.local_audio_streamer import LocalAudioStreamer |
|
|
| local_audio_streamer = LocalAudioStreamer( |
| input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue |
| ) |
| comms_handlers = [local_audio_streamer] |
| should_listen.set() |
| elif module_kwargs.mode == "socket": |
| from connections.socket_receiver import SocketReceiver |
| from connections.socket_sender import SocketSender |
|
|
| comms_handlers = [ |
| SocketReceiver( |
| stop_event, |
| recv_audio_chunks_queue, |
| should_listen, |
| host=socket_receiver_kwargs.recv_host, |
| port=socket_receiver_kwargs.recv_port, |
| chunk_size=socket_receiver_kwargs.chunk_size, |
| ), |
| SocketSender( |
| stop_event, |
| send_audio_chunks_queue, |
| host=socket_sender_kwargs.send_host, |
| port=socket_sender_kwargs.send_port, |
| ), |
| ] |
| else: |
| comms_handlers = [] |
| should_listen.set() |
|
|
| vad = VADHandler( |
| stop_event, |
| queue_in=recv_audio_chunks_queue, |
| queue_out=spoken_prompt_queue, |
| setup_args=(should_listen,), |
| setup_kwargs=vars(vad_handler_kwargs), |
| ) |
|
|
| stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) |
| lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs) |
| tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) |
|
|
| return ThreadManager([*comms_handlers, vad, stt, lm, tts]) |
|
|
|
|
| def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): |
| if module_kwargs.stt == "whisper": |
| from STT.whisper_stt_handler import WhisperSTTHandler |
| return WhisperSTTHandler( |
| stop_event, |
| queue_in=spoken_prompt_queue, |
| queue_out=text_prompt_queue, |
| setup_kwargs=vars(whisper_stt_handler_kwargs), |
| ) |
| elif module_kwargs.stt == "whisper-mlx": |
| from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler |
| return LightningWhisperSTTHandler( |
| stop_event, |
| queue_in=spoken_prompt_queue, |
| queue_out=text_prompt_queue, |
| setup_kwargs=vars(whisper_stt_handler_kwargs), |
| ) |
| elif module_kwargs.stt == "paraformer": |
| from STT.paraformer_handler import ParaformerSTTHandler |
| return ParaformerSTTHandler( |
| stop_event, |
| queue_in=spoken_prompt_queue, |
| queue_out=text_prompt_queue, |
| setup_kwargs=vars(paraformer_stt_handler_kwargs), |
| ) |
| else: |
| raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") |
|
|
|
|
| def get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, mlx_language_model_handler_kwargs): |
| if module_kwargs.llm == "transformers": |
| from LLM.language_model import LanguageModelHandler |
| return LanguageModelHandler( |
| stop_event, |
| queue_in=text_prompt_queue, |
| queue_out=lm_response_queue, |
| setup_kwargs=vars(language_model_handler_kwargs), |
| ) |
| elif module_kwargs.llm == "mlx-lm": |
| from LLM.mlx_language_model import MLXLanguageModelHandler |
| return MLXLanguageModelHandler( |
| stop_event, |
| queue_in=text_prompt_queue, |
| queue_out=lm_response_queue, |
| setup_kwargs=vars(mlx_language_model_handler_kwargs), |
| ) |
| else: |
| raise ValueError("The LLM should be either transformers or mlx-lm") |
|
|
|
|
| def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): |
| if module_kwargs.tts == "parler": |
| from TTS.parler_handler import ParlerTTSHandler |
| return ParlerTTSHandler( |
| stop_event, |
| queue_in=lm_response_queue, |
| queue_out=send_audio_chunks_queue, |
| setup_args=(should_listen,), |
| setup_kwargs=vars(parler_tts_handler_kwargs), |
| ) |
| elif module_kwargs.tts == "melo": |
| try: |
| from TTS.melo_handler import MeloTTSHandler |
| except RuntimeError as e: |
| logger.error( |
| "Error importing MeloTTSHandler. You might need to run: python -m unidic download" |
| ) |
| raise e |
| return MeloTTSHandler( |
| stop_event, |
| queue_in=lm_response_queue, |
| queue_out=send_audio_chunks_queue, |
| setup_args=(should_listen,), |
| setup_kwargs=vars(melo_tts_handler_kwargs), |
| ) |
| elif module_kwargs.tts == "chatTTS": |
| try: |
| from TTS.chatTTS_handler import ChatTTSHandler |
| except RuntimeError as e: |
| logger.error("Error importing ChatTTSHandler") |
| raise e |
| return ChatTTSHandler( |
| stop_event, |
| queue_in=lm_response_queue, |
| queue_out=send_audio_chunks_queue, |
| setup_args=(should_listen,), |
| setup_kwargs=vars(chat_tts_handler_kwargs), |
| ) |
| else: |
| raise ValueError("The TTS should be either parler, melo or chatTTS") |
|
|
|
|
| def main(): |
| ( |
| module_kwargs, |
| socket_receiver_kwargs, |
| socket_sender_kwargs, |
| vad_handler_kwargs, |
| whisper_stt_handler_kwargs, |
| paraformer_stt_handler_kwargs, |
| language_model_handler_kwargs, |
| mlx_language_model_handler_kwargs, |
| parler_tts_handler_kwargs, |
| melo_tts_handler_kwargs, |
| chat_tts_handler_kwargs, |
| ) = parse_arguments() |
|
|
| setup_logger(module_kwargs.log_level) |
|
|
| prepare_all_args( |
| module_kwargs, |
| whisper_stt_handler_kwargs, |
| paraformer_stt_handler_kwargs, |
| language_model_handler_kwargs, |
| mlx_language_model_handler_kwargs, |
| parler_tts_handler_kwargs, |
| melo_tts_handler_kwargs, |
| chat_tts_handler_kwargs, |
| ) |
|
|
| queues_and_events = initialize_queues_and_events() |
|
|
| pipeline_manager = build_pipeline( |
| module_kwargs, |
| socket_receiver_kwargs, |
| socket_sender_kwargs, |
| vad_handler_kwargs, |
| whisper_stt_handler_kwargs, |
| paraformer_stt_handler_kwargs, |
| language_model_handler_kwargs, |
| mlx_language_model_handler_kwargs, |
| parler_tts_handler_kwargs, |
| melo_tts_handler_kwargs, |
| chat_tts_handler_kwargs, |
| queues_and_events, |
| ) |
|
|
| try: |
| pipeline_manager.start() |
| except KeyboardInterrupt: |
| pipeline_manager.stop() |
|
|
|
|
| if __name__ == "__main__": |
| main() |