zsc commited on
Commit
cda2a68
·
1 Parent(s): acd204b
Files changed (2) hide show
  1. app.py +288 -140
  2. requirements.txt +24 -4
app.py CHANGED
@@ -1,154 +1,302 @@
1
  import gradio as gr
 
 
 
2
  import numpy as np
3
- import random
 
 
 
 
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
 
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
 
 
 
 
 
 
 
154
  demo.launch()
 
 
1
  import gradio as gr
2
+ import os
3
+ import torch
4
+ import torchaudio
5
  import numpy as np
6
+ import onnxruntime
7
+ import whisper
8
+ import io
9
+ import librosa
10
+ import math
11
+ from huggingface_hub import snapshot_download
12
+ from funasr import AutoModel
13
 
14
+ # Utils
15
+ def resample_audio(wav, original_sample_rate, target_sample_rate):
16
+ if original_sample_rate != target_sample_rate:
17
+ wav = torchaudio.transforms.Resample(
18
+ orig_freq=original_sample_rate, new_freq=target_sample_rate
19
+ )(wav)
20
+ return wav
21
 
22
+ def energy_norm_fn(wav):
23
+ if type(wav) is np.ndarray:
24
+ max_data = np.max(np.abs(wav))
25
+ wav = wav / max(max_data, 0.01) * 0.999
26
+ else:
27
+ max_data = torch.max(torch.abs(wav))
28
+ wav = wav / max(max_data, 0.01) * 0.999
29
+ return wav
30
+
31
+ def trim_silence(audio, sr, keep_left_time=0.05, keep_right_time=0.22, hop_size=240):
32
+ _, index = librosa.effects.trim(audio, top_db=20, frame_length=512, hop_length=128)
33
+ num_frames = int(math.ceil((index[1] - index[0]) / hop_size))
34
+
35
+ left_sil_samples = int(keep_left_time * sr)
36
+ right_sil_samples = int(keep_right_time * sr)
37
+
38
+ wav_len = len(audio)
39
+ start_idx = index[0] - left_sil_samples
40
+ trim_wav = audio
41
+
42
+ if start_idx > 0:
43
+ trim_wav = trim_wav[start_idx:]
44
+ else:
45
+ trim_wav = np.pad(
46
+ trim_wav, (abs(start_idx), 0), mode="constant", constant_values=0.0
47
+ )
48
+ wav_len = len(trim_wav)
49
+ out_len = int(num_frames * hop_size + (keep_left_time + keep_right_time) * sr)
50
+
51
+ if out_len < wav_len:
52
+ trim_wav = trim_wav[:out_len]
53
+ else:
54
+ trim_wav = np.pad(
55
+ trim_wav, (0, (out_len - wav_len)), mode="constant", constant_values=0.0
56
+ )
57
+ return trim_wav
58
+
59
+ class StepAudioTokenizer:
60
+ def __init__(self):
61
+ model_id = "dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online"
62
+ print(f"Loading model from Hugging Face: {model_id}")
63
+ self.model_dir = snapshot_download(model_id)
64
+
65
+ # Load FunASR model
66
+ print(f"Initializing AutoModel from {self.model_dir}")
67
+ self.funasr_model = AutoModel(
68
+ model=self.model_dir,
69
+ model_revision="main",
70
+ device="cpu",
71
+ disable_update=True
72
+ )
73
+
74
+ kms_path = os.path.join(self.model_dir, "linguistic_tokenizer.npy")
75
+ cosy_tokenizer_path = os.path.join(self.model_dir, "speech_tokenizer_v1.onnx")
76
+
77
+ if not os.path.exists(kms_path):
78
+ raise FileNotFoundError(f"KMS file not found: {kms_path}")
79
+ if not os.path.exists(cosy_tokenizer_path):
80
+ raise FileNotFoundError(f"Cosy tokenizer file not found: {cosy_tokenizer_path}")
81
+
82
+ self.kms = torch.tensor(np.load(kms_path))
83
+
84
+ providers = ["CPUExecutionProvider"]
85
+ session_option = onnxruntime.SessionOptions()
86
+ session_option.graph_optimization_level = (
87
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
88
+ )
89
+ session_option.intra_op_num_threads = 1
90
+ self.ort_session = onnxruntime.InferenceSession(
91
+ cosy_tokenizer_path, sess_options=session_option, providers=providers
92
+ )
93
+ self.chunk_size = [0, 4, 5]
94
+ self.encoder_chunk_look_back = 4
95
+ self.decoder_chunk_look_back = 1
96
+
97
+ # Identify the inference function
98
+ if hasattr(self.funasr_model, "infer_encoder"):
99
+ self.infer_func = self.funasr_model.infer_encoder
100
+ elif hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
101
+ self.infer_func = self.funasr_model.model.infer_encoder
102
+ else:
103
+ # Try to find it in the model object if it's wrapped differently
104
+ print("Warning: infer_encoder not found directly. Will check at runtime.")
105
+ self.infer_func = None
106
+
107
+ def __call__(self, audio_path):
108
+ # Load audio
109
+ audio, sr = torchaudio.load(audio_path)
110
+ # Mix to mono if stereo
111
+ if audio.shape[0] > 1:
112
+ audio = audio.mean(dim=0, keepdim=True)
113
+
114
+ _, vq02, vq06 = self.wav2token(audio, sr, False)
115
+ text = self.merge_vq0206_to_token_str(vq02, vq06)
116
+ return text
117
+
118
+ def preprocess_wav(self, audio, sample_rate, enable_trim=True, energy_norm=True):
119
+ audio = resample_audio(audio, sample_rate, 16000)
120
+ if energy_norm:
121
+ audio = energy_norm_fn(audio)
122
 
