antigravity commited on
Commit
8f40bf2
·
1 Parent(s): b611ed5

feat: add speed control support

Browse files
app.py CHANGED
@@ -101,6 +101,7 @@ async def upload_and_tts(
101
  text: str = Form(...),
102
  language: str = Form("zh"),
103
  text_lang: str = Form(None),
 
104
  file: UploadFile = File(...)
105
  ):
106
  """
@@ -127,7 +128,7 @@ async def upload_and_tts(
127
 
128
  out_path = f"/tmp/out_{ts}.wav"
129
  # 🟢 执行 TTS
130
- genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang)
131
 
132
  # 🟢 关键:强制等待文件出现(最多等5秒)
133
  wait_time = 0
@@ -161,7 +162,8 @@ async def dynamic_tts(
161
  character_name: str = Form("Base"),
162
  prompt_text: str = Form(None),
163
  prompt_lang: str = Form("zh"),
164
- text_lang: str = Form(None), # 新增:目标文本语言(跨语言TTS)
 
165
  use_default_ref: bool = Form(True)
166
  ):
167
  """
@@ -183,7 +185,7 @@ async def dynamic_tts(
183
  genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang)
184
 
185
  out_path = f"/tmp/out_dyn_{int(time.time())}.wav"
186
- genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang)
187
 
188
  # 🟢 同样增加文件等待
189
  wait_time = 0
 
101
  text: str = Form(...),
102
  language: str = Form("zh"),
103
  text_lang: str = Form(None),
104
+ speed: float = Form(1.0), # 语速调节(0.5-2.0)
105
  file: UploadFile = File(...)
106
  ):
107
  """
 
128
 
129
  out_path = f"/tmp/out_{ts}.wav"
130
  # 🟢 执行 TTS
131
+ genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed)
132
 
133
  # 🟢 关键:强制等待文件出现(最多等5秒)
134
  wait_time = 0
 
162
  character_name: str = Form("Base"),
163
  prompt_text: str = Form(None),
164
  prompt_lang: str = Form("zh"),
165
+ text_lang: str = Form(None),
166
+ speed: float = Form(1.0), # 语速调节(0.5-2.0)
167
  use_default_ref: bool = Form(True)
168
  ):
169
  """
 
185
  genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang)
186
 
187
  out_path = f"/tmp/out_dyn_{int(time.time())}.wav"
188
+ genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang, speed=speed)
189
 
190
  # 🟢 同样增加文件等待
191
  wait_time = 0
genie_tts/Core/Inference.py CHANGED
@@ -9,6 +9,43 @@ from ..GetPhonesAndBert import get_phones_and_bert
9
  MAX_T2S_LEN = 1000
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class GENIE:
13
  def __init__(self):
14
  self.stop_event: threading.Event = threading.Event()
@@ -23,7 +60,8 @@ class GENIE:
23
  vocoder: ort.InferenceSession,
24
  prompt_encoder: Optional[ort.InferenceSession],
25
  language: str = 'japanese',
26
- text_language: str = None, # 新增:目标文本语言,默认使用参考音频语言
 
27
  ) -> Optional[np.ndarray]:
28
  # 如果未指定 text_language,则使用参考音频的语言
29
  actual_text_language = text_language if text_language else language
@@ -46,6 +84,9 @@ class GENIE:
46
  first_eos_index = eos_indices[-1][0]
47
  semantic_tokens = semantic_tokens[..., :first_eos_index]
48
 
 
 
 
49
  if prompt_encoder is None:
