alexue4 commited on
Commit
7e241f9
·
verified ·
1 Parent(s): ae44b9c

Update src/chatterbox/models/t3/t3.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/models/t3/t3.py +21 -17
src/chatterbox/models/t3/t3.py CHANGED
@@ -286,21 +286,21 @@ class T3(nn.Module):
286
  logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
287
 
288
  # # Run normal generate method, which calls our custom extended methods
289
- return self.patched_model.generate(
290
- inputs=initial_speech_tokens,
291
- decoder_cond=embeds,
292
- bos_token_id=self.hp.start_speech_token,
293
- eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
294
- pad_token_id=self.hp.stop_speech_token,
295
- max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
296
- num_return_sequences=num_return_sequences,
297
- temperature=temperature,
298
- min_p=min_p,
299
- length_penalty=length_penalty,
300
- repetition_penalty=repetition_penalty,
301
- do_sample=do_sample,
302
- # cache_implementation=None if not self.compiled else "static",
303
- )
304
 
305
  device = embeds.device
306
 
@@ -371,8 +371,12 @@ class T3(nn.Module):
371
  # # Convert logits to probabilities and sample the next token.
372
  # probs = torch.softmax(logits, dim=-1)
373
  # next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
374
- gumbel = -torch.log(-torch.log(torch.rand_like(logits)))
375
- next_token = (logits / temperature + gumbel).argmax(dim=-1)
 
 
 
 
376
  next_token = next_token.unsqueeze(-1)
377
  step_sampling_total += (time.perf_counter() - step_t0)
378
 
 
286
  logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
287
 
288
  # # Run normal generate method, which calls our custom extended methods
289
+ # return self.patched_model.generate(
290
+ # inputs=initial_speech_tokens,
291
+ # decoder_cond=embeds,
292
+ # bos_token_id=self.hp.start_speech_token,
293
+ # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
294
+ # pad_token_id=self.hp.stop_speech_token,
295
+ # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
296
+ # num_return_sequences=num_return_sequences,
297
+ # temperature=temperature,
298
+ # min_p=min_p,
299
+ # length_penalty=length_penalty,
300
+ # repetition_penalty=repetition_penalty,
301
+ # do_sample=do_sample,
302
+ # # cache_implementation=None if not self.compiled else "static",
303
+ # )
304
 
305
  device = embeds.device
306
 
 
371
  # # Convert logits to probabilities and sample the next token.
372
  # probs = torch.softmax(logits, dim=-1)
373
  # next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
374
+ top_k = 50
375
+ vals, idx = logits.topk(top_k, dim=-1)
376
+ masked = torch.full_like(logits, float('-inf'))
377
+ masked.scatter_(1, idx, vals)
378
+ g = -torch.log(-torch.log(torch.rand_like(masked)))
379
+ next_token = ((masked / temperature) + g).argmax(dim=-1)
380
  next_token = next_token.unsqueeze(-1)
381
  step_sampling_total += (time.perf_counter() - step_t0)
382