File size: 10,915 Bytes
6595c29
cda2a68
 
 
6595c29
cda2a68
 
 
 
 
 
 
6595c29
cda2a68
 
 
 
 
 
 
6595c29
cda2a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e7565e
cda2a68
 
 
 
0e7565e
 
cda2a68
0e7565e
cda2a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6595c29
cda2a68
 
 
 
 
 
6595c29
cda2a68
 
 
 
6595c29
cda2a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6595c29
 
cda2a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6595c29
cda2a68
6595c29
cda2a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6595c29
 
cda2a68
 
 
 
 
 
 
6595c29
cda2a68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import gradio as gr
import os
import torch
import torchaudio
import numpy as np
import onnxruntime
import whisper
import io
import librosa
import math
from huggingface_hub import snapshot_download
from funasr import AutoModel

# Utils
def resample_audio(wav, original_sample_rate, target_sample_rate):
    if original_sample_rate != target_sample_rate:
        wav = torchaudio.transforms.Resample(
            orig_freq=original_sample_rate, new_freq=target_sample_rate
        )(wav)
    return wav

def energy_norm_fn(wav):
    if type(wav) is np.ndarray:
        max_data = np.max(np.abs(wav))
        wav = wav / max(max_data, 0.01) * 0.999
    else:
        max_data = torch.max(torch.abs(wav))
        wav = wav / max(max_data, 0.01) * 0.999
    return wav

def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size=240):
    _, index = librosa.effects.trim(audio, top_db=20, frame_length=512, hop_length=128)
    num_frames = int(math.ceil((index[1] - index[0]) / hop_size))

    left_sil_samples = int(keep_left_time * sr)
    right_sil_samples = int(keep_right_time * sr)

    wav_len = len(audio)
    start_idx = index[0] - left_sil_samples
    trim_wav = audio

    if start_idx > 0:
        trim_wav = trim_wav[start_idx:]
    else:
        trim_wav = np.pad(
            trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0
        )
    wav_len = len(trim_wav)
    out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr)

    if out_len < wav_len:
        trim_wav = trim_wav[:out_len]
    else:
        trim_wav = np.pad(
            trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0
        )
    return trim_wav

