Spaces:
Runtime error
Runtime error
Update src/chatterbox/models/t3/t3.py
Browse files- src/chatterbox/models/t3/t3.py +15 -15
src/chatterbox/models/t3/t3.py
CHANGED
|
@@ -286,21 +286,21 @@ class T3(nn.Module):
|
|
| 286 |
logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
|
| 287 |
|
| 288 |
# # Run normal generate method, which calls our custom extended methods
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
|
| 305 |
device = embeds.device
|
| 306 |
|
|
|
|
| 286 |
logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
|
| 287 |
|
| 288 |
# # Run normal generate method, which calls our custom extended methods
|
| 289 |
+
return self.patched_model.generate(
|
| 290 |
+
inputs=initial_speech_tokens,
|
| 291 |
+
decoder_cond=embeds,
|
| 292 |
+
bos_token_id=self.hp.start_speech_token,
|
| 293 |
+
eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
|
| 294 |
+
pad_token_id=self.hp.stop_speech_token,
|
| 295 |
+
max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
| 296 |
+
num_return_sequences=num_return_sequences,
|
| 297 |
+
temperature=temperature,
|
| 298 |
+
min_p=min_p,
|
| 299 |
+
length_penalty=length_penalty,
|
| 300 |
+
repetition_penalty=repetition_penalty,
|
| 301 |
+
do_sample=do_sample,
|
| 302 |
+
# cache_implementation=None if not self.compiled else "static",
|
| 303 |
+
)
|
| 304 |
|
| 305 |
device = embeds.device
|
| 306 |
|