File size: 3,649 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# source: https://github.com/haoheliu/versatile_audio_super_resolution
import librosa
import numpy as np
import torch

from TorchJaekwon.Util.UtilData import UtilData
from TorchJaekwon.Util.UtilTorch import UtilTorch

class UtilAudioSR:

    @staticmethod
    def mel_replace_ops(predict_mel:torch.Tensor,       #[batch, 1, time, melbin], log mel spectrogram
                        gt_low_pass_mel:torch.Tensor,   #[batch, 1, time, melbin], log mel spectrogram
                        debug_message:bool=False
                        ) -> torch.Tensor:              #[batch, 1, time, melbin]
        batch_size = predict_mel.size(0)
        for i in range(batch_size):
            cutoff_melbin = UtilAudioSR.locate_cutoff_freq(torch.exp(gt_low_pass_mel[i].squeeze()))

            if debug_message:
                ratio = predict_mel[i][...,:cutoff_melbin]/gt_low_pass_mel[i][...,:cutoff_melbin]
                print(torch.mean(ratio), torch.max(ratio), torch.min(ratio))

            predict_mel[i][..., :cutoff_melbin] = gt_low_pass_mel[i][..., :cutoff_melbin]
        return predict_mel
    
    @staticmethod
    def locate_cutoff_freq(stft, percentile=0.985):
        magnitude = torch.abs(stft)
        energy = torch.cumsum(torch.sum(magnitude, dim=0), dim=0)
        return UtilAudioSR.find_cutoff(energy, percentile)
    
    @staticmethod
    def find_cutoff(x, percentile=0.95):
        percentile = x[-1] * percentile
        for i in range(1, x.shape[0]):
            if x[-i] < percentile:
                return x.shape[0] - i
        return 0
    
    @staticmethod
    def wav_replace_ops(pred_wav:torch.Tensor,          #[batch, 1, time]
                        gt_low_pass_wav:torch.Tensor    #[batch, 1, time]
                        ) -> torch.Tensor:              #[batch, 1, time]
        device = pred_wav.device
        pred_wav = UtilData.fit_shape_length(pred_wav, 2).cpu().detach().numpy()
        gt_low_pass_wav = UtilData.fit_shape_length(gt_low_pass_wav, 2).cpu().detach().numpy()
        for i in range(pred_wav.shape[0]):

            out = pred_wav[i]
            x = gt_low_pass_wav[i]
            cutoffratio = UtilAudioSR.get_cutoff_index_np(x)

            length = out.shape[0]
            stft_gt = librosa.stft(x)

            stft_out = librosa.stft(out)
            energy_ratio = np.mean(
                np.sum(np.abs(stft_gt[cutoffratio]))
                / np.sum(np.abs(stft_out[cutoffratio, ...]))
            )
            energy_ratio = min(max(energy_ratio, 0.8), 1.2)
            stft_out[:cutoffratio, ...] = stft_gt[:cutoffratio, ...] / energy_ratio

            out_renewed = librosa.istft(stft_out, length=length)
            pred_wav[i] = out_renewed
        
        return torch.FloatTensor(pred_wav).to(device)
    
    @staticmethod
    def get_cutoff_index_np(x):
        stft_x = np.abs(librosa.stft(x))
        energy = np.cumsum(np.sum(stft_x, axis=-1))
        return UtilAudioSR.find_cutoff(energy, 0.985)
    
    @staticmethod
    def find_cutoff_freq(audio:torch.Tensor) -> int:
        stft_spec = torch.stft(
            input = audio,
            n_fft = 2048,
            hop_length=480,
            win_length=2048,
            window=torch.hann_window(2048).to(audio.device),
            center=False,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )

        stft_spec = stft_spec[0].T.float()
        cutoff_freq = (UtilAudioSR.locate_cutoff_freq(stft_spec, percentile=0.983) / 1024) * 24000
        if(cutoff_freq < 1000):
            cutoff_freq = 24000
        return cutoff_freq