File size: 23,532 Bytes
1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca 1467bed 64278ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 | """Audio head for speech-to-speech using a frozen pretrained TTS backbone.
Architecture:
Text → frozen LLM (SmolLM3-3B) → hidden states (llm_dim)
→ Projector MLP (trainable, llm_dim → backbone_dim)
→ Concat with codec embeddings → neutts-nano LlamaForCausalLM (frozen)
→ lm_head → speech token logits → NeuCodec codes → audio
The frozen LLM is loaded for standalone S2S training. When used inside a full
ASR pipeline (ASRModel), pre-computed LLM hidden states are passed directly
and the internal LLM is not used.
neutts-nano (neuphonic/neutts-nano) is a pretrained 24-layer LlamaForCausalLM
(dim=576, ~117M params) that generates NeuCodec codes as <|speech_N|> tokens.
Only the projector MLP is trained.
NeuCodec uses a single FSQ codebook (levels=[4]*8, vocab=65536) at 50 tokens/sec,
outputting 24kHz audio. Codes 0-65535 map to neutts-nano tokens <|speech_0|>..<|speech_65535|>.
"""
import logging
from dataclasses import dataclass
from typing import Iterator, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F # noqa: N812
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
logger = logging.getLogger(__name__)
# NeuCodec FSQ constants
NEUCODEC_VOCAB_SIZE = 65536
NEUCODEC_SAMPLE_RATE = 24000
# Special token IDs used by S2SDataCollator (above NeuCodec vocab range)
BOS_TOKEN = NEUCODEC_VOCAB_SIZE # 65536
EOS_TOKEN = NEUCODEC_VOCAB_SIZE + 1 # 65537
PAD_TOKEN = NEUCODEC_VOCAB_SIZE + 2 # 65538
TOTAL_VOCAB = NEUCODEC_VOCAB_SIZE + 3 # 65539 (for backwards compat)
class AudioHeadConfig(PretrainedConfig):
"""Configuration for AudioHead with frozen TTS backbone + trainable projector."""
model_type = "audio_head"
def __init__(
self,
tts_model_id: str = "neuphonic/neutts-nano",
llm_model_id: str = "HuggingFaceTB/SmolLM3-3B",
projector_hidden: int = 1024,
max_audio_tokens: int = 500,
neucodec_model_id: str = "neuphonic/neucodec",
temperature: float = 1.0,
top_k: int = 50,
**kwargs,
):
self.tts_model_id = tts_model_id
self.llm_model_id = llm_model_id
self.projector_hidden = projector_hidden
self.max_audio_tokens = max_audio_tokens
self.neucodec_model_id = neucodec_model_id
self.temperature = temperature
self.top_k = top_k
super().__init__(**kwargs)
@dataclass
class AudioHeadOutput(ModelOutput):
"""Output of AudioHead forward pass.
Attributes:
loss: Cross-entropy loss when codec_labels are provided.
codes: Generated NeuCodec codes in inference mode [batch, gen_len].
"""
loss: Optional[torch.Tensor] = None
codes: Optional[torch.Tensor] = None
class AudioHead(PreTrainedModel):
"""Frozen TTS backbone + trainable projector for speech generation.
Loads neutts-nano (a pretrained LlamaForCausalLM that generates NeuCodec tokens)
and freezes it entirely. A frozen LLM converts text to hidden states, and a
trainable MLP projector maps those hidden states into neutts-nano's input space.
Standalone training: text_token_ids → frozen LLM → hidden states → projector → backbone → speech codes
Pipeline inference: llm_hidden_states → projector → backbone → speech codes
"""
config_class = AudioHeadConfig
# Prevent from_pretrained from using meta device init (which conflicts
# with loading the backbone inside __init__ via its own from_pretrained)
_supports_param_buffer_assignment = False
def __init__(self, config: AudioHeadConfig):
super().__init__(config)
self.max_tokens = config.max_audio_tokens
# Load frozen TTS backbone (skip if we're in meta device context,
# which happens during from_pretrained — _load_backbone() is called after)
self._backbone_loaded = False
if not self._is_meta_init():
self._load_backbone(config)
def _is_meta_init(self) -> bool:
"""Check if we're inside a meta device context manager."""
try:
test = torch.empty(1)
return test.device.type == "meta"
except Exception:
return False
def _load_backbone(self, config: AudioHeadConfig) -> None:
"""Load the frozen TTS backbone, frozen LLM, and initialize the projector."""
if self._backbone_loaded:
return
# Load frozen TTS backbone (neutts-nano)
logger.info("Loading TTS backbone: %s", config.tts_model_id)
self.backbone = AutoModelForCausalLM.from_pretrained(
config.tts_model_id,
torch_dtype=torch.bfloat16,
)
self.backbone.requires_grad_(False)
self.backbone.eval()
# Load tokenizer to resolve speech token IDs
self.tts_tokenizer = AutoTokenizer.from_pretrained(config.tts_model_id)
# Cache key token IDs
self.speech_token_offset = self.tts_tokenizer.convert_tokens_to_ids("<|speech_0|>")
self.speech_start_id = self.tts_tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
self.speech_end_id = self.tts_tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
# Load frozen LLM for standalone training (text → hidden states).
# In pipeline mode (ASRModel), the duplicate is freed after creation
# since ASRModel provides pre-computed hidden states.
logger.info("Loading frozen LLM: %s", config.llm_model_id)
self.llm = AutoModelForCausalLM.from_pretrained(
config.llm_model_id,
torch_dtype=torch.bfloat16,
)
self.llm.requires_grad_(False)
self.llm.eval()
# Cache a prompt prefix so training hidden states are conditioned on
# conversational context (matching inference where LLM sees full prompt).
llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_id, trust_remote_code=True)
prompt_enc = llm_tokenizer(
"Speak the following text aloud: ",
return_tensors="pt",
add_special_tokens=True,
)
self.register_buffer(
"_prompt_prefix_ids",
prompt_enc.input_ids,
persistent=False,
)
self._prompt_len = prompt_enc.input_ids.shape[1]
llm_dim = self.llm.config.hidden_size
# Auto-detect dimensions
backbone_dim = self.backbone.config.hidden_size # 576 for neutts-nano
# Trainable projector: 2-layer MLP (llm_dim → hidden → backbone_dim)
# Linear → RMSNorm → GELU → Linear → RMSNorm
# Final RMSNorm matches output scale to neutts-nano embedding norms.
from transformers.models.llama.modeling_llama import LlamaRMSNorm
self.projector = nn.Sequential(
nn.Linear(llm_dim, config.projector_hidden),
LlamaRMSNorm(config.projector_hidden, eps=1e-6),
nn.GELU(),
nn.Linear(config.projector_hidden, backbone_dim),
LlamaRMSNorm(backbone_dim, eps=1e-6),
).to(torch.bfloat16)
# Sampling parameters for inference
self.temperature = config.temperature
self.top_k = config.top_k
# NeuCodec model (loaded lazily, frozen, inference only)
self.neucodec_model = None
self._backbone_loaded = True
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""Load AudioHead: config + projector weights from disk/Hub, backbone from HF Hub."""
from pathlib import Path
from safetensors.torch import load_file
path = Path(pretrained_model_name_or_path)
# If not a local directory, download from Hub
if not path.is_dir():
from huggingface_hub import snapshot_download
path = Path(snapshot_download(pretrained_model_name_or_path))
# Load config
config = AudioHeadConfig.from_pretrained(path)
# Create model (loads backbone from HF Hub)
model = cls(config)
# Load projector weights from saved checkpoint
safetensors_path = path / "model.safetensors"
if safetensors_path.exists():
projector_state = load_file(safetensors_path)
model.load_state_dict(projector_state, strict=False)
logger.info("Loaded projector weights from %s", safetensors_path)
return model
def train(self, mode: bool = True):
"""Override to keep backbone and LLM in eval mode (disables dropout, etc.)."""
super().train(mode)
# Always keep frozen models in eval mode regardless of parent training state
self.backbone.eval()
if self.llm is not None:
self.llm.eval()
return self
def _embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
"""Embed tokens using the frozen backbone's embedding table."""
return self.backbone.model.embed_tokens(token_ids)
def _codec_to_speech_ids(self, codec_codes: torch.Tensor) -> torch.Tensor:
"""Map NeuCodec codes [0, 65535] to neutts-nano speech token IDs."""
return codec_codes + self.speech_token_offset
def _speech_ids_to_codec(self, speech_ids: torch.Tensor) -> torch.Tensor:
"""Map neutts-nano speech token IDs back to NeuCodec codes [0, 65535]."""
return speech_ids - self.speech_token_offset
def forward(
self,
text_token_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
llm_hidden_states: Optional[torch.Tensor] = None,
codec_labels: Optional[torch.Tensor] = None,
codec_input_ids: Optional[torch.Tensor] = None,
codec_attention_mask: Optional[torch.Tensor] = None,
**kwargs, # noqa: ARG002 — absorbs extra keys from Trainer
) -> AudioHeadOutput:
"""Forward pass for training or inference.
Args:
text_token_ids: Text token IDs [batch, seq_len] (LLM tokenizer vocab).
Run through frozen LLM to get hidden states. Mutually exclusive
with llm_hidden_states.
attention_mask: Text attention mask [batch, seq_len] (1=real, 0=padding)
llm_hidden_states: Pre-computed LLM hidden states [batch, seq_len, llm_dim].
Used in pipeline mode when ASRModel provides hidden states directly.
codec_labels: Target NeuCodec codes [batch, audio_len] (-100 for ignore)
codec_input_ids: Teacher-forced NeuCodec codes [batch, audio_len]
codec_attention_mask: Codec attention mask [batch, audio_len]
**kwargs: Absorbed silently (Trainer may pass extra keys).
Returns:
AudioHeadOutput with loss (training) or codes (inference).
"""
# Get LLM hidden states: either pre-computed or from frozen LLM
if llm_hidden_states is not None:
hidden_states = llm_hidden_states
elif text_token_ids is not None:
# Prepend cached prompt prefix so hidden states are conditioned on
# conversational context (matching inference where LLM sees full prompt).
batch_size = text_token_ids.shape[0]
device = text_token_ids.device
prompt = self._prompt_prefix_ids.expand(batch_size, -1).to(device)
full_ids = torch.cat([prompt, text_token_ids], dim=1)
if attention_mask is not None:
prompt_mask = torch.ones(
batch_size, self._prompt_len, device=device, dtype=attention_mask.dtype
)
full_mask = torch.cat([prompt_mask, attention_mask], dim=1)
else:
full_mask = None
with torch.no_grad():
llm_out = self.llm.model(
input_ids=full_ids,
attention_mask=full_mask,
)
# Extract hidden states for text tokens only (skip prompt prefix)
hidden_states = llm_out.last_hidden_state[:, self._prompt_len :]
else:
raise ValueError("Either text_token_ids or llm_hidden_states must be provided")
batch_size, text_len = hidden_states.shape[:2]
device = hidden_states.device
# Project LLM hidden states into neutts-nano's input space via trainable projector.
# Gradients flow through the projector (LLM hidden states are detached).
prefix = self.projector(hidden_states) # [batch, text_len, backbone_dim]
if codec_labels is None:
# Inference: autoregressive generation
codes = self._generate(prefix, attention_mask)
return AudioHeadOutput(codes=codes)
# Training: teacher forcing
assert codec_input_ids is not None, "codec_input_ids required when codec_labels provided"
# Map NeuCodec codes to neutts speech token IDs for embedding
# codec_input_ids contains: BOS_TOKEN (65536), codec codes (0-65535), PAD (65538)
# We need to map these to neutts-nano token space
speech_input = self._map_collator_ids_to_speech(codec_input_ids)
with torch.no_grad():
token_emb = self._embed_tokens(speech_input) # [batch, audio_len, 576]
audio_len = token_emb.shape[1]
# Concatenate: [projected_text, codec_token_embeddings]
# prefix has grad (from projector), token_emb is detached (frozen embedding lookup)
hidden = torch.cat([prefix, token_emb], dim=1)
# Build 2D padding mask — backbone handles causal masking internally
prefix_mask = (
attention_mask
if attention_mask is not None
else torch.ones(batch_size, text_len, device=device, dtype=torch.long)
)
audio_mask = (
codec_attention_mask
if codec_attention_mask is not None
else torch.ones(batch_size, audio_len, device=device, dtype=torch.long)
)
combined_mask = torch.cat([prefix_mask, audio_mask], dim=1)
# Run through frozen backbone WITHOUT torch.no_grad().
# The backbone weights have requires_grad=False so they won't accumulate grads,
# but PyTorch still builds the computation graph through the matmuls, allowing
# gradients to flow back from the loss through backbone → hidden → prefix → projector.
outputs = self.backbone.model(
inputs_embeds=hidden,
attention_mask=combined_mask,
)
# Extract audio-position hidden states
audio_hidden = outputs.last_hidden_state[:, text_len:] # [batch, audio_len, 576]
# Project through frozen lm_head to get logits over full vocab.
# Same principle: lm_head weights are frozen but gradients flow through the
# matmul back to audio_hidden (and ultimately to the projector).
logits = self.backbone.lm_head(audio_hidden) # [batch, audio_len, vocab_size]
# Map codec_labels to speech token IDs for CE loss target
speech_labels = self._map_collator_labels_to_speech(codec_labels)
# Compute cross-entropy loss
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
speech_labels.view(-1),
ignore_index=-100,
)
return AudioHeadOutput(loss=loss)
def _map_collator_ids_to_speech(self, codec_input_ids: torch.Tensor) -> torch.Tensor:
"""Map S2SDataCollator codec_input_ids to neutts-nano token IDs.
S2SDataCollator produces:
- BOS_TOKEN (65536) at position 0
- NeuCodec codes (0-65535) for real audio
- PAD_TOKEN (65538) for padding
Maps to:
- BOS_TOKEN → <|SPEECH_GENERATION_START|>
- codes 0-65535 → <|speech_0|>..<|speech_65535|>
- PAD_TOKEN → pad_token_id
"""
result = codec_input_ids.clone()
# Map BOS (65536)
bos_mask = codec_input_ids == NEUCODEC_VOCAB_SIZE
result[bos_mask] = self.speech_start_id
# Map EOS (65537)
eos_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 1)
result[eos_mask] = self.speech_end_id
# Map PAD (65538)
pad_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 2)
result[pad_mask] = self.tts_tokenizer.pad_token_id
# Map codec codes (0-65535) → speech tokens
codec_mask = codec_input_ids < NEUCODEC_VOCAB_SIZE
result[codec_mask] = codec_input_ids[codec_mask] + self.speech_token_offset
return result
def _map_collator_labels_to_speech(self, codec_labels: torch.Tensor) -> torch.Tensor:
"""Map S2SDataCollator codec_labels to neutts-nano token IDs.
codec_labels contains:
- NeuCodec codes (0-65535) for real targets
- EOS_TOKEN (65537) at the end
- -100 for ignore positions
"""
result = codec_labels.clone()
valid = codec_labels != -100
# Map EOS (65537)
eos_mask = valid & (codec_labels == (NEUCODEC_VOCAB_SIZE + 1))
result[eos_mask] = self.speech_end_id
# Map codec codes (0-65535) → speech tokens
codec_mask = valid & (codec_labels < NEUCODEC_VOCAB_SIZE)
result[codec_mask] = codec_labels[codec_mask] + self.speech_token_offset
return result
def _generate(
self, prefix: torch.Tensor, prefix_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""AR generation with KV cache on frozen backbone.
Args:
prefix: Projected text embeddings [batch, text_len, 576].
prefix_mask: Attention mask for prefix tokens (unused for now,
reserved for batched generation with padding).
"""
_ = prefix_mask # Reserved for future batched generation
batch_size, text_len, _ = prefix.shape
device = prefix.device
all_codes = []
# Build initial input: prefix + SPEECH_GENERATION_START token
start_token = torch.full(
(batch_size, 1), self.speech_start_id, dtype=torch.long, device=device
)
start_emb = self._embed_tokens(start_token) # [batch, 1, 576]
hidden = torch.cat([prefix, start_emb], dim=1) # [batch, text_len+1, 576]
position_ids = torch.arange(text_len + 1, device=device).unsqueeze(0).expand(batch_size, -1)
# Initial forward through frozen backbone
with torch.no_grad():
outputs = self.backbone.model(
inputs_embeds=hidden,
position_ids=position_ids,
use_cache=True,
)
past_key_values = outputs.past_key_values
last_hidden = outputs.last_hidden_state[:, -1:] # [batch, 1, 576]
for step in range(self.max_tokens):
# Get logits from lm_head
logits = self.backbone.lm_head(last_hidden.squeeze(1)) # [batch, vocab]
# Mask to speech tokens only
speech_logits = logits[
:, self.speech_token_offset : self.speech_token_offset + NEUCODEC_VOCAB_SIZE
]
# Also check speech_end token
end_logit = logits[:, self.speech_end_id : self.speech_end_id + 1]
# Combine speech + end logits for sampling
combined = torch.cat([speech_logits, end_logit], dim=-1) # [batch, 65537]
# Apply temperature and top-k
if self.temperature != 1.0:
combined = combined / self.temperature
if self.top_k > 0:
topk_vals, _ = combined.topk(min(self.top_k, combined.size(-1)))
combined[combined < topk_vals[:, -1:]] = float("-inf")
probs = F.softmax(combined, dim=-1)
sampled = torch.multinomial(probs, 1).squeeze(-1) # [batch]
# Check for EOS (last position in combined = end token)
is_eos = sampled == NEUCODEC_VOCAB_SIZE # index 65536 = end token
if is_eos.all():
break
# Map sampled index to NeuCodec code (0-65535)
codec_code = sampled.clamp(0, NEUCODEC_VOCAB_SIZE - 1)
all_codes.append(codec_code)
# Map to speech token ID for next step embedding
next_token_id = codec_code + self.speech_token_offset
# For EOS items, use speech_end_id (won't matter as we'll stop)
next_token_id[is_eos] = self.speech_end_id
next_emb = self._embed_tokens(next_token_id.unsqueeze(1)) # [batch, 1, 576]
next_pos = torch.full(
(batch_size, 1),
text_len + 1 + step + 1,
dtype=torch.long,
device=device,
)
outputs = self.backbone.model(
inputs_embeds=next_emb,
position_ids=next_pos,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
last_hidden = outputs.last_hidden_state # [batch, 1, 576]
if all_codes:
codes = torch.stack(all_codes, dim=1) # [batch, gen_len]
else:
codes = torch.empty(batch_size, 0, dtype=torch.long, device=device)
return codes
def state_dict(self, *args, **kwargs):
"""Only save projector weights (backbone is frozen/pretrained)."""
full = super().state_dict(*args, **kwargs)
return {k: v for k, v in full.items() if k.startswith("projector.")}
def _load_neucodec(self):
"""Load frozen NeuCodec model for audio decoding."""
from neucodec import NeuCodec
self.neucodec_model = NeuCodec.from_pretrained(self.config.neucodec_model_id)
self.neucodec_model.eval()
self.neucodec_model.requires_grad_(False)
logger.info("Loaded frozen NeuCodec model for audio decoding")
def decode_to_audio(self, codes: torch.Tensor) -> list[torch.Tensor]:
"""Decode NeuCodec FSQ tokens to audio waveforms.
Args:
codes: Codec tokens [batch, seq_len] (values 0-65535)
Returns:
List of audio waveform tensors (one per batch item)
"""
if self.neucodec_model is None:
self._load_neucodec()
assert self.neucodec_model is not None
codes_3d = codes.unsqueeze(1).to(self.neucodec_model.device)
with torch.no_grad():
audio_values = self.neucodec_model.decode_code(codes_3d)
return [audio_values[i, 0] for i in range(audio_values.shape[0])]
def generate_streaming(
self,
text_token_ids: Optional[torch.Tensor] = None,
llm_hidden_states: Optional[torch.Tensor] = None,
chunk_samples: int = 24000,
) -> Iterator[torch.Tensor]:
"""Generate audio and yield waveform chunks for streaming playback."""
output = self(text_token_ids=text_token_ids, llm_hidden_states=llm_hidden_states)
codes = output.codes
audios = self.decode_to_audio(codes)
for audio in audios:
for start in range(0, audio.shape[-1], chunk_samples):
end = min(start + chunk_samples, audio.shape[-1])
yield audio[..., start:end]
|