Huakang Chen commited on
Commit
040d82e
·
1 Parent(s): d4f7955

update app.py and requirements

Browse files
Files changed (2) hide show
  1. app.py +132 -113
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,20 +1,17 @@
1
  import os
2
  import traceback
3
- import spaces
4
  import gradio as gr
5
  import numpy as np
6
- import pyrootutils
7
  import torch
 
8
  from loguru import logger
9
- from transformers import AutoTokenizer
10
- from vllm import LLM, SamplingParams, TokensPrompt
11
  from funasr_onnx import Paraformer
12
  from huggingface_hub import snapshot_download
13
 
14
  from tools.wer import compute_wers
15
 
16
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
17
- os.environ["VLLM_USE_V1"] = "0"
18
 
19
  from i18n import i18n
20
  from text.chn_text_norm.text import Text as ChnNormedText
@@ -31,54 +28,27 @@ PARAFORMER_REPO_ID = "funasr/Paraformer-large"
31
  # logo
32
  LOGO_URL = "https://raw.githubusercontent.com/ASLP-lab/VoiceSculptor/main/assets/logo.png"
33
 
34
-
35
  model = None
36
  codec_model = None
37
  asr_model = None
38
  tokenizer = None
39
-
40
- @spaces.GPU
41
- def load_models():
42
- global model, codec_model, asr_model, tokenizer
43
-
44
- # 只有当模型为空时才加载
45
- if tokenizer is None:
46
- tokenizer = AutoTokenizer.from_pretrained(LLASA_MODEL_ID)
47
-
48
- if model is None:
49
- logger.info("🚀 Loading vLLM model on GPU...")
50
- model = LLM(
51
- model=LLASA_MODEL_ID,
52
- gpu_memory_utilization=0.90,
53
- max_model_len=2048,
54
- enable_prefix_caching=True,
55
- dtype='auto',
56
- quantization=None,
57
- enforce_eager=False,
58
- kv_cache_dtype='auto'
59
- )
60
-
61
- if codec_model is None:
62
- logger.info("🚀 Loading XCodec2...")
63
- codec_model = XCodec2Model.from_pretrained(XCODEC_MODEL_ID).eval().to("cuda")
64
-
65
- if asr_model is None:
66
- logger.info("🚀 Loading Paraformer...")
67
- paraformer_dir = snapshot_download(repo_id=PARAFORMER_REPO_ID, local_dir="checkpoints/Paraformer-large")
68
- asr_model = Paraformer(paraformer_dir, batch_size=5, quantize=True)
69
 
70
  def normalize_text_final(user_input: str) -> str:
71
  return ChnNormedText(raw_text=user_input).normalize()
72
 
73
 
74
- def extract_speech_ids(speech_tokens_str):
 
75
  speech_ids = []
76
- for token_str in speech_tokens_str:
77
- if token_str.startswith("<|s_") and token_str.endswith("|>"):
78
- num_str = token_str[4:-2]
79
- speech_ids.append(int(num_str))
80
- else:
81
- logger.warning(f"Unexpected token: {token_str}")
 
82
  return speech_ids
83
 
84
 
@@ -97,7 +67,6 @@ def get_asr(asr_model: Paraformer, wav_list: list[np.ndarray]) -> list[str]:
97
  else:
98
  texts.append(preds[0] if len(preds) > 0 else "")
99
 
100
- # 容错:batch 返回数量不一致 -> fallback
101
  if len(texts) != len(wav_list):
102
  logger.warning(f"[ASR] batch返回数量不一致: got {len(texts)} expect {len(wav_list)},fallback逐条补齐")
103
  texts = []
@@ -136,17 +105,71 @@ def get_asr(asr_model: Paraformer, wav_list: list[np.ndarray]) -> list[str]:
136
  texts.append("")
137
  return texts
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @spaces.GPU
140
- def inference_batch(
141
- model: LLM,
142
- codec_model: XCodec2Model,
143
- device: str,
144
- tokenizer: AutoTokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  refined_text: str,
146
  instruct_text: str,
147
  control_tags: str,
148
  batch_size: int = 5,
 
149
  ) -> list[tuple[int, np.ndarray]]:
 
150
  refined_text_norm = normalize_text_final(refined_text)
151
  instruct_text_norm = normalize_text_final(instruct_text)
152
 
