| import gc |
| import queue |
| from typing import Generator |
|
|
| import numpy as np |
| import torch |
| from loguru import logger |
|
|
| from fish_speech.inference_engine.reference_loader import ReferenceLoader |
| from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header |
| from fish_speech.inference_engine.vq_manager import VQManager |
| from fish_speech.models.dac.modded_dac import DAC |
| from fish_speech.models.text2semantic.inference import ( |
| GenerateRequest, |
| GenerateResponse, |
| WrappedGenerateResponse, |
| ) |
| from fish_speech.utils import autocast_exclude_mps, set_seed |
| from fish_speech.utils.schema import ServeTTSRequest |
|
|
|
|
| class TTSInferenceEngine(ReferenceLoader, VQManager): |
|
|
| def __init__( |
| self, |
| llama_queue: queue.Queue, |
| decoder_model: DAC, |
| precision: torch.dtype, |
| compile: bool, |
| ) -> None: |
|
|
| super().__init__() |
|
|
| self.llama_queue = llama_queue |
| self.decoder_model = decoder_model |
| self.precision = precision |
| self.compile = compile |
|
|
| @torch.inference_mode() |
| def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]: |
| """ |
| Main inference function: |
| - Loads the reference audio and text. |
| - Calls the LLAMA model for inference. |
| - Decodes the VQ tokens to audio. |
| """ |
|
|
| ref_id: str | None = req.reference_id |
| prompt_tokens, prompt_texts = [], [] |
| |
| if ref_id is not None: |
| prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache) |
|
|
| elif req.references: |
| prompt_tokens, prompt_texts = self.load_by_hash( |
| req.references, req.use_memory_cache |
| ) |
|
|
| |
| if req.seed is not None: |
| set_seed(req.seed) |
| logger.warning(f"set seed: {req.seed}") |
|
|
| |
| response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts) |
|
|
| |
| if hasattr(self.decoder_model, "spec_transform"): |
| sample_rate = self.decoder_model.spec_transform.sample_rate |
| else: |
| sample_rate = self.decoder_model.sample_rate |
|
|
| |
| if req.streaming: |
| yield InferenceResult( |
| code="header", |
| audio=( |
| sample_rate, |
| np.array(wav_chunk_header(sample_rate=sample_rate)), |
| ), |
| error=None, |
| ) |
|
|
| segments = [] |
|
|
| while True: |
| |
| wrapped_result: WrappedGenerateResponse = response_queue.get() |
| if wrapped_result.status == "error": |
| yield InferenceResult( |
| code="error", |
| audio=None, |
| error=( |
| wrapped_result.response |
| if isinstance(wrapped_result.response, Exception) |
| else Exception("Unknown error") |
| ), |
| ) |
| break |
|
|
| |
| if not isinstance(wrapped_result.response, GenerateResponse): |
| raise TypeError( |
| "Expected GenerateResponse, got {type(wrapped_result.response).__name__}" |
| ) |
|
|
| result: GenerateResponse = wrapped_result.response |
| if result.action != "next": |
| segment = self.get_audio_segment(result) |
|
|
| if req.streaming: |
| yield InferenceResult( |
| code="segment", |
| audio=(sample_rate, segment), |
| error=None, |
| ) |
| segments.append(segment) |
| else: |
| break |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| if len(segments) == 0: |
| yield InferenceResult( |
| code="error", |
| audio=None, |
| error=RuntimeError("No audio generated, please check the input text."), |
| ) |
| else: |
| |
| audio = np.concatenate(segments, axis=0) |
| yield InferenceResult( |
| code="final", |
| audio=(sample_rate, audio), |
| error=None, |
| ) |
|
|
| return None |
|
|
| def send_Llama_request( |
| self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list |
| ) -> queue.Queue: |
| """ |
| Send a request to the LLAMA model to generate the symbolic tokens. |
| """ |
|
|
| |
| request = dict( |
| device=self.decoder_model.device, |
| max_new_tokens=req.max_new_tokens, |
| text=req.text, |
| top_p=req.top_p, |
| repetition_penalty=req.repetition_penalty, |
| temperature=req.temperature, |
| compile=self.compile, |
| iterative_prompt=req.chunk_length > 0, |
| chunk_length=req.chunk_length, |
| prompt_tokens=prompt_tokens, |
| prompt_text=prompt_texts, |
| ) |
|
|
| |
| response_queue = queue.Queue() |
|
|
| |
| self.llama_queue.put( |
| GenerateRequest( |
| request=request, |
| response_queue=response_queue, |
| ) |
| ) |
|
|
| return response_queue |
|
|
| def get_audio_segment(self, result: GenerateResponse) -> np.ndarray: |
| """ |
| Decode the VQ tokens to audio. |
| """ |
|
|
| |
| with autocast_exclude_mps( |
| device_type=self.decoder_model.device.type, dtype=self.precision |
| ): |
| |
| segment = self.decode_vq_tokens(codes=result.codes) |
|
|
| |
| return segment.float().cpu().numpy() |
|
|