File size: 3,507 Bytes
6fbfc17
 
d05dcd7
 
6fbfc17
 
d05dcd7
 
6fbfc17
 
 
028cdeb
 
6fbfc17
 
 
 
 
028cdeb
6fbfc17
 
 
 
 
 
 
 
 
 
 
d05dcd7
028cdeb
6fbfc17
028cdeb
6fbfc17
 
 
 
 
028cdeb
d05dcd7
6fbfc17
028cdeb
 
 
6fbfc17
 
 
028cdeb
6fbfc17
 
 
 
 
028cdeb
 
 
 
 
 
 
 
6fbfc17
 
 
 
 
 
 
 
 
 
 
028cdeb
 
 
 
 
 
 
 
6fbfc17
 
 
 
028cdeb
6fbfc17
 
 
028cdeb
 
 
 
 
 
 
 
6fbfc17
028cdeb
 
 
 
 
 
 
 
 
 
6fbfc17
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import soundfile
from openvoice import utils
from openvoice import commons
import os
import librosa
from openvoice.mel_processing import spectrogram_torch
from openvoice.models import SynthesizerTrn


class OpenVoiceBaseClass(object):
    def __init__(self, config_path, device="cuda:0"):
        if "cuda" in device:
            assert torch.cuda.is_available()

        hps = utils.get_hparams_from_file(config_path)

        model = SynthesizerTrn(
            len(getattr(hps, "symbols", [])),
            hps.data.filter_length // 2 + 1,
            n_speakers=hps.data.n_speakers,
            **hps.model,
        ).to(device)

        model.eval()
        self.model = model
        self.hps = hps
        self.device = device

    def load_ckpt(self, ckpt_path):
        checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
        a, b = self.model.load_state_dict(checkpoint_dict["model"], strict=False)
        print("Loaded checkpoint '{}'".format(ckpt_path))
        print("missing/unexpected keys:", a, b)


class ToneColorConverter(OpenVoiceBaseClass):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.version = getattr(self.hps, "_version_", "v1")

    def extract_se(self, ref_wav_list, se_save_path=None):
        # if isinstance(ref_wav_list, str):
        #     ref_wav_list = [ref_wav_list]

        device = self.device
        hps = self.hps
        gs = []

        for fname in ref_wav_list:
            audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
            y = torch.FloatTensor(audio_ref)
            y = y.to(device)
            y = y.unsqueeze(0)
            y = spectrogram_torch(
                y,
                hps.data.filter_length,
                hps.data.sampling_rate,
                hps.data.hop_length,
                hps.data.win_length,
                center=False,
            ).to(device)
            with torch.no_grad():
                g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
                gs.append(g.detach())
        gs = torch.stack(gs).mean(0)

        if se_save_path is not None:
            os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
            torch.save(gs.cpu(), se_save_path)

        return gs

    def convert(
        self,
        audio_src_path,
        src_se,
        tgt_se,
        output_path=None,
        tau=0.3,
    ):
        hps = self.hps
        # load audio
        audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
        audio = torch.tensor(audio).float()

        with torch.no_grad():
            y = torch.FloatTensor(audio).to(self.device)
            y = y.unsqueeze(0)
            spec = spectrogram_torch(
                y,
                hps.data.filter_length,
                hps.data.sampling_rate,
                hps.data.hop_length,
                hps.data.win_length,
                center=False,
            ).to(self.device)
            spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)

            audio = (
                self.model.voice_conversion(
                    spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau
                )[0][0, 0]
                .data.cpu()
                .float()
                .numpy()
            )

            if output_path is None:
                return audio
            else:
                soundfile.write(output_path, audio, hps.data.sampling_rate)