| | import torch |
| | from funasr import AutoModel |
| | from loguru import logger |
| |
|
| | from fish_speech.inference_engine import TTSInferenceEngine |
| | from fish_speech.models.dac.inference import load_model as load_decoder_model |
| | from fish_speech.models.text2semantic.inference import ( |
| | launch_thread_safe_queue, |
| | launch_thread_safe_queue_agent, |
| | ) |
| | from fish_speech.utils.schema import ServeTTSRequest |
| | from tools.server.inference import inference_wrapper as inference |
| |
|
| | ASR_MODEL_NAME = "iic/SenseVoiceSmall" |
| |
|
| |
|
| | class ModelManager: |
| | def __init__( |
| | self, |
| | mode: str, |
| | device: str, |
| | half: bool, |
| | compile: bool, |
| | asr_enabled: bool, |
| | llama_checkpoint_path: str, |
| | decoder_checkpoint_path: str, |
| | decoder_config_name: str, |
| | ) -> None: |
| |
|
| | self.mode = mode |
| | self.device = device |
| | self.half = half |
| | self.compile = compile |
| |
|
| | self.precision = torch.half if half else torch.bfloat16 |
| |
|
| | |
| | if torch.backends.mps.is_available(): |
| | self.device = "mps" |
| | logger.info("mps is available, running on mps.") |
| | elif not torch.cuda.is_available(): |
| | self.device = "cpu" |
| | logger.info("CUDA is not available, running on CPU.") |
| |
|
| | |
| | if asr_enabled: |
| | self.load_asr_model(self.device) |
| |
|
| | |
| | self.load_llama_model( |
| | llama_checkpoint_path, self.device, self.precision, self.compile, self.mode |
| | ) |
| | self.load_decoder_model( |
| | decoder_config_name, decoder_checkpoint_path, self.device |
| | ) |
| | self.tts_inference_engine = TTSInferenceEngine( |
| | llama_queue=self.llama_queue, |
| | decoder_model=self.decoder_model, |
| | precision=self.precision, |
| | compile=self.compile, |
| | ) |
| |
|
| | |
| | if self.mode == "tts": |
| | self.warm_up(self.tts_inference_engine) |
| |
|
| | def load_asr_model(self, device, hub="ms") -> None: |
| | self.asr_model = AutoModel( |
| | model=ASR_MODEL_NAME, |
| | device=device, |
| | disable_pbar=True, |
| | hub=hub, |
| | ) |
| | logger.info("ASR model loaded.") |
| |
|
| | def load_llama_model( |
| | self, checkpoint_path, device, precision, compile, mode |
| | ) -> None: |
| |
|
| | if mode == "tts": |
| | self.llama_queue = launch_thread_safe_queue( |
| | checkpoint_path=checkpoint_path, |
| | device=device, |
| | precision=precision, |
| | compile=compile, |
| | ) |
| | elif mode == "agent": |
| | self.llama_queue, self.tokenizer, self.config = ( |
| | launch_thread_safe_queue_agent( |
| | checkpoint_path=checkpoint_path, |
| | device=device, |
| | precision=precision, |
| | compile=compile, |
| | ) |
| | ) |
| | else: |
| | raise ValueError(f"Invalid mode: {mode}") |
| |
|
| | logger.info("LLAMA model loaded.") |
| |
|
| | def load_decoder_model(self, config_name, checkpoint_path, device) -> None: |
| | self.decoder_model = load_decoder_model( |
| | config_name=config_name, |
| | checkpoint_path=checkpoint_path, |
| | device=device, |
| | ) |
| | logger.info("Decoder model loaded.") |
| |
|
| | def warm_up(self, tts_inference_engine) -> None: |
| | request = ServeTTSRequest( |
| | text="Hello world.", |
| | references=[], |
| | reference_id=None, |
| | max_new_tokens=1024, |
| | chunk_length=200, |
| | top_p=0.7, |
| | repetition_penalty=1.2, |
| | temperature=0.7, |
| | format="wav", |
| | ) |
| | list(inference(request, tts_inference_engine)) |
| | logger.info("Models warmed up.") |
| |
|