File size: 4,564 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Optional, Tuple, Union

import torch

from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM
from TorchJaekwon.Model.Diffusion.External.diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from TorchJaekwon.Model.Diffusion.External.diffusers.DiffusersWrapper import DiffusersWrapper

from FlashSR.AudioSR.AudioSRUnet import AudioSRUnet
from FlashSR.VAEWrapper import VAEWrapper
from FlashSR.SRVocoder import SRVocoder
from FlashSR.Util.UtilAudioSR import UtilAudioSR
from FlashSR.Util.UtilAudioLowPassFilter import UtilAudioLowPassFilter

class FlashSR(DDPM):
    def __init__(
            self, 
            student_ldm_ckpt_path:str,
            sr_vocoder_ckpt_path:str,
            autoencoder_ckpt_path:str,
            model_output_type:str = 'v_prediction',
            beta_schedule_type:str = 'cosine',
            **kwargs
        ) -> None:
        
        super().__init__(model = AudioSRUnet(), model_output_type=model_output_type, beta_schedule_type=beta_schedule_type, **kwargs)
        
        student_ldm_state_dict = torch.load(student_ldm_ckpt_path)
        self.load_state_dict(student_ldm_state_dict)

        self.vae = VAEWrapper(autoencoder_ckpt_path)

        self.sr_vocoder = SRVocoder()
        sr_vocoder_state_dict = torch.load(sr_vocoder_ckpt_path)
        self.sr_vocoder.load_state_dict(sr_vocoder_state_dict)

    def forward(self, 
                lr_audio:torch.Tensor, #[batch, time] ex) [4, 245760]
                num_steps:int = 1,
                lowpass_input:bool = True,
                lowpass_cutoff_freq:int = None
                ) -> torch.Tensor: #[batch, time] ex) [4, 245760]
        
        if lowpass_input:
            device = lr_audio.device
            if lowpass_cutoff_freq is None:
                lowpass_cutoff_freq:int = UtilAudioSR.find_cutoff_freq(lr_audio)
            lr_audio = lr_audio.cpu().numpy()
            lr_audio = UtilAudioLowPassFilter.lowpass(lr_audio, 48000, filter_name='cheby', filter_order=8, cutoff_freq=lowpass_cutoff_freq)
            lr_audio = torch.from_numpy(lr_audio).to(device)

        with torch.no_grad():
            pred_hr_audio = DiffusersWrapper.infer(
                ddpm_module=self, 
                diffusers_scheduler_class=DPMSolverMultistepScheduler, 
                x_shape=None, 
                cond = lr_audio,
                num_steps=num_steps,
                device=lr_audio.device
            )
            pred_hr_audio = pred_hr_audio[...,:lr_audio.shape[-1]]
            return pred_hr_audio
    
    def preprocess(self, 
                   x_start:torch.Tensor, # [batch, time]
                   cond:Optional[Union[dict,torch.Tensor]] = None, # [batch, time]
                   ) -> Tuple[torch.Tensor, torch.Tensor]: #( [batch, 1 , mel, time//hop] ,  [batch, 1 , mel, time//hop] )
        device = cond.device
        if self.vae.device != device:
            self.vae.to(device=device)
        x_dict = dict()

        cond_dict = self.vae.encode_to_z(cond)

        if x_start is not None:
            state_dict:dict = {
                'mean_scale_factor': cond_dict['mean_scale_factor'],
                'var_scale_factor': cond_dict['var_scale_factor']
            }
            x_dict = self.vae.encode_to_z(x_start, scale_dict=state_dict) ##[batch, 16, time / (hop * 8), mel_bin / 8]

        return x_dict.get('z', None), cond_dict['z'], cond_dict
        
    def postprocess(self, 
                    x:torch.Tensor, #[batch, 1, mel, time]
                    additional_data_dict:dict) -> torch.Tensor:
        mel_spec = self.vae.z_to_mel(x)
        mel_spec = mel_spec.squeeze(1).transpose(1,2)
        pred_hr_audio = self.sr_vocoder(mel_spec, additional_data_dict['norm_wav'])['pred_hr_audio']
        pred_hr_audio = self.vae.denormalize_wav(pred_hr_audio, additional_data_dict)
        return pred_hr_audio
    
    def get_x_shape(self, cond):
        return cond.shape
    
    def get_unconditional_condition(self,
                                    cond:Optional[Union[dict,torch.Tensor]] = None, 
                                    cond_shape:Optional[tuple] = None,
                                    condition_device:Optional[torch.device] = None
                                    ) -> torch.Tensor:
        if cond_shape is None: cond_shape = cond.shape
        if cond is not None and isinstance(cond,torch.Tensor): condition_device = cond.device
        return (-11.4981 + torch.zeros(cond_shape)).to(condition_device) * self.vae.scale_factor_z