50
  return vocoder.run(None, {
51
  "text_seq": text_seq,
 
9
  MAX_T2S_LEN = 1000
10
 
11
 
12
+ def stretch_semantic_tokens(tokens: np.ndarray, speed: float) -> np.ndarray:
13
+ """
14
+ 语义 Token 插值(最近邻),用于实现语速调节。
15
+ 借鉴自 AstraTTS 的 StretchSemanticTokens 算法。
16
+
17
+ Args:
18
+ tokens: 原始 semantic tokens [1, 1, T]
19
+ speed: 语速系数,>1 加速,<1 减速
20
+ Returns:
21
+ 插值后的 tokens
22
+ """
23
+ if tokens is None or tokens.size == 0:
24
+ return tokens
25
+ if abs(speed - 1.0) < 0.01:
26
+ return tokens
27
+
28
+ # 提取原始 token 序列(去除批次维度)
29
+ original = tokens.flatten()
30
+ original_len = len(original)
31
+
32
+ # 计算新长度
33
+ new_len = int(round(original_len / speed))
34
+ if new_len < 1:
35
+ new_len = 1
36
+
37
+ # 最近邻插值
38
+ result = np.zeros(new_len, dtype=original.dtype)
39
+ for i in range(new_len):
40
+ old_idx = int(i * speed)
41
+ if old_idx >= original_len:
42
+ old_idx = original_len - 1
43
+ result[i] = original[old_idx]
44
+
45
+ # 恢复原始形状 [1, 1, new_len]
46
+ return result.reshape(1, 1, -1)
47
+
48
+
49
  class GENIE:
50
  def __init__(self):
51
  self.stop_event: threading.Event = threading.Event()
 
60
  vocoder: ort.InferenceSession,
61
  prompt_encoder: Optional[ort.InferenceSession],
62
  language: str = 'japanese',
63
+ text_language: str = None,
64
+ speed: float = 1.0, # 语速调节
65
  ) -> Optional[np.ndarray]:
66
  # 如果未指定 text_language,则使用参考音频的语言
67
  actual_text_language = text_language if text_language else language
 
84
  first_eos_index = eos_indices[-1][0]
85
  semantic_tokens = semantic_tokens[..., :first_eos_index]
86
 
87
+ # 🔥 语速调节:在 vocoder 前对 semantic tokens 进行插值
88
+ semantic_tokens = stretch_semantic_tokens(semantic_tokens, speed)
89
+
90
  if prompt_encoder is None:
91
  return vocoder.run(None, {
92
  "text_seq": text_seq,
genie_tts/Core/TTSPlayer.py CHANGED
@@ -93,7 +93,8 @@ class TTSPlayer:
93
  vocoder=gsv_model.VITS,
94
  prompt_encoder=gsv_model.PROMPT_ENCODER,
95
  language=gsv_model.LANGUAGE,
96
- text_language=context.current_text_language, # 新增:跨语言TTS支持
 
97
  )
98
 
99
  if audio_chunk is not None:
 
93
  vocoder=gsv_model.VITS,
94
  prompt_encoder=gsv_model.PROMPT_ENCODER,
95
  language=gsv_model.LANGUAGE,
96
+ text_language=context.current_text_language,
97
+ speed=context.current_speed, # 🔥 语速调节
98
  )
99
 
100
  if audio_chunk is not None:
genie_tts/Internal.py CHANGED
@@ -193,7 +193,8 @@ async def tts_async(
193
  play: bool = False,
194
  split_sentence: bool = False,
195
  save_path: Union[str, PathLike, None] = None,
196
- text_language: str = None, # 新增:目标文本语言,用于跨语言TTS
 
197
  ) -> AsyncIterator[bytes]:
198
  """
199
  Asynchronously generates speech from text and yields audio chunks.
@@ -242,6 +243,8 @@ async def tts_async(
242
  )
243
  # 设置目标文本语言(跨语言TTS)
244
  context.current_text_language = normalize_language(text_language) if text_language else None
 
 
245
 
246
  # 3. 使用新的回调接口启动 TTS 会话
247
  tts_player.start_session(
@@ -269,7 +272,8 @@ def tts(
269
  play: bool = False,
270
  split_sentence: bool = True,
271
  save_path: Union[str, PathLike, None] = None,
272
- text_language: str = None, # 新增:目标文本语言,用于跨语言TTS
 
273
  ) -> None:
274
  """
275
  Synchronously generates speech from text.
@@ -303,6 +307,8 @@ def tts(
303
  )
304
  # 设置目标文本语言(跨语言TTS)
305
  context.current_text_language = normalize_language(text_language) if text_language else None
 
 
306
 
307
  tts_player.start_session(
308
  play=play,
 
193
  play: bool = False,
194
  split_sentence: bool = False,
195
  save_path: Union[str, PathLike, None] = None,
196
+ text_language: str = None,
197
+ speed: float = 1.0, # 语速调节(0.5-2.0)
198
  ) -> AsyncIterator[bytes]:
199
  """
200
  Asynchronously generates speech from text and yields audio chunks.
 
243
  )
244
  # 设置目标文本语言(跨语言TTS)
245
  context.current_text_language = normalize_language(text_language) if text_language else None
246
+ # 设置语速
247
+ context.current_speed = speed
248
 
249
  # 3. 使用新的回调接口启动 TTS 会话
250
  tts_player.start_session(
 
272
  play: bool = False,
273
  split_sentence: bool = True,
274
  save_path: Union[str, PathLike, None] = None,
275
+ text_language: str = None,
276
+ speed: float = 1.0, # 语速调节(0.5-2.0)
277
  ) -> None:
278
  """
279
  Synchronously generates speech from text.
 
307
  )
308
  # 设置目标文本语言(跨语言TTS)
309
  context.current_text_language = normalize_language(text_language) if text_language else None
310
+ # 设置语速
311
+ context.current_speed = speed
312
 
313
  tts_player.start_session(
314
  play=play,
genie_tts/Utils/Shared.py CHANGED
@@ -8,7 +8,8 @@ class Context:
8
  def __init__(self):
9
  self.current_speaker: str = ''
10
  self.current_prompt_audio: Optional['ReferenceAudio'] = None
11
- self.current_text_language: Optional[str] = None # 新增:目标文本语言(跨语言TTS)
 
12
 
13
 
14
  context: Context = Context()
 
8
  def __init__(self):
9
  self.current_speaker: str = ''
10
  self.current_prompt_audio: Optional['ReferenceAudio'] = None
11
+ self.current_text_language: Optional[str] = None # 目标文本语言(跨语言TTS)
12
+ self.current_speed: float = 1.0 # 语速调节(0.5-2.0)
13
 
14
 
15
  context: Context = Context()