@@ -162,61 +185,53 @@ def inference_batch(
162
  {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
163
  ]
164
 
165
- with torch.no_grad():
166
- input_ids = tokenizer.apply_chat_template(
167
- chat,
168
- tokenize=True,
169
- return_tensors="pt",
170
- continue_final_message=True,
171
- ).to(device)
172
-
173
- speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
174
- prompt_ids = input_ids.squeeze(0).tolist()
175
- prompts = [TokensPrompt(prompt_token_ids=prompt_ids) for _ in range(batch_size)]
176
-
177
- base_seed = int.from_bytes(os.urandom(4), "little")
178
-
179
- try:
180
- sampling_params_list = [
181
- SamplingParams(
182
- temperature=0.9,
183
- top_p=0.95,
184
- top_k=15,
185
- max_tokens=2048,
186
- repetition_penalty=1.05,
187
- stop_token_ids=[speech_end_id],
188
- seed=base_seed + i,
189
- )
190
- for i in range(batch_size)
191
- ]
192
- outputs = model.generate(prompts=prompts, sampling_params=sampling_params_list)
193
- except TypeError:
194
- logger.warning("[vLLM] 当前版本不支持 SamplingParams(seed=...),将不带 seed 生成")
195
- sampling_params = SamplingParams(
196
- temperature=0.9,
197
- top_p=0.95,
198
- top_k=15,
199
- max_tokens=2048,
200
- repetition_penalty=1.05,
201
- stop_token_ids=[speech_end_id],
202
- )
203
- outputs = model.generate(prompts=prompts, sampling_params=sampling_params)
204
-
205
- audios: list[tuple[int, np.ndarray]] = []
206
- for out in outputs:
207
- token_ids = out.outputs[0].token_ids
208
- if len(token_ids) > 0 and token_ids[-1] == speech_end_id:
209
- token_ids = token_ids[:-1]
210
-
211
- speech_tokens = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
212
- speech_tokens = extract_speech_ids(speech_tokens)
213
-
214
- speech_tokens_t = torch.tensor(speech_tokens, device=device).unsqueeze(0).unsqueeze(0)
215
- wav = codec_model.decode_code(speech_tokens_t)
216
- wav = wav.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32)
217
- audios.append((16000, wav))
218
-
219
- return audios
220
 
221
  def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
222
  tag_map = {
@@ -230,7 +245,8 @@ def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
230
  "音量较小": "<|音量较小|>", "音量很小": "<|音量很小|>",
231
  "语速很快": "<|语速很快|>", "语速较快": "<|语速较快|>", "语速中等": "<|语速中等|>",
232
  "语速较慢": "<|语速较慢|>", "语速很慢": "<|语速很慢|>",
233
- "开心": "<|开心|>", "生气": "<|生气|>", "难过": "<|难过|>", "惊讶": "<|惊讶|>", "厌恶": "<|厌恶|>", "害怕": "<|害怕|>",
 
234
  }
235
  tags = []
236
  for v in [gender, age, speed, volume, pitch, pitch_var, emo]:
@@ -238,21 +254,23 @@ def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
238
  tags.append(tag_map[v])
239
  return "".join(tags)
240
 
241
- @spaces.GPU(duration=120)
242
  def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitch_var, volume, speed, emo):
243
- load_models()
244
  control_tags = build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo)
 
245
  try:
246
- audios5 = inference_batch(
247
- model=model,
248
- codec_model=codec_model,
249
- device='cuda',
250
- tokenizer=tokenizer,
251
  refined_text=refined_text,
252
  instruct_text=instruct_text,
253
  control_tags=control_tags,
254
  batch_size=5,
 
255
  )
 
256
  wav_list = [wav for (_, wav) in audios5]
257
  asr_texts = get_asr(asr_model, wav_list)
258
 
@@ -264,14 +282,15 @@ def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitc
264
  logger.info(f"[ASR/WER] idx={i} wer={w:.4f} gt='{refined_text_norm}' asr='{hyp}'")
265
 
266
  best_idx = np.argsort(np.array(wers))[:3].tolist()
267
- logger.info(f"[ASR/WER] best_idx={best_idx} best_wers={[float(wers[i]) for i in best_idx]}")
268
  best3 = [audios5[i] for i in best_idx]
269
  return best3[0], best3[1], best3[2]
 
270
  except Exception as e:
271
  logger.error(f"推理/ASR/WER 失败: {e}", exc_info=True)
272
  logger.error("错误详细信息:\n" + traceback.format_exc())
273
  return None, None, None
274
 
 
275
  def build_app():
276
 
