Spaces:
Running
Running
| import logging | |
| from typing import Literal | |
| import torch | |
| import torch.nn as nn | |
| # Configure logger | |
| logger = logging.getLogger("kanade_tokenizer") | |
| logger.setLevel(logging.INFO) | |
| handler = logging.StreamHandler() | |
| handler.setLevel(logging.INFO) | |
| handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(name)s: %(message)s")) | |
| logger.addHandler(handler) | |
| def get_logger() -> logging.Logger: | |
| return logger | |
| def freeze_modules(modules: list[nn.Module] | None): | |
| for module in modules: | |
| if module is not None: | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| def _load_audio_internal( | |
| path: str, frame_offset: int | None = None, num_frames: int | None = None | |
| ) -> tuple[torch.Tensor, int]: | |
| # TorchAudio >= 2.9.0 removed decoding and encoding capabilities to TorchCodec. | |
| # See: https://github.com/pytorch/audio/issues/3902 | |
| # waveform, sample_rate = torchaudio.load(path, frame_offset=frame_offset or 0, num_frames=num_frames or -1) | |
| import soundfile as sf | |
| with sf.SoundFile(path) as f: | |
| if frame_offset is not None: | |
| f.seek(frame_offset) | |
| frames = f.read(frames=num_frames or -1, dtype="float32", always_2d=True) | |
| waveform = torch.from_numpy(frames.T) | |
| sample_rate = f.samplerate | |
| return waveform, sample_rate | |
| def load_audio(audio_path: str, sample_rate: int = 24000) -> torch.Tensor: | |
| import torchaudio | |
| """Load and preprocess audio file.""" | |
| waveform, sr = _load_audio_internal(audio_path) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Resample if necessary | |
| if sr != sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, sample_rate) | |
| waveform = resampler(waveform) | |
| # Normalize waveform | |
| max_val = torch.max(torch.abs(waveform)) + 1e-8 | |
| waveform = waveform / max_val # Normalize to [-1, 1] | |
| return waveform.squeeze(0) # Remove channel dimension | |
| def load_vocoder(name: Literal["vocos", "hift"] = "vocos") -> torch.nn.Module: | |
| if name == "vocos": | |
| from vocos import Vocos | |
| model = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
| model = model.eval() | |
| return model | |
| elif name == "hift": | |
| from huggingface_hub import hf_hub_download | |
| from .module.hift import HiFTGenerator | |
| # Download hte HiFT model from FunAudioLLM/CosyVoice2-0.5B | |
| model_path = hf_hub_download(repo_id="FunAudioLLM/CosyVoice2-0.5B", filename="hift.pt") | |
| model = HiFTGenerator() | |
| model.load_weights(model_path) | |
| model = model.eval() | |
| return model | |
| else: | |
| raise ValueError(f"Unsupported vocoder name: {name}") | |
| def vocode(vocoder, mel_spectrogram: torch.Tensor) -> torch.Tensor: | |
| """Convert mel spectrogram to waveform using Vocos vocoder. | |
| Args: | |
| vocoder: Pretrained vocoder model. | |
| mel_spectrogram (torch.Tensor): Input mel spectrogram tensor (..., n_mels, frame). | |
| Returns: | |
| torch.Tensor: Generated audio waveform tensor (..., samples). | |
| """ | |
| mel_spectrogram = mel_spectrogram.to(torch.float32) # Ensure mel spectrogram is in float32 | |
| vocoder_class_name = vocoder.__class__.__name__ | |
| if "Vocos" in vocoder_class_name: | |
| generated_waveform = vocoder.decode(mel_spectrogram) | |
| elif "HiFT" in vocoder_class_name: | |
| generated_waveform = vocoder.inference(mel_spectrogram) | |
| else: | |
| raise ValueError(f"Unsupported vocoder class: {vocoder_class_name}") | |
| return generated_waveform | |