File size: 6,589 Bytes
34b0b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from slam_llm.utils.train_utils import print_module_size
import torch
import torchaudio
import os
import torch.nn as nn
import uuid
import logging
logger = logging.getLogger(__name__)

def setup_codec(train_config, model_config, **kwargs):
    import sys
    sys.path.append(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/Matcha-TTS"))
    from cosyvoice.cli.cosyvoice import CosyVoice,CosyVoice2
    if model_config.cosyvoice_version==1:
        codec_decoder = CosyVoice(model_config.codec_decoder_path, load_jit=False, load_trt=False, fp16=False)
    elif model_config.cosyvoice_version==2:
        codec_decoder = CosyVoice2(model_config.codec_decoder_path, load_jit=False, load_trt=False, fp16=False)
    else:
        raise NotImplementedError
    codec_decoder_module = nn.ModuleList((codec_decoder.model.flow,codec_decoder.model.hift))

    print_module_size(codec_decoder_module, model_config.codec_decoder_type + " Codec", int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
    
    return codec_decoder

def get_single_layer_answer_token(audio_tokens, num_latency_tokens, padding_token, end_of_audio):
    audio_length = len(audio_tokens) + num_latency_tokens + 1  # 1 is due to end of audio token
    result = [padding_token] * num_latency_tokens + list(audio_tokens) + [end_of_audio]
    result_tensor = torch.tensor(result).unsqueeze(0)
    return result_tensor, audio_length

def get_group_answer_token(audio_tokens, num_latency_tokens, padding_token, end_of_audio, num_layers):
    padded_audio_tokens = audio_tokens + [end_of_audio]
    padding_needed = (num_layers - len(padded_audio_tokens) % num_layers ) % num_layers
    
    # Add padding to ensure even distribution across layers
    padded_audio_tokens = padded_audio_tokens + [padding_token] * padding_needed
    total_length = len(padded_audio_tokens)
    audio_length = total_length // num_layers + num_latency_tokens

    # Create the result for each layer
    result = []
    for layer in range(num_layers):
        layer_tokens = [padding_token] * num_latency_tokens
        layer_tokens.extend(padded_audio_tokens[layer::num_layers])
        result.append(torch.tensor(layer_tokens))
    
    result_tensor = torch.stack(result)
    return result_tensor, audio_length

def audio_decode_cosyvoice(audio_tokens, model_config, codec_decoder, audio_prompt_path=None, code_layer=1, num_latency_tokens=1, speed=1.0, replace_token=4095):
    """
    Generate audio from tokens with optional tone and prompt embedding.

    Args:
        audio_tokens (list): List of audio tokens to be processed.
        model_config: Configuration object containing vocab settings.
        codec_decoder: Codec decoder for generating audio.
        audio_prompt_path (str, optional): Path to the audio prompt file. Required when tone_dir is not "default_tone".
        code_layer (int, optional): Number of code layers. Defaults to 1.
        num_latency_tokens (int, optional): Number of latency tokens to ignore. Defaults to 0.
        speed (float, optional): Speed factor for audio generation. Defaults to 1.0.
    
    Returns:
        torch.Tensor: Generated audio waveform.
    """
    # Reshape audio tokens based on code_layer
    if code_layer > 1:
        audio_tokens_tensor = torch.stack(audio_tokens, dim=0)
        audio_tokens_permuted = audio_tokens_tensor.permute(1, 0)
        audio_tokens = audio_tokens_permuted.reshape(-1).unsqueeze(0)
        audio_tokens = audio_tokens[..., num_latency_tokens * code_layer:]
    elif code_layer == 1:
        audio_tokens = torch.cat(audio_tokens, dim=-1).unsqueeze(0)
        audio_tokens = audio_tokens[..., num_latency_tokens:]
    else:
        audio_tokens = audio_tokens[..., num_latency_tokens:]

    # Get vocabulary configuration for end of audio (EOA) and padding token
    eoa = model_config.vocab_config.eoa
    pad_a = model_config.vocab_config.pad_a

    # Truncate audio tokens at the EOA token
    if eoa not in audio_tokens[0]:
        return None
    end_index = torch.nonzero(audio_tokens[0] == eoa)[0]
    audio_tokens = audio_tokens[..., :end_index]

    # Handle padding tokens if present, # FIXME: this is a temporary fix for the padding issue, where the padding token may be included in the audio tokens
    if pad_a in audio_tokens:
        audio_tokens = audio_tokens.masked_fill(audio_tokens == pad_a, replace_token)
    if model_config.save_audio_token:
        return audio_tokens
    if audio_tokens.numel()==0: 
        return None

    this_uuid = str(uuid.uuid1())  # Generate a unique ID for this audio generation

    from utils.cosyvoice.utils.file_utils import load_wav
    prompt_speech_16k = load_wav(audio_prompt_path, 16000)
    flow_prompt_speech_token, flow_prompt_speech_token_len = codec_decoder.frontend._extract_speech_token(prompt_speech_16k)
    if model_config.cosyvoice_version==1:
        prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
        prompt_speech_feat, prompt_speech_feat_len = codec_decoder.frontend._extract_speech_feat(prompt_speech_22050)
    elif model_config.cosyvoice_version==2:
        prompt_speech_24000 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(prompt_speech_16k)
        prompt_speech_feat, prompt_speech_feat_len = codec_decoder.frontend._extract_speech_feat(prompt_speech_24000)
    flow_embedding = codec_decoder.frontend._extract_spk_embedding(prompt_speech_16k)

    # Convert tokens to audio waveform
    if model_config.cosyvoice_version==1:
        audio_hat = codec_decoder.model.token2wav(
            token=audio_tokens,
            prompt_token=flow_prompt_speech_token,
            prompt_feat=prompt_speech_feat,
            embedding=flow_embedding,
            uuid=this_uuid,
            finalize=True,
            speed=speed
        )
    elif model_config.cosyvoice_version==2:
        audio_hat = codec_decoder.model.token2wav(
            token=audio_tokens,
            prompt_token=flow_prompt_speech_token,
            prompt_feat=prompt_speech_feat,
            embedding=flow_embedding,
            uuid=this_uuid,
            token_offset=0,
            finalize=True,
            speed=speed
        )
    else:
        raise NotImplementedError
    return audio_hat

def layershift(input_id, layer, stride=4160, shift=152000):
    return input_id + shift + layer * stride

def simple_shift(input_id, layer, stride=4160, shift=152000):
    return input_id + shift