277
  INSTRUCT_TEMPLATES = {
 
1
  import os
2
  import traceback
 
3
  import gradio as gr
4
  import numpy as np
 
5
  import torch
6
+ import spaces
7
  from loguru import logger
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
  from funasr_onnx import Paraformer
10
  from huggingface_hub import snapshot_download
11
 
12
  from tools.wer import compute_wers
13
 
14
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
 
15
 
16
  from i18n import i18n
17
  from text.chn_text_norm.text import Text as ChnNormedText
 
28
  # logo
29
  LOGO_URL = "https://raw.githubusercontent.com/ASLP-lab/VoiceSculptor/main/assets/logo.png"
30
 
31
+ # ===== Global cache =====
32
  model = None
33
  codec_model = None
34
  asr_model = None
35
  tokenizer = None
36
+ device= 'cuda'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def normalize_text_final(user_input: str) -> str:
39
  return ChnNormedText(raw_text=user_input).normalize()
40
 
41
 
42
+ def extract_speech_ids(token_strs: list[str]) -> list[int]:
43
+ """把 tokenizer 输出的 token 字符串列表中形如 <|s_123|> 的 token 提取成 int id"""
44
  speech_ids = []
45
+ for t in token_strs:
46
+ if t.startswith("<|s_") and t.endswith("|>"):
47
+ num_str = t[4:-2]
48
+ try:
49
+ speech_ids.append(int(num_str))
50
+ except Exception:
51
+ logger.warning(f"Bad speech token: {t}")
52
  return speech_ids
53
 
54
 
 
67
  else:
68
  texts.append(preds[0] if len(preds) > 0 else "")
69
 
 
70
  if len(texts) != len(wav_list):
71
  logger.warning(f"[ASR] batch返回数量不一致: got {len(texts)} expect {len(wav_list)},fallback逐条补齐")
72
  texts = []
 
105
  texts.append("")
106
  return texts
107
 
108
+
109
+ def _safe_load_tokenizer(model_id: str) -> AutoTokenizer:
110
+ try:
111
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
112
+ except TypeError:
113
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
114
+
115
+ if tok.pad_token_id is None:
116
+ if tok.eos_token_id is not None:
117
+ tok.pad_token = tok.eos_token
118
+ return tok
119
+
120
+
121
+ def _safe_load_lm(model_id: str, device: str) -> AutoModelForCausalLM:
122
+ m = AutoModelForCausalLM.from_pretrained(
123
+ model_id,
124
+ trust_remote_code=True,
125
+ )
126
+ m.eval().to(device)
127
+ return m
128
+
129
  @spaces.GPU
130
+ def load_models(force_device: str | None = None):
131
+ """本地:加载并缓存模型(无 spaces/ZeroGPU)"""
132
+ global model, codec_model, asr_model, tokenizer
133
+
134
+ logger.info(f"Using device: {device}")
135
+
136
+ if tokenizer is None:
137
+ logger.info("Loading tokenizer...")
138
+ tokenizer = _safe_load_tokenizer(LLASA_MODEL_ID)
139
+
140
+ if model is None:
141
+ logger.info("Loading AutoModelForCausalLM...")
142
+ model = _safe_load_lm(LLASA_MODEL_ID, device=device)
143
+
144
+ if codec_model is None:
145
+ logger.info("Loading XCodec2...")
146
+ codec_model = XCodec2Model.from_pretrained(XCODEC_MODEL_ID).eval().to(device)
147
+
148
+ if asr_model is None:
149
+ logger.info("Loading Paraformer (funasr_onnx)...")
150
+ paraformer_dir = snapshot_download(
151
+ repo_id=PARAFORMER_REPO_ID,
152
+ local_dir="checkpoints/Paraformer-large",
153
+ local_dir_use_symlinks=False,
154
+ )
155
+ asr_model = Paraformer(paraformer_dir, batch_size=5, quantize=True)
156
+
157
+ logger.info("✅ All models loaded.")
158
+
159
+ load_models()
160
+
161
+ @torch.inference_mode()
162
+ def inference_batch_transformers(
163
+ lm: AutoModelForCausalLM,
164
+ codec: XCodec2Model,
165
+ tok: AutoTokenizer,
166
  refined_text: str,
167
  instruct_text: str,
168
  control_tags: str,
169
  batch_size: int = 5,
170
+ max_new_tokens: int = 2048,
171
  ) -> list[tuple[int, np.ndarray]]:
172
+
173
  refined_text_norm = normalize_text_final(refined_text)
174
  instruct_text_norm = normalize_text_final(instruct_text)
175
 
 
185
  {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
186
  ]
187
 
188
+ input_ids_1 = tok.apply_chat_template(
189
+ chat,
190
+ tokenize=True,
191
+ return_tensors="pt",
192
+ continue_final_message=True,
193
+ ).to(device)
194
+
195
+ speech_end_id = tok.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
196
+ pad_id = tok.pad_token_id if tok.pad_token_id is not None else (tok.eos_token_id or speech_end_id)
197
+
198
+ outputs = lm.generate(
199
+ input_ids=input_ids_1,
200
+ do_sample=True,
201
+ top_p=0.95,
202
+ temperature=0.9,
203
+ top_k=15,
204
+ repetition_penalty=1.05,
205
+ max_new_tokens=max_new_tokens,
206
+ eos_token_id=speech_end_id,
207
+ pad_token_id=pad_id,
208
+ num_return_sequences=batch_size,
209
+ use_cache=True,
210
+ )
211
+
212
+ prompt_len = input_ids_1.shape[1]
213
+ audios: list[tuple[int, np.ndarray]] = []
214
+
215
+ for i in range(outputs.shape[0]):
216
+ gen_ids = outputs[i, prompt_len:].tolist()
217
+ if len(gen_ids) > 0 and gen_ids[-1] == speech_end_id:
218
+ gen_ids = gen_ids[:-1]
219
+
220
+ token_strs = tok.convert_ids_to_tokens(gen_ids, skip_special_tokens=False)
221
+ speech_ids = extract_speech_ids(token_strs)
222
+
223
+ if len(speech_ids) == 0:
224
+ logger.warning("[TTS] No speech tokens extracted, outputting silence.")
225
+ audios.append((16000, np.zeros((16000,), dtype=np.float32)))
226
+ continue
227
+
228
+ speech_tokens_t = torch.tensor(speech_ids, device=device).unsqueeze(0).unsqueeze(0)
229
+ wav = codec.decode_code(speech_tokens_t)
230
+ wav = wav.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.float32)
231
+ audios.append((16000, wav))
232
+
233
+ return audios
234
+
 
 
 
 
 
 
 
 
235
 
