antigravity commited on
Commit
4379c64
·
1 Parent(s): 82f54c3

fix: add retry mechanism to prevent EOS early termination sentence dropping

Browse files
Files changed (1) hide show
  1. genie_tts/Core/Inference.py +51 -23
genie_tts/Core/Inference.py CHANGED
@@ -115,8 +115,13 @@ class GENIE:
115
  first_stage_decoder: ort.InferenceSession,
116
  stage_decoder: ort.InferenceSession,
117
  ) -> Optional[np.ndarray]:
118
- """在CPU上运行T2S模型"""
119
- # Encoder
 
 
 
 
 
120
  x, prompts = encoder.run(
121
  None,
122
  {
@@ -127,30 +132,53 @@ class GENIE:
127
  "ssl_content": ssl_content,
128
  },
129
  )
130
-
131
- # First Stage Decoder
132
- y, y_emb, *present_key_values = first_stage_decoder.run(
133
- None, {"x": x, "prompts": prompts}
134
- )
135
-
136
- # Stage Decoder
137
  input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()]
138
- idx: int = 0
139
- for idx in range(0, 500):
 
 
140
  if self.stop_event.is_set():
141
  return None
142
- input_feed = {
143
- name: data
144
- for name, data in zip(input_names, [y, y_emb, *present_key_values])
145
- }
146
- outputs = stage_decoder.run(None, input_feed)
147
- y, y_emb, stop_condition_tensor, *present_key_values = outputs
148
-
149
- if stop_condition_tensor:
150
- break
151
-
152
- y[0, -1] = 0
153
- return np.expand_dims(y[:, -idx:], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  tts_client: GENIE = GENIE()
 
115
  first_stage_decoder: ort.InferenceSession,
116
  stage_decoder: ort.InferenceSession,
117
  ) -> Optional[np.ndarray]:
118
+ """在CPU上运行T2S模型,带重试机制防止 EOS 过早终止"""
119
+
120
+ # 动态阈值:最小期望 tokens 数量(参考 AstraTTS)
121
+ min_expected_tokens = max(8, text_seq.shape[-1] * 2)
122
+ max_retries = 5
123
+
124
+ # Encoder 只需运行一次
125
  x, prompts = encoder.run(
126
  None,
127
  {
 
132
  "ssl_content": ssl_content,
133
  },
134
  )
135
+
 
 
 
 
 
 
136
  input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()]
137
+ best_y = None
138
+ best_idx = 0
139
+
140
+ for retry in range(max_retries):
141
  if self.stop_event.is_set():
142
  return None
143
+
144
+ # First Stage Decoder(每次重试都重新运行以获取新的随机采样状态)
145
+ y, y_emb, *present_key_values = first_stage_decoder.run(
146
+ None, {"x": x, "prompts": prompts}
147
+ )
148
+
149
+ # Stage Decoder Loop
150
+ idx: int = 0
151
+ for idx in range(0, 500):
152
+ if self.stop_event.is_set():
153
+ return None
154
+ input_feed = {
155
+ name: data
156
+ for name, data in zip(input_names, [y, y_emb, *present_key_values])
157
+ }
158
+ outputs = stage_decoder.run(None, input_feed)
159
+ y, y_emb, stop_condition_tensor, *present_key_values = outputs
160
+
161
+ if stop_condition_tensor:
162
+ break
163
+
164
+ # 保存最佳结果(tokens 数量最多的)
165
+ if idx > best_idx:
166
+ best_idx = idx
167
+ best_y = y.copy()
168
+
169
+ # 验证生成数量是否达到预期
170
+ if idx >= min_expected_tokens:
171
+ break # 成功,退出重试循环
172
+
173
+ # 否则继续重试
174
+
175
+ # 使用最佳结果
176
+ if best_y is None:
177
+ best_y = y
178
+ best_idx = idx
179
+
180
+ best_y[0, -1] = 0
181
+ return np.expand_dims(best_y[:, -best_idx:], axis=0)
182
 
183
 
184
  tts_client: GENIE = GENIE()