File size: 3,399 Bytes
f66b0a3
4ae14db
f66b0a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae14db
 
 
 
 
 
f66b0a3
 
4ae14db
f66b0a3
 
 
 
 
 
 
 
 
4ae14db
 
 
 
 
 
 
 
 
f66b0a3
 
 
 
 
 
4ae14db
 
f66b0a3
 
 
 
 
 
4ae14db
f66b0a3
4ae14db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import os
import torch
import torchaudio
import numpy as np
from frontend import CosyVoiceFrontEnd

def load_wav(wav, target_sr):
    speech, sample_rate = torchaudio.load(wav, backend='soundfile')
    speech = speech.mean(dim=0, keepdim=True)
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech

if __name__ == "__main__":

    args = argparse.ArgumentParser()
    args.add_argument('--model_dir', type=str, default="scripts/CosyVoice-BlankEN", help="tokenizer configuration directionary")
    args.add_argument('--wetext_dir', type=str, default="pengzhendong/wetext", help="path to wetext")
    args.add_argument('--sample_rate', type=int, default=24000, help="Sampling rate for prompt audio")
    args.add_argument('--prompt_text', type=str, default="希望你以后能够做的比我还好呦。", help="The text content of the prompt(reference) audio. Text or file path.")
    args.add_argument('--prompt_speech', type=str, default="asset/zero_shot_prompt.wav", help="The path to prompt(reference) audio.")
    args.add_argument('--output', type=str, default="prompt_files", help="Output data storage directory")
    args = args.parse_args()

    os.makedirs(args.output, exist_ok=True)

    frontend = CosyVoiceFrontEnd(f"{args.model_dir}",
                                args.wetext_dir,
                                "frontend-onnx/campplus.onnx",
                                "frontend-onnx/speech_tokenizer_v2.onnx",
                                f"{args.model_dir}/spk2info.pt",
                                "all")

    prompt_speech_16k = load_wav(args.prompt_speech, 16000)
    zero_shot_spk_id = ""

    if os.path.isfile(args.prompt_text):
        with open(args.prompt_text, "r") as f:
            prompt_text = f.read()
    else:
        prompt_text = args.prompt_text 
    print("prompt_text",prompt_text)
    model_input = frontend.process_prompt( prompt_text, prompt_speech_16k, args.sample_rate, zero_shot_spk_id)
    
    # model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
    #                        'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
    #                        'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
    #                        'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
    #                        'llm_embedding': embedding, 'flow_embedding': embedding}
    print("prompt speech token size:", model_input["flow_prompt_speech_token"].shape)
    assert model_input["flow_prompt_speech_token"].shape[1] >=75, f"speech_token length should >= 75, bug get {model_input['flow_prompt_speech_token'].shape[1]}"
    for k, v in model_input.items():
        if "_len" in k:
            continue
        shapes = [str(s) for s in v.shape]
        shape_str = "_".join(shapes)
        if v.dtype in (torch.int32, torch.int64):
            np.savetxt(f"{args.output}/{k}.txt", v.detach().cpu().numpy().reshape(-1), fmt="%d", delimiter=",")
        else:
            np.savetxt(f"{args.output}/{k}.txt", v.detach().cpu().numpy().reshape(-1), delimiter=",")