Spaces:
Runtime error
Runtime error
Update src/chatterbox/models/t3/t3.py
Browse files- src/chatterbox/models/t3/t3.py +21 -17
src/chatterbox/models/t3/t3.py
CHANGED
|
@@ -286,21 +286,21 @@ class T3(nn.Module):
|
|
| 286 |
logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
|
| 287 |
|
| 288 |
# # Run normal generate method, which calls our custom extended methods
|
| 289 |
-
return self.patched_model.generate(
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
)
|
| 304 |
|
| 305 |
device = embeds.device
|
| 306 |
|
|
@@ -371,8 +371,12 @@ class T3(nn.Module):
|
|
| 371 |
# # Convert logits to probabilities and sample the next token.
|
| 372 |
# probs = torch.softmax(logits, dim=-1)
|
| 373 |
# next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
next_token = next_token.unsqueeze(-1)
|
| 377 |
step_sampling_total += (time.perf_counter() - step_t0)
|
| 378 |
|
|
|
|
| 286 |
logger.warning(f"t3.inference: patch/compile backend took {time.perf_counter() - compile_start:.4f}s (compiled={self.compiled})")
|
| 287 |
|
| 288 |
# # Run normal generate method, which calls our custom extended methods
|
| 289 |
+
# return self.patched_model.generate(
|
| 290 |
+
# inputs=initial_speech_tokens,
|
| 291 |
+
# decoder_cond=embeds,
|
| 292 |
+
# bos_token_id=self.hp.start_speech_token,
|
| 293 |
+
# eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
|
| 294 |
+
# pad_token_id=self.hp.stop_speech_token,
|
| 295 |
+
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
| 296 |
+
# num_return_sequences=num_return_sequences,
|
| 297 |
+
# temperature=temperature,
|
| 298 |
+
# min_p=min_p,
|
| 299 |
+
# length_penalty=length_penalty,
|
| 300 |
+
# repetition_penalty=repetition_penalty,
|
| 301 |
+
# do_sample=do_sample,
|
| 302 |
+
# # cache_implementation=None if not self.compiled else "static",
|
| 303 |
+
# )
|
| 304 |
|
| 305 |
device = embeds.device
|
| 306 |
|
|
|
|
| 371 |
# # Convert logits to probabilities and sample the next token.
|
| 372 |
# probs = torch.softmax(logits, dim=-1)
|
| 373 |
# next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
|
| 374 |
+
top_k = 50
|
| 375 |
+
vals, idx = logits.topk(top_k, dim=-1)
|
| 376 |
+
masked = torch.full_like(logits, float('-inf'))
|
| 377 |
+
masked.scatter_(1, idx, vals)
|
| 378 |
+
g = -torch.log(-torch.log(torch.rand_like(masked)))
|
| 379 |
+
next_token = ((masked / temperature) + g).argmax(dim=-1)
|
| 380 |
next_token = next_token.unsqueeze(-1)
|
| 381 |
step_sampling_total += (time.perf_counter() - step_t0)
|
| 382 |
|