File size: 5,150 Bytes
30f8290
 
 
 
 
 
 
 
 
 
be43029
30f8290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import sys
import torch

import numpy as np

from scipy.signal import medfilt

sys.path.append(os.getcwd())

from infer.lib.predictors.DJCM.spec import Spectrogram

SAMPLE_RATE, WINDOW_LENGTH, N_CLASS = 16000, 1024, 360

class DJCM:
    def __init__(
        self, 
        model_path, 
        device = "cpu", 
        is_half = False, 
        onnx = False, 
        svs = False, 
        providers = ["CPUExecutionProvider"], 
        batch_size = 1, 
        segment_len = 5.12, 
        kernel_size = 3
    ):
        super(DJCM, self).__init__()
        if svs: WINDOW_LENGTH = 2048
        self.onnx = onnx

        if self.onnx:
            import onnxruntime as ort

            sess_options = ort.SessionOptions()
            sess_options.log_severity_level = 3
            self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
        else:
            from main.library.predictors.DJCM.model import DJCMM

            model = DJCMM(1, 1, 1, svs=svs, window_length=WINDOW_LENGTH, n_class=N_CLASS)
            model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
            model.eval()
            if is_half: model = model.half()
            self.model = model.to(device)

        self.batch_size = batch_size
        self.seg_len = int(segment_len * SAMPLE_RATE)
        self.seg_frames = int(self.seg_len // int(SAMPLE_RATE // 100))

        self.device = device
        self.is_half = is_half
        self.kernel_size = kernel_size

        self.spec_extractor = Spectrogram(int(SAMPLE_RATE // 100), WINDOW_LENGTH).to(device)
        cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
        self.cents_mapping = np.pad(cents_mapping, (4, 4))

    def spec2hidden(self, spec):
        if self.onnx:
            spec = spec.cpu().numpy().astype(np.float32)

            hidden = torch.as_tensor(
                self.model.run(
                    [self.model.get_outputs()[0].name], 
                    {self.model.get_inputs()[0].name: spec}
                )[0], 
                device=self.device
            )
        else:
            if self.is_half: spec = spec.half()
            hidden = self.model(spec)

        return hidden

    def infer_from_audio(self, audio, thred=0.03):
        if torch.is_tensor(audio): audio = audio.cpu().numpy()
        if audio.ndim > 1: audio = audio.squeeze()

        with torch.no_grad():
            padded_audio = self.pad_audio(audio)
            hidden = self.inference(padded_audio)[:(audio.shape[-1] // int(SAMPLE_RATE // 100) + 1)]

            f0 = self.decode(hidden.squeeze(0).cpu().numpy(), thred)
            if self.kernel_size is not None: f0 = medfilt(f0, kernel_size=self.kernel_size)

            return f0
        
    def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
        f0 = self.infer_from_audio(audio, thred)
        f0[(f0 < f0_min) | (f0 > f0_max)] = 0

        return f0

    def to_local_average_cents(self, salience, thred=0.05):
        center = np.argmax(salience, axis=1)
        salience = np.pad(salience, ((0, 0), (4, 4)))
        center += 4
        todo_salience, todo_cents_mapping = [], []
        starts = center - 4
        ends = center + 5

        for idx in range(salience.shape[0]):
            todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
            todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])

        todo_salience = np.array(todo_salience)
        devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
        devided[np.max(salience, axis=1) <= thred] = 0

        return devided
        
    def decode(self, hidden, thred=0.03):
        f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
        f0[f0 == 10] = 0
        return f0

    def pad_audio(self, audio):
        audio_len = audio.shape[-1]

        seg_nums = int(np.ceil(audio_len / self.seg_len)) + 1
        pad_len = int(seg_nums * self.seg_len - audio_len + self.seg_len // 2)

        left_pad = np.zeros(int(self.seg_len // 4), dtype=np.float32)
        right_pad = np.zeros(int(pad_len - self.seg_len // 4), dtype=np.float32)
        padded_audio = np.concatenate([left_pad, audio, right_pad], axis=-1)

        segments = [
            padded_audio[start: start + int(self.seg_len)] 
            for start in range(
                0, 
                len(padded_audio) - int(self.seg_len) + 1, 
                int(self.seg_len // 2)
            )
        ]

        segments = np.stack(segments, axis=0)
        segments = torch.from_numpy(segments).unsqueeze(1).to(self.device)

        return segments

    def inference(self, segments):
        hidden_segments = torch.cat([
            self.spec2hidden(self.spec_extractor(segments[i:i + self.batch_size].float()))
            for i in range(0, len(segments), self.batch_size)
        ], dim=0)

        hidden = torch.cat([
            seg[self.seg_frames // 4: int(self.seg_frames * 0.75)]
            for seg in hidden_segments
        ], dim=0)

        return hidden