123
+ if enable_trim:
124
+ audio = audio.cpu().numpy().squeeze(0)
125
+ audio = trim_silence(audio, 16000)
126
+ audio = torch.from_numpy(audio)
127
+ audio = audio.unsqueeze(0)
128
+ return audio
129
 
130
+ def wav2token(self, audio, sample_rate, enable_trim=True, energy_norm=True):
131
+ audio = self.preprocess_wav(
132
+ audio, sample_rate, enable_trim=enable_trim, energy_norm=energy_norm
133
+ )
134
 
135
+ vq02_ori = self.get_vq02_code(audio)
136
+ vq02 = [int(x) + 65536 for x in vq02_ori]
137
+ vq06_ori = self.get_vq06_code(audio)
138
+ vq06 = [int(x) + 65536 + 1024 for x in vq06_ori]
139
+
140
+ chunk = 1
141
+ chunk_nums = min(len(vq06) // (3 * chunk), len(vq02) // (2 * chunk))
142
+ speech_tokens = []
143
+ for idx in range(chunk_nums):
144
+ speech_tokens += vq02[idx * chunk * 2 : (idx + 1) * chunk * 2]
145
+ speech_tokens += vq06[idx * chunk * 3 : (idx + 1) * chunk * 3]
146
+ return speech_tokens, vq02_ori, vq06_ori
147
+
148
+ def get_vq02_code(self, audio):
149
+ _tmp_wav = io.BytesIO()
150
+ torchaudio.save(_tmp_wav, audio, 16000, format="wav")
151
+ _tmp_wav.seek(0)
152
+
153
+ if self.infer_func is None:
154
+ # Last ditch effort to find it
155
+ if hasattr(self.funasr_model, "model") and hasattr(self.funasr_model.model, "infer_encoder"):
156
+ self.infer_func = self.funasr_model.model.infer_encoder
157
+ elif hasattr(self.funasr_model, "infer_encoder"):
158
+ self.infer_func = self.funasr_model.infer_encoder
159
+ else:
160
+ raise RuntimeError("infer_encoder method not found on FunASR model.")
161
+
162
+ # Note: Depending on funasr version, input might need to be different
163
+ # funasr usually accepts: audio path, bytes, or numpy
164
+ # If we pass bytes, we might need to ensure the model handles it.
165
+ # But let's try passing the BytesIO object wrapped in list as per original code.
166
+
167
+ try:
168
+ res = self.infer_func(
169
+ input=[_tmp_wav],
170
+ chunk_size=self.chunk_size,
171
+ encoder_chunk_look_back=self.encoder_chunk_look_back,
172
+ decoder_chunk_look_back=self.decoder_chunk_look_back,
173
+ device="cpu",
174
+ is_final=True,
175
+ cache={}
176
+ )
177
+ except TypeError as e:
178
+ print(f"Error calling infer_encoder: {e}. Trying different arguments.")
179
+ # Maybe it doesn't accept some args
180
+ res = self.infer_func(
181
+ input=[_tmp_wav],
182
+ is_final=True
183
  )
184
 
185
+ if isinstance(res, tuple):
186
+ res = res[0]
187
+
188
+ c_list = []
189
+ for j, res_ in enumerate(res):
190
+ feat = res_["enc_out"]
191
+ if len(feat) > 0:
192
+ c_list = self.dump_label([feat], self.kms)[0]
193
+ return c_list
194
+
195
+ def get_vq06_code(self, audio):
196
+ def split_audio(audio, chunk_duration=480000):
197
+ start = 0
198
+ chunks = []
199
+ while start < len(audio):
200
+ end = min(start + chunk_duration, len(audio))
201
+ chunk = audio[start:end]
202
+ if len(chunk) < 480:
203
+ pass
204
+ else:
205
+ chunks.append(chunk)
206
+ start = end
207
+ return chunks
208
+
209
+ audio = audio.squeeze(0)
210
+ chunk_audios = split_audio(audio, chunk_duration=30 * 16000)
211
+ speech_tokens = []
212
+ for chunk in chunk_audios:
213
+ duration = round(chunk.shape[0] / 16000, 2)
214
+ feat = whisper.log_mel_spectrogram(chunk, n_mels=128)
215
+ feat = feat.unsqueeze(0)
216
+ feat_len = np.array([feat.shape[2]], dtype=np.int32)
217
+ chunk_token = (
218
+ self.ort_session.run(
219
+ None,
220
+ {
221
+ self.ort_session.get_inputs()[0]
222
+ .name: feat.detach()
223
+ .cpu()
224
+ .numpy(),
225
+ self.ort_session.get_inputs()[1].name: feat_len,
226
+ },
227
+ )[0]
228
+ .flatten()
229
+ .tolist()
230
  )
231
+ speech_tokens += chunk_token
232
 
233
+ return speech_tokens
234
+
235
+ def kmean_cluster(self, samples, means):
236
+ dists = torch.cdist(samples, means)
237
+ indices = dists.argmin(dim=1).cpu().numpy()
238
+ return indices.tolist()
239
+
240
+ def dump_label(self, samples, mean):
241
+ dims = samples[0].shape[-1]
242
+ x_lens = [x.shape[1] for x in samples]
243
+ total_len = sum(x_lens)
244
+ x_sel = torch.FloatTensor(1, total_len, dims)
245
+ start_len = 0
246
+ for sample in samples:
247
+ sample_len = sample.shape[1]
248
+ end_len = start_len + sample_len
249
+ x_sel[:, start_len:end_len] = sample
250
+ start_len = end_len
251
+ dense_x = x_sel.squeeze(0)
252
+ indices = self.kmean_cluster(dense_x, mean)
253
+ indices_list = []
254
+ start_len = 0
255
+ for x_len in x_lens:
256
+ end_len = start_len + end_len
257
+ indices_list.append(indices[start_len:end_len])
258
+ return indices_list
259
+
260
+ def merge_vq0206_to_token_str(self, vq02, vq06):
261
+ _vq06 = [1024 + x for x in vq06]
262
+ result = []
263
+ i = 0
264
+ j = 0
265
+ while i < len(vq02) - 1 and j < len(_vq06) - 2:
266
+ sublist = vq02[i : i + 2] + _vq06[j : j + 3]
267
+ result.extend(sublist)
268
+ i += 2
269
+ j += 3
270
+ return "".join([f"<audio_{x}>" for x in result])
271
+
272
+
273
+ tokenizer = None
274
+
275
+ def process_audio(audio_path):
276
+ global tokenizer
277
+ if tokenizer is None:
278
+ try:
279
+ tokenizer = StepAudioTokenizer()
280
+ except Exception as e:
281
+ return f"Error loading model: {e}"
282
+
283
+ try:
284
+ if not audio_path:
285
+ return "Please upload an audio file."
286
+ tokens = tokenizer(audio_path)
287
+ return tokens
288
+ except Exception as e:
289
+ import traceback
290
+ traceback.print_exc()
291
+ return f"Error processing audio: {e}"
292
 
293
  if __name__ == "__main__":
294
+ demo = gr.Interface(
295
+ fn=process_audio,
296
+ inputs=gr.Audio(type="filepath", label="Upload WAV"),
297
+ outputs=gr.Textbox(label="Token String"),
298
+ title="Step Audio Tokenizer",
299
+ description="Upload a WAV file to convert it to token string (<audio_XXX>)."
300
+ )
301
  demo.launch()
302
+
requirements.txt CHANGED
@@ -1,6 +1,26 @@
1
  accelerate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
1
  accelerate
2
+ xformers
3
+ torch==2.8.0
4
+ torchaudio==2.8.0
5
+ torchvision==0.23.0
6
+ transformers==4.53.3
7
+ openai-whisper==20240930
8
+ onnxruntime
9
+ omegaconf==2.3.0
10
+ librosa==0.10.2.post1
11
+ sox==1.5.0
12
+ modelscope
13
+ numpy==2.2.6
14
+ six==1.16.0
15
+ hyperpyyaml
16
+ conformer==0.3.2
17
  diffusers
18
+ pillow
19
+ sentencepiece
20
+ funasr>=1.1.3
21
+ protobuf==5.29.3
22
+ gradio==5.49.1
23
+ spaces==0.42.1
24
+ matplotlib==3.10.7
25
+ llmcompressor==0.8.1
26
+ datasets==4.0.0