| | import io |
| | from hashlib import sha256 |
| | from pathlib import Path |
| | from typing import Callable, Literal, Tuple |
| |
|
| | import torch |
| | import torchaudio |
| | from loguru import logger |
| |
|
| | from fish_speech.models.dac.modded_dac import DAC |
| | from fish_speech.utils.file import ( |
| | AUDIO_EXTENSIONS, |
| | audio_to_bytes, |
| | list_files, |
| | read_ref_text, |
| | ) |
| | from fish_speech.utils.schema import ServeReferenceAudio |
| |
|
| |
|
| | class ReferenceLoader: |
| |
|
| | def __init__(self) -> None: |
| | """ |
| | Component of the TTSInferenceEngine class. |
| | Loads and manages the cache for the reference audio and text. |
| | """ |
| | self.ref_by_id: dict = {} |
| | self.ref_by_hash: dict = {} |
| |
|
| | |
| | self.decoder_model: DAC |
| | self.encode_reference: Callable |
| |
|
| | |
| | backends = torchaudio.list_audio_backends() |
| | if "ffmpeg" in backends: |
| | self.backend = "ffmpeg" |
| | else: |
| | self.backend = "soundfile" |
| |
|
| | def load_by_id( |
| | self, |
| | id: str, |
| | use_cache: Literal["on", "off"], |
| | ) -> Tuple: |
| |
|
| | |
| | ref_folder = Path("references") / id |
| | ref_folder.mkdir(parents=True, exist_ok=True) |
| | ref_audios = list_files( |
| | ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False |
| | ) |
| |
|
| | if use_cache == "off" or id not in self.ref_by_id: |
| | |
| | prompt_tokens = [ |
| | self.encode_reference( |
| | |
| | reference_audio=audio_to_bytes(str(ref_audio)), |
| | enable_reference_audio=True, |
| | ) |
| | for ref_audio in ref_audios |
| | ] |
| | prompt_texts = [ |
| | read_ref_text(str(ref_audio.with_suffix(".lab"))) |
| | for ref_audio in ref_audios |
| | ] |
| | self.ref_by_id[id] = (prompt_tokens, prompt_texts) |
| |
|
| | else: |
| | |
| | logger.info("Use same references") |
| | prompt_tokens, prompt_texts = self.ref_by_id[id] |
| |
|
| | return prompt_tokens, prompt_texts |
| |
|
| | def load_by_hash( |
| | self, |
| | references: list[ServeReferenceAudio], |
| | use_cache: Literal["on", "off"], |
| | ) -> Tuple: |
| |
|
| | |
| | audio_hashes = [sha256(ref.audio).hexdigest() for ref in references] |
| |
|
| | cache_used = False |
| | prompt_tokens, prompt_texts = [], [] |
| | for i, ref in enumerate(references): |
| | if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash: |
| | |
| | prompt_tokens.append( |
| | self.encode_reference( |
| | reference_audio=ref.audio, |
| | enable_reference_audio=True, |
| | ) |
| | ) |
| | prompt_texts.append(ref.text) |
| | self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text) |
| |
|
| | else: |
| | |
| | cached_token, cached_text = self.ref_by_hash[audio_hashes[i]] |
| | prompt_tokens.append(cached_token) |
| | prompt_texts.append(cached_text) |
| | cache_used = True |
| |
|
| | if cache_used: |
| | logger.info("Use same references") |
| |
|
| | return prompt_tokens, prompt_texts |
| |
|
| | def load_audio(self, reference_audio, sr): |
| | """ |
| | Load the audio data from a file or bytes. |
| | """ |
| | if len(reference_audio) > 255 or not Path(reference_audio).exists(): |
| | audio_data = reference_audio |
| | reference_audio = io.BytesIO(audio_data) |
| |
|
| | waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend) |
| |
|
| | if waveform.shape[0] > 1: |
| | waveform = torch.mean(waveform, dim=0, keepdim=True) |
| |
|
| | if original_sr != sr: |
| | resampler = torchaudio.transforms.Resample( |
| | orig_freq=original_sr, new_freq=sr |
| | ) |
| | waveform = resampler(waveform) |
| |
|
| | audio = waveform.squeeze().numpy() |
| | return audio |
| |
|