Huakang Chen commited on
Commit
07cdf55
·
1 Parent(s): 3ba0af4

update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -83
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import traceback
3
-
4
  import gradio as gr
5
  import numpy as np
6
  import pyrootutils
@@ -32,6 +32,38 @@ PARAFORMER_REPO_ID = "funasr/Paraformer-large"
32
  LOGO_URL = "https://raw.githubusercontent.com/ASLP-lab/VoiceSculptor/main/assets/logo.png"
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def normalize_text_final(user_input: str) -> str:
36
  return ChnNormedText(raw_text=user_input).normalize()
37
 
@@ -101,7 +133,7 @@ def get_asr(asr_model: Paraformer, wav_list: list[np.ndarray]) -> list[str]:
101
  texts.append("")
102
  return texts
103
 
104
-
105
  def inference_batch(
106
  model: LLM,
107
  codec_model: XCodec2Model,
@@ -183,37 +215,61 @@ def inference_batch(
183
 
184
  return audios
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def build_app():
188
- device = "cuda" if torch.cuda.is_available() else "cpu"
189
- logger.info(f"✅ Loading models on device={device}")
190
-
191
- # ===== LLaSA =====
192
- tokenizer = AutoTokenizer.from_pretrained(LLASA_MODEL_ID)
193
-
194
- model = LLM(
195
- model=LLASA_MODEL_ID,
196
- gpu_memory_utilization=0.90,
197
- max_model_len=2048,
198
- enable_prefix_caching=True,
199
- dtype="auto",
200
- quantization=None,
201
- enforce_eager=False,
202
- kv_cache_dtype="auto",
203
- )
204
-
205
- # ===== XCodec2 =====
206
- codec_model = XCodec2Model.from_pretrained(XCODEC_MODEL_ID).eval().to(device)
207
-
208
- # ===== Paraformer =====
209
- paraformer_dir = snapshot_download(
210
- repo_id=PARAFORMER_REPO_ID,
211
- local_dir="checkpoints/Paraformer-large",
212
- local_dir_use_symlinks=False,
213
- )
214
- asr_model = Paraformer(paraformer_dir, batch_size=5, quantize=True)
215
-
216
- logger.info("✅ Models loaded: VoiceSculptor + xcodec2 + Paraformer")
217
 
218
  INSTRUCT_TEMPLATES = {
219
  "自定义": "",
@@ -263,58 +319,6 @@ def build_app():
263
  "ASMR-气声耳语": "现在,让我在你耳边轻声细语。听到我的声音了吗?放松你的头皮,感受每一个毛孔都在呼吸。",
264
  }
265
 
266
- def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
267
- tag_map = {
268
- "小孩": "<|小孩|>", "青年": "<|青年|>", "中年": "<|中年|>", "老年": "<|老年|>",
269
- "男性": "<|男性|>", "女性": "<|女性|>",
270
- "音调很高": "<|音调很高|>", "音调较高": "<|音调较高|>", "音调中等": "<|音调中等|>",
271
- "音调较低": "<|音调较低|>", "音调很低": "<|音调很低|>",
272
- "音调变化很强": "<|音调变化很强|>", "音调变化较强": "<|音调变化较强|>", "音调变化一般": "<|音调变化一般|>",
273
- "音调变化较弱": "<|音调变化较弱|>", "音调变化很弱": "<|音调变化很弱|>",
274
- "音量很大": "<|音量很大|>", "音量较大": "<|音量较大|>", "音量中等": "<|音量中等|>",
275
- "音量较小": "<|音量较小|>", "音量很小": "<|音量很小|>",
276
- "语速很快": "<|语速很快|>", "语速较快": "<|语速较快|>", "语速中等": "<|语速中等|>",
277
- "语速较慢": "<|语速较慢|>", "语速很慢": "<|语速很慢|>",
278
- "开心": "<|开心|>", "生气": "<|生气|>", "难过": "<|难过|>", "惊讶": "<|惊讶|>", "厌恶": "<|厌恶|>", "害怕": "<|害怕|>",
279
- }
280
- tags = []
281
- for v in [gender, age, speed, volume, pitch, pitch_var, emo]:
282
- if v != "不指定":
283
- tags.append(tag_map[v])
284
- return "".join(tags)
285
-
286
- def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitch_var, volume, speed, emo):
287
- control_tags = build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo)
288
- try:
289
- audios5 = inference_batch(
290
- model=model,
291
- codec_model=codec_model,
292
- device=device,
293
- tokenizer=tokenizer,
294
- refined_text=refined_text,
295
- instruct_text=instruct_text,
296
- control_tags=control_tags,
297
- batch_size=5,
298
- )
299
- wav_list = [wav for (_, wav) in audios5]
300
- asr_texts = get_asr(asr_model, wav_list)
301
-
302
- refined_text_norm = normalize_text_final(refined_text)
303
- gt_texts = [refined_text_norm] * len(asr_texts)
304
- wers = compute_wers(gt_texts, asr_texts, lang="zh")
305
-
306
- for i, (hyp, w) in enumerate(zip(asr_texts, wers)):
307
- logger.info(f"[ASR/WER] idx={i} wer={w:.4f} gt='{refined_text_norm}' asr='{hyp}'")
308
-
309
- best_idx = np.argsort(np.array(wers))[:3].tolist()
310
- logger.info(f"[ASR/WER] best_idx={best_idx} best_wers={[float(wers[i]) for i in best_idx]}")
311
- best3 = [audios5[i] for i in best_idx]
312
- return best3[0], best3[1], best3[2]
313
- except Exception as e:
314
- logger.error(f"推理/ASR/WER 失败: {e}", exc_info=True)
315
- logger.error("错误详细信息:\n" + traceback.format_exc())
316
- return None, None, None
317
-
318
  THEME = gr.themes.Soft(
319
  primary_hue="orange",
320
  secondary_hue="cyan",
 
1
  import os
2
  import traceback
3
+ import spaces
4
  import gradio as gr
5
  import numpy as np
6
  import pyrootutils
 
32
  LOGO_URL = "https://raw.githubusercontent.com/ASLP-lab/VoiceSculptor/main/assets/logo.png"
33
 
34
 
35
+ model = None
36
+ codec_model = None
37
+ asr_model = None
38
+ tokenizer = None
39
+
40
+ @spaces.GPU
41
+ def load_models():
42
+ global model, codec_model, asr_model, tokenizer
43
+
44
+ # 只有当模型为空时才加载
45
+ if tokenizer is None:
46
+ tokenizer = AutoTokenizer.from_pretrained(LLASA_MODEL_ID)
47
+
48
+ if model is None:
49
+ logger.info("🚀 Loading vLLM model on GPU...")
50
+ model = LLM(
51
+ model=LLASA_MODEL_ID,
52
+ gpu_memory_utilization=0.8,
53
+ max_model_len=2048,
54
+ enforce_eager=True,
55
+ device="cuda"
56
+ )
57
+
58
+ if codec_model is None:
59
+ logger.info("🚀 Loading XCodec2...")
60
+ codec_model = XCodec2Model.from_pretrained(XCODEC_MODEL_ID).eval().to("cuda")
61
+
62
+ if asr_model is None:
63
+ logger.info("🚀 Loading Paraformer...")
64
+ paraformer_dir = snapshot_download(repo_id=PARAFORMER_REPO_ID, local_dir="checkpoints/Paraformer-large")
65
+ asr_model = Paraformer(paraformer_dir, batch_size=5, quantize=True)
66
+
67
  def normalize_text_final(user_input: str) -> str:
68
  return ChnNormedText(raw_text=user_input).normalize()
69
 
 
133
  texts.append("")
134
  return texts
135
 
136
+ @spaces.GPU
137
  def inference_batch(
138
  model: LLM,
139
  codec_model: XCodec2Model,
 
215
 
216
  return audios
217
 
218
+ def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
219
+ tag_map = {
220
+ "小孩": "<|小孩|>", "青年": "<|青年|>", "中年": "<|中年|>", "老年": "<|老年|>",
221
+ "男性": "<|男性|>", "女性": "<|女性|>",
222
+ "音调很高": "<|音调很高|>", "音调较高": "<|音调较高|>", "音调中等": "<|音调中等|>",
223
+ "音调较低": "<|音调较低|>", "音调很低": "<|音调很低|>",
224
+ "音调变化很强": "<|音调变化很强|>", "音调变化较强": "<|音调变化较强|>", "音调变化一般": "<|音调变化一般|>",
225
+ "音调变化较弱": "<|音调变化较弱|>", "音调变化很弱": "<|音调变化很弱|>",
226
+ "音量很大": "<|音量很大|>", "音量较大": "<|音量较大|>", "音量中等": "<|音量中等|>",
227
+ "音量较小": "<|音量较小|>", "音量很小": "<|音量很小|>",
228
+ "语速很快": "<|语速很快|>", "语速较快": "<|语速较快|>", "语速中等": "<|语速中等|>",
229
+ "语速较慢": "<|语速较慢|>", "语速很慢": "<|语速很慢|>",
230
+ "开心": "<|开心|>", "生气": "<|生气|>", "难过": "<|难过|>", "惊讶": "<|惊讶|>", "厌恶": "<|厌恶|>", "害怕": "<|害怕|>",
231
+ }
232
+ tags = []
233
+ for v in [gender, age, speed, volume, pitch, pitch_var, emo]:
234
+ if v != "不指定":
235
+ tags.append(tag_map[v])
236
+ return "".join(tags)
237
+
238
+ @spaces.GPU(duration=120)
239
+ def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitch_var, volume, speed, emo):
240
+ load_models()
241
+ control_tags = build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo)
242
+ try:
243
+ audios5 = inference_batch(
244
+ model=model,
245
+ codec_model=codec_model,
246
+ device='cuda',
247
+ tokenizer=tokenizer,
248
+ refined_text=refined_text,
249
+ instruct_text=instruct_text,
250
+ control_tags=control_tags,
251
+ batch_size=5,
252
+ )
253
+ wav_list = [wav for (_, wav) in audios5]
254
+ asr_texts = get_asr(asr_model, wav_list)
255
+
256
+ refined_text_norm = normalize_text_final(refined_text)
257
+ gt_texts = [refined_text_norm] * len(asr_texts)
258
+ wers = compute_wers(gt_texts, asr_texts, lang="zh")
259
+
260
+ for i, (hyp, w) in enumerate(zip(asr_texts, wers)):
261
+ logger.info(f"[ASR/WER] idx={i} wer={w:.4f} gt='{refined_text_norm}' asr='{hyp}'")
262
+
263
+ best_idx = np.argsort(np.array(wers))[:3].tolist()
264
+ logger.info(f"[ASR/WER] best_idx={best_idx} best_wers={[float(wers[i]) for i in best_idx]}")
265
+ best3 = [audios5[i] for i in best_idx]
266
+ return best3[0], best3[1], best3[2]
267
+ except Exception as e:
268
+ logger.error(f"推理/ASR/WER 失败: {e}", exc_info=True)
269
+ logger.error("错误详细信息:\n" + traceback.format_exc())
270
+ return None, None, None
271
 
272
  def build_app():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  INSTRUCT_TEMPLATES = {
275
  "自定义": "",
 
319
  "ASMR-气声耳语": "现在,让我在你耳边轻声细语。听到我的声音了吗?放松你的头皮,感受每一个毛孔都在呼吸。",
320
  }
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  THEME = gr.themes.Soft(
323
  primary_hue="orange",
324
  secondary_hue="cyan",