File size: 3,587 Bytes
2cba492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import logging
from typing import Literal

import torch
import torch.nn as nn

# Configure logger
logger = logging.getLogger("kanade_tokenizer")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(name)s: %(message)s"))
logger.addHandler(handler)


def get_logger() -> logging.Logger:
    return logger


def freeze_modules(modules: list[nn.Module] | None):
    for module in modules:
        if module is not None:
            for param in module.parameters():
                param.requires_grad = False


def _load_audio_internal(
    path: str, frame_offset: int | None = None, num_frames: int | None = None
) -> tuple[torch.Tensor, int]:
    # TorchAudio >= 2.9.0 removed decoding and encoding capabilities to TorchCodec.
    # See: https://github.com/pytorch/audio/issues/3902
    # waveform, sample_rate = torchaudio.load(path, frame_offset=frame_offset or 0, num_frames=num_frames or -1)

    import soundfile as sf

    with sf.SoundFile(path) as f:
        if frame_offset is not None:
            f.seek(frame_offset)
        frames = f.read(frames=num_frames or -1, dtype="float32", always_2d=True)
        waveform = torch.from_numpy(frames.T)
        sample_rate = f.samplerate
    return waveform, sample_rate


def load_audio(audio_path: str, sample_rate: int = 24000) -> torch.Tensor:
    import torchaudio

    """Load and preprocess audio file."""
    waveform, sr = _load_audio_internal(audio_path)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Resample if necessary
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)

    # Normalize waveform
    max_val = torch.max(torch.abs(waveform)) + 1e-8
    waveform = waveform / max_val  # Normalize to [-1, 1]

    return waveform.squeeze(0)  # Remove channel dimension


def load_vocoder(name: Literal["vocos", "hift"] = "vocos") -> torch.nn.Module:
    if name == "vocos":
        from vocos import Vocos

        model = Vocos.from_pretrained("charactr/vocos-mel-24khz")
        model = model.eval()
        return model
    elif name == "hift":
        from huggingface_hub import hf_hub_download
        from .module.hift import HiFTGenerator

        # Download hte HiFT model from FunAudioLLM/CosyVoice2-0.5B
        model_path = hf_hub_download(repo_id="FunAudioLLM/CosyVoice2-0.5B", filename="hift.pt")
        model = HiFTGenerator()
        model.load_weights(model_path)
        model = model.eval()
        return model
    else:
        raise ValueError(f"Unsupported vocoder name: {name}")


def vocode(vocoder, mel_spectrogram: torch.Tensor) -> torch.Tensor:
    """Convert mel spectrogram to waveform using Vocos vocoder.
    Args:
        vocoder: Pretrained vocoder model.
        mel_spectrogram (torch.Tensor): Input mel spectrogram tensor (..., n_mels, frame).
    Returns:
        torch.Tensor: Generated audio waveform tensor (..., samples).
    """
    mel_spectrogram = mel_spectrogram.to(torch.float32)  # Ensure mel spectrogram is in float32

    vocoder_class_name = vocoder.__class__.__name__
    if "Vocos" in vocoder_class_name:
        generated_waveform = vocoder.decode(mel_spectrogram)
    elif "HiFT" in vocoder_class_name:
        generated_waveform = vocoder.inference(mel_spectrogram)
    else:
        raise ValueError(f"Unsupported vocoder class: {vocoder_class_name}")

    return generated_waveform