class StepAudioTokenizer:
    def __init__(self):
        model_id = "stepfun-ai/Step-Audio-Tokenizer"
        print(f"Loading model from Hugging Face: {model_id}")
        self.model_dir = snapshot_download(model_id)
        
        # Load FunASR model
        paraformer_dir = os.path.join(self.model_dir, "dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online")
        print(f"Initializing AutoModel from {paraformer_dir}")
        self.funasr_model = AutoModel(
            model=paraformer_dir,
            model_revision="main",
            device="cpu",
            disable_update=True
        )

        kms_path = os.path.join(self.model_dir, "linguistic_tokenizer.npy")
        cosy_tokenizer_path = os.path.join(self.model_dir, "speech_tokenizer_v1.onnx")

        if not os.path.exists(kms_path):
             raise FileNotFoundError(f"KMS file not found: {kms_path}")
        if not os.path.exists(cosy_tokenizer_path):
             raise FileNotFoundError(f"Cosy tokenizer file not found: {cosy_tokenizer_path}")

        self.kms = torch.tensor(np.load(kms_path))

        providers = ["CPUExecutionProvider"]
        session_option = onnxruntime.SessionOptions()
        session_option.graph_optimization_level = (
            onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        )
        session_option.intra_op_num_threads = 1
        self.ort_session = onnxruntime.InferenceSession(
            cosy_tokenizer_path, sess_options=session_option, providers=providers
        )
        self.chunk_size = [0, 4, 5]
        self.encoder_chunk_look_back = 4
        self.decoder_chunk_look_back = 1
        
        # Identify the inference function
        if hasattr(self.funasr_model, "infer_encoder"):
             self.infer_func = self.funasr_model.infer_encoder
        elif hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
             self.infer_func = self.funasr_model.model.infer_encoder
        else:
             # Try to find it in the model object if it's wrapped differently
             print("Warning: infer_encoder not found directly. Will check at runtime.")
             self.infer_func = None

    def __call__(self, audio_path):
        # Load audio
        audio, sr = torchaudio.load(audio_path)
        # Mix to mono if stereo
        if audio.shape[0] > 1:
            audio = audio.mean(dim=0, keepdim=True)
            
        _, vq02, vq06 = self.wav2token(audio, sr, False)
        text = self.merge_vq0206_to_token_str(vq02, vq06)
        return text

    def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True):
        audio = resample_audio(audio, sample_rate, 16000)
        if energy_norm:
            audio = energy_norm_fn(audio)

        if enable_trim:
            audio = audio.cpu().numpy().squeeze(0)
            audio = trim_silence(audio, 16000)
            audio = torch.from_numpy(audio)
            audio = audio.unsqueeze(0)
        return audio

    def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True):
        audio = self.preprocess_wav(
            audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm
        )

        vq02_ori = self.get_vq02_code(audio)
        vq02 = [int(x) + 65536 for x in vq02_ori]
        vq06_ori = self.get_vq06_code(audio)
        vq06 = [int(x) + 65536 + 1024 for x in vq06_ori]

        chunk = 1
        chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk))
        speech_tokens = []
        for idx in range(chunk_nums):
            speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2]
            speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3]
        return speech_tokens, vq02_ori, vq06_ori

    def get_vq02_code(self, audio):
        _tmp_wav = io.BytesIO()
        torchaudio.save(_tmp_wav, audio, 16000, format="wav")
        _tmp_wav.seek(0)
        
        if self.infer_func is None:
             # Last ditch effort to find it
             if hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
                 self.infer_func = self.funasr_model.model.infer_encoder
             elif hasattr(self.funasr_model, "infer_encoder"):
                 self.infer_func = self.funasr_model.infer_encoder
             else:
                 raise RuntimeError("infer_encoder method not found on FunASR model.")

        # Note: Depending on funasr version, input might need to be different
        # funasr usually accepts: audio path, bytes, or numpy
        # If we pass bytes, we might need to ensure the model handles it.
        # But let's try passing the BytesIO object wrapped in list as per original code.
        
        try:
            res = self.infer_func(
                input=[_tmp_wav], 
                chunk_size=self.chunk_size,
                encoder_chunk_look_back=self.encoder_chunk_look_back,
                decoder_chunk_look_back=self.decoder_chunk_look_back,
                device="cpu",
                is_final=True,
                cache={} 
            )
        except TypeError as e:
            print(f"Error calling infer_encoder: {e}. Trying different arguments.")
            # Maybe it doesn't accept some args
            res = self.infer_func(
                input=[_tmp_wav],
                is_final=True
            )

        if isinstance(res, tuple):
             res = res[0]
        
        c_list = []
        for j, res_ in enumerate(res):
            feat = res_["enc_out"]
            if len(feat) > 0:
                c_list = self.dump_label([feat], self.kms)[0]
        return c_list

    def get_vq06_code(self, audio):
        def split_audio(audio, chunk_duration=480000):
            start = 0
            chunks = []
            while start < len(audio):
                end = min(start + chunk_duration, len(audio))
                chunk = audio[start:end]
                if len(chunk) < 480:
                    pass
                else:
                    chunks.append(chunk)
                start = end
            return chunks

        audio = audio.squeeze(0)
        chunk_audios = split_audio(audio, chunk_duration=30 * 16000)
        speech_tokens = []
        for chunk in chunk_audios:
            duration = round(chunk.shape[0] / 16000, 2)
            feat = whisper.log_mel_spectrogram(chunk, n_mels=128)
            feat = feat.unsqueeze(0)
            feat_len = np.array([feat.shape[2]], dtype=np.int32)
            chunk_token = (
                self.ort_session.run(
                    None,
                    {
                        self.ort_session.get_inputs()[0]
                        .name: feat.detach()
                        .cpu()
                        .numpy(),
                        self.ort_session.get_inputs()[1].name: feat_len,
                    },
                )[0]
                .flatten()
                .tolist()
            )
            speech_tokens += chunk_token

        return speech_tokens

    def kmean_cluster(self, samples, means):
        dists = torch.cdist(samples, means)
        indices = dists.argmin(dim=1).cpu().numpy()
        return indices.tolist()

    def dump_label(self, samples, mean):
        dims = samples[0].shape[-1]
        x_lens = [x.shape[1] for x in samples]
        total_len = sum(x_lens)
        x_sel = torch.FloatTensor(1, total_len, dims)
        start_len = 0
        for sample in samples:
            sample_len = sample.shape[1]
            end_len = start_len + sample_len
            x_sel[:, start_len:end_len] = sample
            start_len = end_len
        dense_x = x_sel.squeeze(0)
        indices = self.kmean_cluster(dense_x, mean)
        indices_list = []
        start_len = 0
        for x_len in x_lens:
            end_len = start_len + end_len
            indices_list.append(indices[start_len:end_len])
        return indices_list

    def merge_vq0206_to_token_str(self, vq02, vq06):
        _vq06 = [1024 + x for x in vq06]
        result = []
        i = 0
        j = 0
        while i < len(vq02) - 1 and j < len(_vq06) - 2:
            sublist = vq02[i : i + 2] + _vq06[j : j + 3]
            result.extend(sublist)
            i += 2
            j += 3
        return "".join([f"<audio_{x}>" for x in result])


tokenizer = None

def process_audio(audio_path):
    global tokenizer
    if tokenizer is None:
        try:
             tokenizer = StepAudioTokenizer()
        except Exception as e:
             return f"Error loading model: {e}"
    
    try:
        if not audio_path:
             return "Please upload an audio file."
        tokens = tokenizer(audio_path)
        return tokens
    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"Error processing audio: {e}"

if __name__ == "__main__":
    demo = gr.Interface(
        fn=process_audio,
        inputs=gr.Audio(type="filepath", label="Upload WAV"),
        outputs=gr.Textbox(label="Token String"),
        title="Step Audio Tokenizer",
        description="Upload a WAV file to convert it to token string (<audio_XXX>)."
    )
    demo.launch()