File size: 1,021 Bytes
c7f3ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn

from soulxsinger.models.modules.flow_matching import FlowMatchingTransformer


class CFMDecoder(nn.Module):
    def __init__(self, config):
        super(CFMDecoder, self).__init__()
        self.model = FlowMatchingTransformer(cfg=config, **config)

    def forward(self, mel, x_mask, decoder_inp, is_prompt):
        outputs = self.model(mel, x_mask, decoder_inp, is_prompt)

        noise, x, flow_pred, final_mask, prompt_len = outputs["output"]
        return noise, x, flow_pred, final_mask, prompt_len

    def reverse_diffusion(self, pt_mel, pt_decoder_inp, gt_decoder_inp, n_timesteps=32, cfg=1):
        diffusion_cond = torch.cat([pt_decoder_inp, gt_decoder_inp], dim=1)
        diffusion_cond_emb = self.model.cond_emb(diffusion_cond)
        diffusion_prompt = pt_mel

        generated = self.model.reverse_diffusion(
            diffusion_cond_emb,
            diffusion_prompt,
            n_timesteps=n_timesteps,
            cfg=cfg
        )
        return generated