"""Audio head for speech-to-speech using a frozen pretrained TTS backbone. Architecture: Text → frozen LLM (SmolLM3-3B) → hidden states (llm_dim) → Projector MLP (trainable, llm_dim → backbone_dim) → Concat with codec embeddings → neutts-nano LlamaForCausalLM (frozen) → lm_head → speech token logits → NeuCodec codes → audio The frozen LLM is loaded for standalone S2S training. When used inside a full ASR pipeline (ASRModel), pre-computed LLM hidden states are passed directly and the internal LLM is not used. neutts-nano (neuphonic/neutts-nano) is a pretrained 24-layer LlamaForCausalLM (dim=576, ~117M params) that generates NeuCodec codes as <|speech_N|> tokens. Only the projector MLP is trained. NeuCodec uses a single FSQ codebook (levels=[4]*8, vocab=65536) at 50 tokens/sec, outputting 24kHz audio. Codes 0-65535 map to neutts-nano tokens <|speech_0|>..<|speech_65535|>. """ import logging from dataclasses import dataclass from typing import Iterator, Optional import torch import torch.nn as nn from torch.nn import functional as F # noqa: N812 from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import ModelOutput logger = logging.getLogger(__name__) # NeuCodec FSQ constants NEUCODEC_VOCAB_SIZE = 65536 NEUCODEC_SAMPLE_RATE = 24000 # Special token IDs used by S2SDataCollator (above NeuCodec vocab range) BOS_TOKEN = NEUCODEC_VOCAB_SIZE # 65536 EOS_TOKEN = NEUCODEC_VOCAB_SIZE + 1 # 65537 PAD_TOKEN = NEUCODEC_VOCAB_SIZE + 2 # 65538 TOTAL_VOCAB = NEUCODEC_VOCAB_SIZE + 3 # 65539 (for backwards compat) class AudioHeadConfig(PretrainedConfig): """Configuration for AudioHead with frozen TTS backbone + trainable projector.""" model_type = "audio_head" def __init__( self, tts_model_id: str = "neuphonic/neutts-nano", llm_model_id: str = "HuggingFaceTB/SmolLM3-3B", projector_hidden: int = 1024, max_audio_tokens: int = 500, neucodec_model_id: str = "neuphonic/neucodec", temperature: float = 1.0, top_k: int = 50, **kwargs, ): self.tts_model_id = tts_model_id self.llm_model_id = llm_model_id self.projector_hidden = projector_hidden self.max_audio_tokens = max_audio_tokens self.neucodec_model_id = neucodec_model_id self.temperature = temperature self.top_k = top_k super().__init__(**kwargs) @dataclass class AudioHeadOutput(ModelOutput): """Output of AudioHead forward pass. Attributes: loss: Cross-entropy loss when codec_labels are provided. codes: Generated NeuCodec codes in inference mode [batch, gen_len]. """ loss: Optional[torch.Tensor] = None codes: Optional[torch.Tensor] = None class AudioHead(PreTrainedModel): """Frozen TTS backbone + trainable projector for speech generation. Loads neutts-nano (a pretrained LlamaForCausalLM that generates NeuCodec tokens) and freezes it entirely. A frozen LLM converts text to hidden states, and a trainable MLP projector maps those hidden states into neutts-nano's input space. Standalone training: text_token_ids → frozen LLM → hidden states → projector → backbone → speech codes Pipeline inference: llm_hidden_states → projector → backbone → speech codes """ config_class = AudioHeadConfig # Prevent from_pretrained from using meta device init (which conflicts # with loading the backbone inside __init__ via its own from_pretrained) _supports_param_buffer_assignment = False def __init__(self, config: AudioHeadConfig): super().__init__(config) self.max_tokens = config.max_audio_tokens # Load frozen TTS backbone (skip if we're in meta device context, # which happens during from_pretrained — _load_backbone() is called after) self._backbone_loaded = False if not self._is_meta_init(): self._load_backbone(config) def _is_meta_init(self) -> bool: """Check if we're inside a meta device context manager.""" try: test = torch.empty(1) return test.device.type == "meta" except Exception: return False def _load_backbone(self, config: AudioHeadConfig) -> None: """Load the frozen TTS backbone, frozen LLM, and initialize the projector.""" if self._backbone_loaded: return # Load frozen TTS backbone (neutts-nano) logger.info("Loading TTS backbone: %s", config.tts_model_id) self.backbone = AutoModelForCausalLM.from_pretrained( config.tts_model_id, torch_dtype=torch.bfloat16, ) self.backbone.requires_grad_(False) self.backbone.eval() # Load tokenizer to resolve speech token IDs self.tts_tokenizer = AutoTokenizer.from_pretrained(config.tts_model_id) # Cache key token IDs self.speech_token_offset = self.tts_tokenizer.convert_tokens_to_ids("<|speech_0|>") self.speech_start_id = self.tts_tokenizer.convert_tokens_to_ids( "<|SPEECH_GENERATION_START|>" ) self.speech_end_id = self.tts_tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") # Load frozen LLM for standalone training (text → hidden states). # In pipeline mode (ASRModel), the duplicate is freed after creation # since ASRModel provides pre-computed hidden states. logger.info("Loading frozen LLM: %s", config.llm_model_id) self.llm = AutoModelForCausalLM.from_pretrained( config.llm_model_id, torch_dtype=torch.bfloat16, ) self.llm.requires_grad_(False) self.llm.eval() # Cache a prompt prefix so training hidden states are conditioned on # conversational context (matching inference where LLM sees full prompt). llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_id, trust_remote_code=True) prompt_enc = llm_tokenizer( "Speak the following text aloud: ", return_tensors="pt", add_special_tokens=True, ) self.register_buffer( "_prompt_prefix_ids", prompt_enc.input_ids, persistent=False, ) self._prompt_len = prompt_enc.input_ids.shape[1] llm_dim = self.llm.config.hidden_size # Auto-detect dimensions backbone_dim = self.backbone.config.hidden_size # 576 for neutts-nano # Trainable projector: 2-layer MLP (llm_dim → hidden → backbone_dim) # Linear → RMSNorm → GELU → Linear → RMSNorm # Final RMSNorm matches output scale to neutts-nano embedding norms. from transformers.models.llama.modeling_llama import LlamaRMSNorm self.projector = nn.Sequential( nn.Linear(llm_dim, config.projector_hidden), LlamaRMSNorm(config.projector_hidden, eps=1e-6), nn.GELU(), nn.Linear(config.projector_hidden, backbone_dim), LlamaRMSNorm(backbone_dim, eps=1e-6), ).to(torch.bfloat16) # Sampling parameters for inference self.temperature = config.temperature self.top_k = config.top_k # NeuCodec model (loaded lazily, frozen, inference only) self.neucodec_model = None self._backbone_loaded = True @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """Load AudioHead: config + projector weights from disk/Hub, backbone from HF Hub.""" from pathlib import Path from safetensors.torch import load_file path = Path(pretrained_model_name_or_path) # If not a local directory, download from Hub if not path.is_dir(): from huggingface_hub import snapshot_download path = Path(snapshot_download(pretrained_model_name_or_path)) # Load config config = AudioHeadConfig.from_pretrained(path) # Create model (loads backbone from HF Hub) model = cls(config) # Load projector weights from saved checkpoint safetensors_path = path / "model.safetensors" if safetensors_path.exists(): projector_state = load_file(safetensors_path) model.load_state_dict(projector_state, strict=False) logger.info("Loaded projector weights from %s", safetensors_path) return model def train(self, mode: bool = True): """Override to keep backbone and LLM in eval mode (disables dropout, etc.).""" super().train(mode) # Always keep frozen models in eval mode regardless of parent training state self.backbone.eval() if self.llm is not None: self.llm.eval() return self def _embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed tokens using the frozen backbone's embedding table.""" return self.backbone.model.embed_tokens(token_ids) def _codec_to_speech_ids(self, codec_codes: torch.Tensor) -> torch.Tensor: """Map NeuCodec codes [0, 65535] to neutts-nano speech token IDs.""" return codec_codes + self.speech_token_offset def _speech_ids_to_codec(self, speech_ids: torch.Tensor) -> torch.Tensor: """Map neutts-nano speech token IDs back to NeuCodec codes [0, 65535].""" return speech_ids - self.speech_token_offset def forward( self, text_token_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, llm_hidden_states: Optional[torch.Tensor] = None, codec_labels: Optional[torch.Tensor] = None, codec_input_ids: Optional[torch.Tensor] = None, codec_attention_mask: Optional[torch.Tensor] = None, **kwargs, # noqa: ARG002 — absorbs extra keys from Trainer ) -> AudioHeadOutput: """Forward pass for training or inference. Args: text_token_ids: Text token IDs [batch, seq_len] (LLM tokenizer vocab). Run through frozen LLM to get hidden states. Mutually exclusive with llm_hidden_states. attention_mask: Text attention mask [batch, seq_len] (1=real, 0=padding) llm_hidden_states: Pre-computed LLM hidden states [batch, seq_len, llm_dim]. Used in pipeline mode when ASRModel provides hidden states directly. codec_labels: Target NeuCodec codes [batch, audio_len] (-100 for ignore) codec_input_ids: Teacher-forced NeuCodec codes [batch, audio_len] codec_attention_mask: Codec attention mask [batch, audio_len] **kwargs: Absorbed silently (Trainer may pass extra keys). Returns: AudioHeadOutput with loss (training) or codes (inference). """ # Get LLM hidden states: either pre-computed or from frozen LLM if llm_hidden_states is not None: hidden_states = llm_hidden_states elif text_token_ids is not None: # Prepend cached prompt prefix so hidden states are conditioned on # conversational context (matching inference where LLM sees full prompt). batch_size = text_token_ids.shape[0] device = text_token_ids.device prompt = self._prompt_prefix_ids.expand(batch_size, -1).to(device) full_ids = torch.cat([prompt, text_token_ids], dim=1) if attention_mask is not None: prompt_mask = torch.ones( batch_size, self._prompt_len, device=device, dtype=attention_mask.dtype ) full_mask = torch.cat([prompt_mask, attention_mask], dim=1) else: full_mask = None with torch.no_grad(): llm_out = self.llm.model( input_ids=full_ids, attention_mask=full_mask, ) # Extract hidden states for text tokens only (skip prompt prefix) hidden_states = llm_out.last_hidden_state[:, self._prompt_len :] else: raise ValueError("Either text_token_ids or llm_hidden_states must be provided") batch_size, text_len = hidden_states.shape[:2] device = hidden_states.device # Project LLM hidden states into neutts-nano's input space via trainable projector. # Gradients flow through the projector (LLM hidden states are detached). prefix = self.projector(hidden_states) # [batch, text_len, backbone_dim] if codec_labels is None: # Inference: autoregressive generation codes = self._generate(prefix, attention_mask) return AudioHeadOutput(codes=codes) # Training: teacher forcing assert codec_input_ids is not None, "codec_input_ids required when codec_labels provided" # Map NeuCodec codes to neutts speech token IDs for embedding # codec_input_ids contains: BOS_TOKEN (65536), codec codes (0-65535), PAD (65538) # We need to map these to neutts-nano token space speech_input = self._map_collator_ids_to_speech(codec_input_ids) with torch.no_grad(): token_emb = self._embed_tokens(speech_input) # [batch, audio_len, 576] audio_len = token_emb.shape[1] # Concatenate: [projected_text, codec_token_embeddings] # prefix has grad (from projector), token_emb is detached (frozen embedding lookup) hidden = torch.cat([prefix, token_emb], dim=1) # Build 2D padding mask — backbone handles causal masking internally prefix_mask = ( attention_mask if attention_mask is not None else torch.ones(batch_size, text_len, device=device, dtype=torch.long) ) audio_mask = ( codec_attention_mask if codec_attention_mask is not None else torch.ones(batch_size, audio_len, device=device, dtype=torch.long) ) combined_mask = torch.cat([prefix_mask, audio_mask], dim=1) # Run through frozen backbone WITHOUT torch.no_grad(). # The backbone weights have requires_grad=False so they won't accumulate grads, # but PyTorch still builds the computation graph through the matmuls, allowing # gradients to flow back from the loss through backbone → hidden → prefix → projector. outputs = self.backbone.model( inputs_embeds=hidden, attention_mask=combined_mask, ) # Extract audio-position hidden states audio_hidden = outputs.last_hidden_state[:, text_len:] # [batch, audio_len, 576] # Project through frozen lm_head to get logits over full vocab. # Same principle: lm_head weights are frozen but gradients flow through the # matmul back to audio_hidden (and ultimately to the projector). logits = self.backbone.lm_head(audio_hidden) # [batch, audio_len, vocab_size] # Map codec_labels to speech token IDs for CE loss target speech_labels = self._map_collator_labels_to_speech(codec_labels) # Compute cross-entropy loss loss = F.cross_entropy( logits.view(-1, logits.size(-1)), speech_labels.view(-1), ignore_index=-100, ) return AudioHeadOutput(loss=loss) def _map_collator_ids_to_speech(self, codec_input_ids: torch.Tensor) -> torch.Tensor: """Map S2SDataCollator codec_input_ids to neutts-nano token IDs. S2SDataCollator produces: - BOS_TOKEN (65536) at position 0 - NeuCodec codes (0-65535) for real audio - PAD_TOKEN (65538) for padding Maps to: - BOS_TOKEN → <|SPEECH_GENERATION_START|> - codes 0-65535 → <|speech_0|>..<|speech_65535|> - PAD_TOKEN → pad_token_id """ result = codec_input_ids.clone() # Map BOS (65536) bos_mask = codec_input_ids == NEUCODEC_VOCAB_SIZE result[bos_mask] = self.speech_start_id # Map EOS (65537) eos_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 1) result[eos_mask] = self.speech_end_id # Map PAD (65538) pad_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 2) result[pad_mask] = self.tts_tokenizer.pad_token_id # Map codec codes (0-65535) → speech tokens codec_mask = codec_input_ids < NEUCODEC_VOCAB_SIZE result[codec_mask] = codec_input_ids[codec_mask] + self.speech_token_offset return result def _map_collator_labels_to_speech(self, codec_labels: torch.Tensor) -> torch.Tensor: """Map S2SDataCollator codec_labels to neutts-nano token IDs. codec_labels contains: - NeuCodec codes (0-65535) for real targets - EOS_TOKEN (65537) at the end - -100 for ignore positions """ result = codec_labels.clone() valid = codec_labels != -100 # Map EOS (65537) eos_mask = valid & (codec_labels == (NEUCODEC_VOCAB_SIZE + 1)) result[eos_mask] = self.speech_end_id # Map codec codes (0-65535) → speech tokens codec_mask = valid & (codec_labels < NEUCODEC_VOCAB_SIZE) result[codec_mask] = codec_labels[codec_mask] + self.speech_token_offset return result def _generate( self, prefix: torch.Tensor, prefix_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """AR generation with KV cache on frozen backbone. Args: prefix: Projected text embeddings [batch, text_len, 576]. prefix_mask: Attention mask for prefix tokens (unused for now, reserved for batched generation with padding). """ _ = prefix_mask # Reserved for future batched generation batch_size, text_len, _ = prefix.shape device = prefix.device all_codes = [] # Build initial input: prefix + SPEECH_GENERATION_START token start_token = torch.full( (batch_size, 1), self.speech_start_id, dtype=torch.long, device=device ) start_emb = self._embed_tokens(start_token) # [batch, 1, 576] hidden = torch.cat([prefix, start_emb], dim=1) # [batch, text_len+1, 576] position_ids = torch.arange(text_len + 1, device=device).unsqueeze(0).expand(batch_size, -1) # Initial forward through frozen backbone with torch.no_grad(): outputs = self.backbone.model( inputs_embeds=hidden, position_ids=position_ids, use_cache=True, ) past_key_values = outputs.past_key_values last_hidden = outputs.last_hidden_state[:, -1:] # [batch, 1, 576] for step in range(self.max_tokens): # Get logits from lm_head logits = self.backbone.lm_head(last_hidden.squeeze(1)) # [batch, vocab] # Mask to speech tokens only speech_logits = logits[ :, self.speech_token_offset : self.speech_token_offset + NEUCODEC_VOCAB_SIZE ] # Also check speech_end token end_logit = logits[:, self.speech_end_id : self.speech_end_id + 1] # Combine speech + end logits for sampling combined = torch.cat([speech_logits, end_logit], dim=-1) # [batch, 65537] # Apply temperature and top-k if self.temperature != 1.0: combined = combined / self.temperature if self.top_k > 0: topk_vals, _ = combined.topk(min(self.top_k, combined.size(-1))) combined[combined < topk_vals[:, -1:]] = float("-inf") probs = F.softmax(combined, dim=-1) sampled = torch.multinomial(probs, 1).squeeze(-1) # [batch] # Check for EOS (last position in combined = end token) is_eos = sampled == NEUCODEC_VOCAB_SIZE # index 65536 = end token if is_eos.all(): break # Map sampled index to NeuCodec code (0-65535) codec_code = sampled.clamp(0, NEUCODEC_VOCAB_SIZE - 1) all_codes.append(codec_code) # Map to speech token ID for next step embedding next_token_id = codec_code + self.speech_token_offset # For EOS items, use speech_end_id (won't matter as we'll stop) next_token_id[is_eos] = self.speech_end_id next_emb = self._embed_tokens(next_token_id.unsqueeze(1)) # [batch, 1, 576] next_pos = torch.full( (batch_size, 1), text_len + 1 + step + 1, dtype=torch.long, device=device, ) outputs = self.backbone.model( inputs_embeds=next_emb, position_ids=next_pos, past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values last_hidden = outputs.last_hidden_state # [batch, 1, 576] if all_codes: codes = torch.stack(all_codes, dim=1) # [batch, gen_len] else: codes = torch.empty(batch_size, 0, dtype=torch.long, device=device) return codes def state_dict(self, *args, **kwargs): """Only save projector weights (backbone is frozen/pretrained).""" full = super().state_dict(*args, **kwargs) return {k: v for k, v in full.items() if k.startswith("projector.")} def _load_neucodec(self): """Load frozen NeuCodec model for audio decoding.""" from neucodec import NeuCodec self.neucodec_model = NeuCodec.from_pretrained(self.config.neucodec_model_id) self.neucodec_model.eval() self.neucodec_model.requires_grad_(False) logger.info("Loaded frozen NeuCodec model for audio decoding") def decode_to_audio(self, codes: torch.Tensor) -> list[torch.Tensor]: """Decode NeuCodec FSQ tokens to audio waveforms. Args: codes: Codec tokens [batch, seq_len] (values 0-65535) Returns: List of audio waveform tensors (one per batch item) """ if self.neucodec_model is None: self._load_neucodec() assert self.neucodec_model is not None codes_3d = codes.unsqueeze(1).to(self.neucodec_model.device) with torch.no_grad(): audio_values = self.neucodec_model.decode_code(codes_3d) return [audio_values[i, 0] for i in range(audio_values.shape[0])] def generate_streaming( self, text_token_ids: Optional[torch.Tensor] = None, llm_hidden_states: Optional[torch.Tensor] = None, chunk_samples: int = 24000, ) -> Iterator[torch.Tensor]: """Generate audio and yield waveform chunks for streaming playback.""" output = self(text_token_ids=text_token_ids, llm_hidden_states=llm_hidden_states) codes = output.codes audios = self.decode_to_audio(codes) for audio in audios: for start in range(0, audio.shape[-1], chunk_samples): end = min(start + chunk_samples, audio.shape[-1]) yield audio[..., start:end]