File size: 4,906 Bytes
ea47387
 
 
 
 
a17e6b8
 
ea47387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a17e6b8
 
ea47387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a17e6b8
 
 
 
 
 
 
 
 
 
 
 
ea47387
 
 
 
 
 
 
 
 
 
 
 
 
 
a17e6b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea47387
a17e6b8
 
ea47387
a17e6b8
ea47387
 
 
 
 
 
 
 
 
 
 
a17e6b8
 
ea47387
a17e6b8
 
 
 
 
 
 
 
 
 
 
 
 
 
ea47387
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import soundfile as sf
import torch


@dataclass
class VocosFbankConfig:
    sampling_rate: int = 24000
    n_mels: int = 100
    n_fft: int = 1024
    hop_length: int = 256


def compute_num_frames(num_samples: int, hop_length: int) -> int:
    return int((num_samples + hop_length // 2) // hop_length)


class LocalVocosFbank:
    def __init__(self) -> None:
        self.config = VocosFbankConfig()
        self.window = torch.hann_window(self.config.n_fft)
        self.mel_basis = _create_mel_filterbank(
            sample_rate=self.config.sampling_rate,
            n_fft=self.config.n_fft,
            n_mels=self.config.n_mels,
        )

    def extract(self, samples: torch.Tensor, sampling_rate: int) -> torch.Tensor:
        if sampling_rate != self.config.sampling_rate:
            raise ValueError(
                f"Mismatched sampling rate: expected {self.config.sampling_rate}, got {sampling_rate}"
            )
        if samples.ndim == 1:
            samples = samples.unsqueeze(0)
        if samples.ndim != 2:
            raise ValueError(f"Expected waveform shape [C, T], got {tuple(samples.shape)}")
        if samples.shape[0] == 2:
            samples = samples.mean(dim=0, keepdim=True)

        stft = torch.stft(
            samples,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            win_length=self.config.n_fft,
            window=self.window.to(samples.device),
            center=True,
            pad_mode="reflect",
            return_complex=True,
        )
        spec = stft.abs()
        mel = torch.matmul(self.mel_basis.to(samples.device).t(), spec).clamp(min=1e-7).log()
        mel = mel.reshape(-1, mel.shape[-1]).t()
        num_frames = compute_num_frames(samples.shape[1], self.config.hop_length)

        if mel.shape[0] > num_frames:
            mel = mel[:num_frames]
        elif mel.shape[0] < num_frames:
            mel = torch.nn.functional.pad(
                mel.unsqueeze(0),
                (0, 0, 0, num_frames - mel.shape[0]),
                mode="replicate",
            ).squeeze(0)
        return mel


def _hz_to_mel(freq: torch.Tensor) -> torch.Tensor:
    return 2595.0 * torch.log10(1.0 + freq / 700.0)


def _mel_to_hz(mels: torch.Tensor) -> torch.Tensor:
    return 700.0 * (torch.pow(10.0, mels / 2595.0) - 1.0)


def _create_mel_filterbank(sample_rate: int, n_fft: int, n_mels: int) -> torch.Tensor:
    n_freqs = n_fft // 2 + 1
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
    m_min = _hz_to_mel(torch.tensor(0.0))
    m_max = _hz_to_mel(torch.tensor(float(sample_rate // 2)))
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
    f_pts = _mel_to_hz(m_pts)

    f_diff = f_pts[1:] - f_pts[:-1]
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
    down_slopes = -slopes[:, :-2] / f_diff[:-1]
    up_slopes = slopes[:, 2:] / f_diff[1:]
    fb = torch.maximum(torch.zeros(1), torch.minimum(down_slopes, up_slopes))

    if (fb.max(dim=0).values == 0.0).any():
        raise ValueError("Mel filterbank has empty filters")
    return fb


def _resample_linear(wav: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
    if orig_freq == new_freq:
        return wav
    old_len = wav.shape[-1]
    new_len = max(1, int(round(old_len * new_freq / orig_freq)))
    old_pos = np.arange(old_len, dtype=np.float64)
    new_pos = np.linspace(0, old_len - 1, new_len, dtype=np.float64)
    channels = []
    for channel in wav.cpu().numpy():
        channels.append(np.interp(new_pos, old_pos, channel).astype(np.float32))
    return torch.from_numpy(np.stack(channels, axis=0))


def load_prompt_wav(prompt_wav: str | Path, sampling_rate: int) -> torch.Tensor:
    wav_np, sr = sf.read(str(prompt_wav), always_2d=True, dtype="float32")
    wav = torch.from_numpy(wav_np.T.copy())
    if sr != sampling_rate:
        wav = _resample_linear(wav, orig_freq=sr, new_freq=sampling_rate)
    return wav


def rms_norm(wav: torch.Tensor, target_rms: float):
    wav_rms = torch.sqrt(torch.mean(torch.square(wav)))
    if wav_rms < target_rms:
        wav = wav * target_rms / wav_rms
    return wav, wav_rms


def load_local_vocos(vocoder_dir: str | Path):
    from scripts.local_vocos import LocalVocos

    vocoder_dir = Path(vocoder_dir)
    vocoder = LocalVocos()
    try:
        state_dict = torch.load(
            str(vocoder_dir / "pytorch_model.bin"),
            weights_only=True,
            map_location="cpu",
        )
    except TypeError:
        state_dict = torch.load(str(vocoder_dir / "pytorch_model.bin"), map_location="cpu")
    state_dict = {
        key: value
        for key, value in state_dict.items()
        if key.startswith(("backbone.", "head."))
    }
    vocoder.load_state_dict(state_dict)
    return vocoder