alexue4 commited on
Commit
ae44b9c
·
verified ·
1 Parent(s): a26b5ed

Update src/chatterbox/models/t3/t3.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/models/t3/t3.py +15 -15
src/chatterbox/models/t3/t3.py CHANGED
@@ -286,21 +286,21 @@ class T3(nn.Module):
286
  logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
287
 
288
  # # Run normal generate method, which calls our custom extended methods
289
- # return self.patched_model.generate(
290
- # inputs=initial_speech_tokens,
291
- # decoder_cond=embeds,
292
- # bos_token_id=self.hp.start_speech_token,
293
- # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
294
- # pad_token_id=self.hp.stop_speech_token,
295
- # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
296
- # num_return_sequences=num_return_sequences,
297
- # temperature=temperature,
298
- # min_p=min_p,
299
- # length_penalty=length_penalty,
300
- # repetition_penalty=repetition_penalty,
301
- # do_sample=do_sample,
302
- # # cache_implementation=None if not self.compiled else "static",
303
- # )
304
 
305
  device = embeds.device
306
 
 
286
  logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
287
 
288
  # # Run normal generate method, which calls our custom extended methods
289
+ return self.patched_model.generate(
290
+ inputs=initial_speech_tokens,
291
+ decoder_cond=embeds,
292
+ bos_token_id=self.hp.start_speech_token,
293
+ eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
294
+ pad_token_id=self.hp.stop_speech_token,
295
+ max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
296
+ num_return_sequences=num_return_sequences,
297
+ temperature=temperature,
298
+ min_p=min_p,
299
+ length_penalty=length_penalty,
300
+ repetition_penalty=repetition_penalty,
301
+ do_sample=do_sample,
302
+ # cache_implementation=None if not self.compiled else "static",
303
+ )
304
 
305
  device = embeds.device
306