| from typing import Callable |
|
|
| import torch |
| from loguru import logger |
|
|
| from fish_speech.models.dac.modded_dac import DAC |
|
|
|
|
| class VQManager: |
|
|
| def __init__(self): |
| |
| self.decoder_model: DAC |
| self.load_audio: Callable |
|
|
| def decode_vq_tokens(self, codes): |
| feature_lengths = torch.tensor( |
| [codes.shape[1]], device=self.decoder_model.device |
| ) |
| logger.info(f"VQ features: {codes.shape}") |
|
|
| if isinstance(self.decoder_model, DAC): |
| return self.decoder_model.decode( |
| indices=codes[None], |
| feature_lengths=feature_lengths, |
| )[0].squeeze() |
|
|
| raise ValueError(f"Unknown model type: {type(self.decoder_model)}") |
|
|
| def encode_reference(self, reference_audio, enable_reference_audio): |
| if enable_reference_audio and reference_audio is not None: |
| |
| if hasattr(self.decoder_model, "spec_transform"): |
| sample_rate = self.decoder_model.spec_transform.sample_rate |
| else: |
| sample_rate = self.decoder_model.sample_rate |
| reference_audio_content = self.load_audio(reference_audio, sample_rate) |
|
|
| audios = torch.from_numpy(reference_audio_content).to( |
| self.decoder_model.device |
| )[None, None, :] |
| audio_lengths = torch.tensor( |
| [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long |
| ) |
| logger.info( |
| f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds" |
| ) |
|
|
| |
| if isinstance(self.decoder_model, DAC): |
| prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0] |
| logger.info(f"Encoded prompt: {prompt_tokens.shape}") |
| else: |
| raise ValueError(f"Unknown model type: {type(self.decoder_model)}") |
| else: |
| prompt_tokens = None |
| logger.info("No reference audio provided") |
|
|
| return prompt_tokens |
|
|