File size: 3,346 Bytes
b3c4dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn

from torch import Tensor, nn, no_grad
from .autoencoders import OobleckDecoder, OobleckEncoder

from .transformer import ContinuousTransformer
LRELU_SLOPE = 0.1
padding_mode = "zeros"
sample_eps = 1e-6

def vae_sample(mean, scale):
    stdev = nn.functional.softplus(scale)
    var = stdev * stdev + sample_eps
    logvar = torch.log(var)
    latents = torch.randn_like(mean) * stdev + mean

    kl = (mean * mean + var - logvar - 1).sum(1).mean()
    
    return latents, kl


class EAR_VAE(nn.Module):

    def __init__(self, model_config: dict = None):
        super().__init__()

        if model_config is None:
            model_config = {
                "encoder": {
                    "config": {
                        "in_channels": 2,
                        "channels": 128,
                        "c_mults": [1, 2, 4, 8, 16],
                        "strides": [2, 4, 4, 4, 8],
                        "latent_dim": 128,
                        "use_snake": True
                    }
                },
                "decoder": {
                    "config": {
                        "out_channels": 2,
                        "channels": 128,
                        "c_mults": [1, 2, 4, 8, 16],
                        "strides": [2, 4, 4, 4, 8],
                        "latent_dim": 64,
                        "use_nearest_upsample": False,
                        "use_snake": True,
                        "final_tanh": False,
                    },
                },
                "latent_dim": 64,
                "downsampling_ratio": 1024,
                "io_channels": 2,
            }
        else:
            model_config = model_config

        if model_config.get("transformer") is not None:
            self.transformers = ContinuousTransformer(
                dim=model_config["decoder"]["config"]["latent_dim"],
                depth=model_config["transformer"]["depth"],
                **model_config["transformer"].get("config", {}),
            )
        else:
            self.transformers = None

        self.encoder = OobleckEncoder(**model_config["encoder"]["config"])
        self.decoder = OobleckDecoder(**model_config["decoder"]["config"])

    def forward(self, audio) -> Tensor:
        """
        audio: Input audio tensor [B,C,T]
        """
        status = self.encoder(audio)
        mean, scale = status.chunk(2, dim=1)
        z, kl = vae_sample(mean, scale)
        
        if self.transformers is not None:
            z = z.permute(0, 2, 1)
            z = self.transformers(z)
            z = z.permute(0, 2, 1)

        x = self.decoder(z)
        return x, kl

    def encode(self, audio, use_sample=True):
        x = self.encoder(audio)
        mean, scale = x.chunk(2, dim=1)
        if use_sample:
            z, _ = vae_sample(mean, scale)
        else:
            z = mean
        return z

    def decode(self, z):
        
        if self.transformers is not None:
            z = z.permute(0, 2, 1)
            z = self.transformers(z)
            z = z.permute(0, 2, 1)
            
        x = self.decoder(z)
        return x

    @no_grad()
    def inference(self, audio):
        z = self.encode(audio)
        recon_audio = self.decode(z)
        return recon_audio