File size: 5,259 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import json

import torch
import torchaudio.transforms as T
from torch import nn

from .autoencoders import create_autoencoder_from_config
from .utils import load_ckpt_state_dict


class PadCrop(nn.Module):
    def __init__(self, n_samples, randomize=True):
        super().__init__()
        self.n_samples = n_samples
        self.randomize = randomize

    def __call__(self, signal):
        n, s = signal.shape
        start = (
            0
            if (not self.randomize)
            else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
        )
        end = start + self.n_samples
        output = signal.new_zeros([n, self.n_samples])
        output[:, : min(s, self.n_samples)] = signal[:, start:end]
        return output


def set_audio_channels(audio, target_channels):
    if target_channels == 1:
        audio = audio.mean(1, keepdim=True)
    elif target_channels == 2:
        if audio.shape[1] == 1:
            audio = audio.repeat(1, 2, 1)
        elif audio.shape[1] > 2:
            audio = audio[:, :2, :]
    return audio


def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
    audio = audio.to(device)

    if in_sr != target_sr:
        resample_tf = T.Resample(in_sr, target_sr).to(device)
        audio = resample_tf(audio)

    assert target_length is None
    if target_length is None:
        target_length = audio.shape[-1]

    audio = PadCrop(target_length, randomize=False)(audio)

    # Add batch dimension
    if audio.dim() == 1:
        audio = audio.unsqueeze(0).unsqueeze(0)
    elif audio.dim() == 2:
        audio = audio.unsqueeze(0)

    audio = set_audio_channels(audio, target_channels)

    return audio


class StableAudioInfer(nn.Module):
    def __init__(self, model_config_path, model_ckpt_path=None):
        super().__init__()

        with open(model_config_path) as f:
            self.model_config = json.load(f)

        self.model = create_autoencoder_from_config(self.model_config)
        if model_ckpt_path is not None:
            self.model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))

        self.sample_rate = self.model_config["sample_rate"]
        self.sample_size = self.model_config["sample_size"]
        self.io_channels = self.model.io_channels
        self.sample_size = 24576

    @property
    def device(self):
        return next(self.parameters()).device

    def normalize_audio(self, y, target_dbfs=0):
        """Normalize audio to a specific dBFS level."""
        max_amplitude = torch.max(torch.abs(y))
        target_amplitude = 10.0 ** (target_dbfs / 20.0)
        scale_factor = target_amplitude / max_amplitude
        return y * scale_factor

    def encode_audio(self, input_audio, in_sr):
        """Encode audio waveform into VAE latent representation.

        Args:
            input_audio: Input audio tensor.
            in_sr: Input sample rate.

        Returns:
            Latent tensor from the VAE encoder.
        """
        input_audio = prepare_audio(
            input_audio,
            in_sr=in_sr,
            target_sr=self.model.sample_rate,
            target_length=None,  # Determined after resampling
            target_channels=self.io_channels,
            device=self.device,
        )
        input_audio = self.normalize_audio(input_audio, -6)

        with torch.no_grad():
            # Use chunked encoding for long audio to save memory
            if input_audio.shape[-1] > (128 + 10) * self.model.sample_rate:
                latent = self.model.encode_audio(input_audio, chunked=True)
            else:
                latent = self.model.encode_audio(input_audio, chunked=False)

        return latent

    def decode_audio(self, latent):
        """Decode VAE latent back to audio waveform.

        Args:
            latent: Latent tensor.

        Returns:
            Decoded audio tensor.
        """
        with torch.no_grad():
            # Use chunked decoding for long latents to save memory
            if latent.shape[-1] > 128 + 10:
                output = self.model.decode_audio(latent, chunked=True)
            else:
                output = self.model.decode_audio(latent, chunked=False)
        return output

    def forward(self, func_type, x, sr=None):
        x = x.to(next(self.parameters()).device)
        if func_type == "encode":
            assert sr is not None, "sr is required for encoding"
            return self.encode_audio(input_audio=x, in_sr=sr)
        elif func_type == "decode":
            return self.decode_audio(x)
        else:
            raise ValueError(f"Unknown func_type: {func_type}")


if __name__ == "__main__":
    import torchaudio

    device = "cuda"
    vae_model = StableAudioInfer(
        model_config_path="config/stable_audio_2_0_vae_20hz_official.json",
        model_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
    )
    vae_model = vae_model.eval().to(device)

    input_audio, in_sr = torchaudio.load("path/to/input.wav")
    latent = vae_model(func_type="encode", x=input_audio, sr=in_sr)

    output_audio = vae_model(func_type="decode", x=latent, sr=None)
    output_audio = output_audio.squeeze(0).cpu()
    torchaudio.save("output.wav", output_audio, sample_rate=44100)