236
  def build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo):
237
  tag_map = {
 
245
  "音量较小": "<|音量较小|>", "音量很小": "<|音量很小|>",
246
  "语速很快": "<|语速很快|>", "语速较快": "<|语速较快|>", "语速中等": "<|语速中等|>",
247
  "语速较慢": "<|语速较慢|>", "语速很慢": "<|语速很慢|>",
248
+ "开心": "<|开心|>", "生气": "<|生气|>", "难过": "<|难过|>", "惊讶": "<|惊讶|>",
249
+ "厌恶": "<|厌恶|>", "害怕": "<|害怕|>",
250
  }
251
  tags = []
252
  for v in [gender, age, speed, volume, pitch, pitch_var, emo]:
 
254
  tags.append(tag_map[v])
255
  return "".join(tags)
256
 
257
+ @spaces.GPU(duration=240)
258
  def inference_select_best3(refined_text, instruct_text, age, gender, pitch, pitch_var, volume, speed, emo):
259
+
260
  control_tags = build_control_tags(age, gender, pitch, pitch_var, volume, speed, emo)
261
+
262
  try:
263
+ audios5 = inference_batch_transformers(
264
+ lm=model,
265
+ codec=codec_model,
266
+ tok=tokenizer,
 
267
  refined_text=refined_text,
268
  instruct_text=instruct_text,
269
  control_tags=control_tags,
270
  batch_size=5,
271
+ max_new_tokens=2048,
272
  )
273
+
274
  wav_list = [wav for (_, wav) in audios5]
275
  asr_texts = get_asr(asr_model, wav_list)
276
 
 
282
  logger.info(f"[ASR/WER] idx={i} wer={w:.4f} gt='{refined_text_norm}' asr='{hyp}'")
283
 
284
  best_idx = np.argsort(np.array(wers))[:3].tolist()
 
285
  best3 = [audios5[i] for i in best_idx]
286
  return best3[0], best3[1], best3[2]
287
+
288
  except Exception as e:
289
  logger.error(f"推理/ASR/WER 失败: {e}", exc_info=True)
290
  logger.error("错误详细信息:\n" + traceback.format_exc())
291
  return None, None, None
292
 
293
+
294
  def build_app():
295
 
296
  INSTRUCT_TEMPLATES = {
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  gradio
2
  torch
3
  transformers
4
- vllm
5
  funasr-onnx
6
  huggingface_hub
7
  jiwer
 
1
  gradio
2
  torch
3
  transformers
4
+ spaces
5
  funasr-onnx
6
  huggingface_hub
7
  jiwer