File size: 11,332 Bytes
1cd928a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96019b0
 
 
 
 
1cd928a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96019b0
 
 
 
 
1cd928a
 
 
 
 
 
 
 
 
96019b0
 
 
 
1cd928a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# support audio dataset with text prompt
import os
import librosa
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import sys

from huggingface_hub import hf_hub_download
sys.path.append(os.path.join(os.path.dirname(__file__), '../utils'))
from ddsp.vocoder import F0_Extractor, Volume_Extractor

import torch
from typing import Union
from torch.nn import functional as F
from slicer import Slicer
from transformers import AutoTokenizer, AutoModel
# from ThreeD_Speaker.speakerlab.bin.get_spk_sim import build_model, get_spk_emb, get_spk_emb_t

def edge_padding(f0):
    f0_padded = f0.copy()
    
    # Loop through the array, checking for boundaries (zero values)
    for i in range(1, len(f0) - 1):
        if f0[i] != 0:
            # If boundary found, pad the previous frame (if not the first frame)
            if f0[i-1] == 0:
                f0_padded[i-1] = f0[i]
            # Pad the next frame (if not the last frame)
            if f0[i+1] == 0:
                f0_padded[i+1] = f0[i]
    
    return f0_padded

def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000):
    slnpicer = Slicer(
                sr=sample_rate,
                threshold=db_thresh,
                min_length=min_len)       
    chunks = dict(slicer.slice(audio))
    result = []
    for k, v in chunks.items():
        tag = v["split_time"].split(",")
        if tag[0] != tag[1]:
            start_frame = int(int(tag[0]) // hop_size)
            end_frame = int(int(tag[1]) // hop_size)
            if end_frame > start_frame:
                result.append((
                        start_frame, 
                        audio[int(start_frame * hop_size) : int(end_frame * hop_size)]))
    return result

def wav_pad(wav, multiple=200):
    seq_len = wav.shape[0]
    padded_len = ((seq_len + (multiple-1)) // multiple) * multiple
    padded_wav = repeat_expand(wav, padded_len)
    return padded_wav

def repeat_expand(
    content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
):
    """Repeat content to target length.
    This is a wrapper of torch.nn.functional.interpolate.

    Args:
        content (torch.Tensor): tensor
        target_len (int): target length
        mode (str, optional): interpolation mode. Defaults to "nearest".

    Returns:
        torch.Tensor: tensor
    """

    ndim = content.ndim

    if content.ndim == 1:
        content = content[None, None]
    elif content.ndim == 2:
        content = content[None]

    assert content.ndim == 3

    is_np = isinstance(content, np.ndarray)
    if is_np:
        content = torch.from_numpy(content)

    results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)

    if is_np:
        results = results.numpy()

    if ndim == 1:
        return results[0, 0]
    elif ndim == 2:
        return results[0]

def repeat_expand_2d(content, target_len, mode = 'left'):
    # content : [h, t]
    return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)


def repeat_expand_2d_left(content, target_len):
    # content : [h, t]

    src_len = content.shape[-1]
    target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
    temp = torch.arange(src_len+1) * target_len / src_len
    current_pos = 0
    for i in range(target_len):
        if i < temp[current_pos+1]:
            target[:, i] = content[:, current_pos]
        else:
            current_pos += 1
            target[:, i] = content[:, current_pos]

    return target


# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
    # content : [h, t]
    content = content[None,:,:]
    target = F.interpolate(content,size=target_len,mode=mode)[0]
    return target

def align_data(data, max_len):
    data_len = data.shape[-1]
    if data_len < max_len:
        data = F.pad(data, (0, max_len - data_len))
    elif data_len > max_len:
        data = data[:max_len]
    return data

def adjust_length(feature, target_len):
    # feature.shape = (current_len, dim)
    current_len = feature.shape[0]
    # dim = feature.shape[1]
    
    # 如果当前长度等于目标长度,直接返回
    if current_len == target_len:
        return feature
    
    # 调整维度以正确插值
    feature = feature.t()  # 转置为 (dim, current_len)
    feature = feature.unsqueeze(0)  # 添加批量维度,变为 (1, dim, current_len)
    feature = F.interpolate(feature, size=target_len, mode='linear', align_corners=False)
    # 输出为 (1, dim, target_len)
    feature = feature.squeeze(0)  # 移除批量维度,变为 (dim, target_len)
    feature = feature.t()  # 转置回 (target_len, dim)
    
    return feature

def load_bert_model(model_name, device):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    return tokenizer, model

def get_style_embed(style_prompt, tokenizer, model):
    inputs = tokenizer(style_prompt, return_tensors="pt").to(model.device)
    outputs = model(**inputs)
    return outputs[-1]

def load_facodec(device):
    from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2
    fa_encoder = FACodecEncoderV2(
        ngf=32,
        up_ratios=[2, 4, 5, 5],
        out_channels=256,
    )

    fa_decoder = FACodecDecoderV2(
        in_channels=256,
        upsample_initial_channel=1024,
        ngf=32,
        up_ratios=[5, 5, 4, 2],
        vq_num_q_c=2,
        vq_num_q_p=1,
        vq_num_q_r=3,
        vq_dim=256,
        codebook_dim=8,
        codebook_size_prosody=10,
        codebook_size_content=10,
        codebook_size_residual=10,
        use_gr_x_timbre=True,
        use_gr_residual_f0=True,
        use_gr_residual_phone=True,
    )
    # encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin", local_dir="utils/pretrain")
    # decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin", local_dir="utils/pretrain")

    encoder_ckpt = "utils/pretrain/ns3_facodec_encoder_v2.bin"
    decoder_ckpt = "utils/pretrain/ns3_facodec_decoder_v2.bin"

    fa_encoder.load_state_dict(torch.load(encoder_ckpt))
    fa_decoder.load_state_dict(torch.load(decoder_ckpt))
    
    fa_encoder = fa_encoder.to(device).eval()
    fa_decoder = fa_decoder.to(device).eval()
    
    return fa_encoder, fa_decoder

def load_f0_extractor(args):
    f0_extractor = F0_Extractor(args.f0_extractor if args.f0_extractor is not None else 'rmvpe',
                                args.sr if args.sr is not None else 44100, 
                                args.block_size if args.block_size is not None else 512, 
                                args.f0_min if args.f0_min is not None else 60,
                                args.f0_max if args.f0_max is not None else 1200)
    return f0_extractor

def load_volume_extractor(args):
    volume_extractor = Volume_Extractor(args.block_size if args.block_size is not None else 512)
    return volume_extractor

def load_audio(input_path, sr):
    audio, _ = librosa.load(input_path, sr=sr)
    if len(audio.shape) > 1:
        audio = librosa.to_mono(audio)
    return audio

def resample_and_normalize(audio, max_gain=0.6):
    audio = audio / np.abs(audio).max() * max_gain
    audio = audio / max(0.01, np.max(np.abs(audio))) * 32767 * max_gain
    return audio.astype(np.int16)

def get_processed_file(input_path, sr, encoder_sr, mel_extractor, volume_extractor, f0_extractor, 
                       fa_encoder=None, fa_decoder=None, content_encoder=None, spk_encoder=None, 
                       device='cuda', max_sec=None, f0_interpolate_mode='full'):

    if max_sec is not None:
        max_audio_44k_len = sr * max_sec
        max_audio_len = encoder_sr * max_sec
    
    # 1. 串行加载音频(必须先拿到数据才能提取特征)
    if not os.path.exists(input_path):
        print(f'\n[Error] {input_path} does not exist!')
        return None
    try:
        name = input_path.split('/')[-1].split('.')[0]
        audio_44k = load_audio(input_path, sr)
        audio = load_audio(input_path, encoder_sr)

        if max_sec is not None and max_audio_44k_len > 0:
            audio_44k = audio_44k[:min(len(audio_44k), max_audio_44k_len)]
            audio = audio[:min(len(audio), max_audio_len)]
        # 转换为 Tensor 供 GPU 任务使用
        audio_44k_t = torch.from_numpy(audio_44k).float().to(device).unsqueeze(0)
    except Exception as e:
        print(f'\n[Error] Failed to load audio. Error: {e}')
        return None

    # --- 内部并行化逻辑开始 ---
    # 定义子任务函数
    def task_f0():
        return f0_extractor.extract(audio_44k, uv_interp=False)

    def task_volume():
        return volume_extractor.extract(audio_44k)

    def task_mel():
        return mel_extractor.extract(audio_44k_t, sr).squeeze()

    def task_encoder():
        # 这里包含了原本的 FACodec 或 Content/Spk 逻辑
        with torch.no_grad():
            if fa_encoder is not None and fa_decoder is not None:
                audio_t = torch.from_numpy(wav_pad(audio)).unsqueeze(0).unsqueeze(0).to(device)
                enc_out = fa_encoder(audio_t)
                prosody = fa_encoder.get_prosody_feature(audio_t)
                content_emb_t, _, _, _, spk_emb_t = fa_decoder(enc_out, prosody, eval_vq=False, vq=True)
                return content_emb_t.squeeze(0), spk_emb_t
        return None, None

    # 使用线程池并行执行
    # 虽然 Python 有 GIL,但 PyTorch 和 C++ 扩展(如 F0 提取)会释放 GIL,实现真正的并行
    with ThreadPoolExecutor(max_workers=4) as executor:
        future_f0 = executor.submit(task_f0)
        future_vol = executor.submit(task_volume)
        future_mel = executor.submit(task_mel)
        future_enc = executor.submit(task_encoder)

        # 获取结果(阻塞直到所有任务完成)
        f0 = future_f0.result()
        volume = future_vol.result()
        mel_t = future_mel.result()
        content_emb_t, spk_emb_t = future_enc.result()

    # --- 内部并行化逻辑结束 ---

    # 3. 后处理(这些步骤依赖前面获取的所有结果)
    if f0 is None or volume is None or mel_t is None:
        return None

    seq_len = mel_t.shape[0]
    volume_t = align_data(torch.from_numpy(volume).float(), seq_len)
    
    # 对齐编码器长度
    if fa_encoder is not None:
        content_emb_t = repeat_expand_2d(content_emb_t, seq_len).T
    else:
        content_emb_t = adjust_length(content_emb_t, seq_len)

    # F0 插值与后处理
    f0_origin = f0.copy()
    if f0_interpolate_mode == 'full':
        uv = (f0 == 0)
        if len(f0[~uv]) > 0:
            f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
        else:
            return None
    elif f0_interpolate_mode == 'part':
        f0 = edge_padding(f0)
    
    f0_t = align_data(torch.from_numpy(f0).float(), seq_len)

    return dict(
        vq_post=content_emb_t, 
        spk=spk_emb_t, 
        f0=f0_t, 
        f0_origin=f0_origin, 
        vol=volume_t, 
        name=name, 
        mel=mel_t
    )