| | """Audio head for speech-to-speech using a frozen pretrained TTS backbone. |
| | |
| | Architecture: |
| | Text → frozen LLM (SmolLM3-3B) → hidden states (llm_dim) |
| | → Projector MLP (trainable, llm_dim → backbone_dim) |
| | → Concat with codec embeddings → neutts-nano LlamaForCausalLM (frozen) |
| | → lm_head → speech token logits → NeuCodec codes → audio |
| | |
| | The frozen LLM is loaded for standalone S2S training. When used inside a full |
| | ASR pipeline (ASRModel), pre-computed LLM hidden states are passed directly |
| | and the internal LLM is not used. |
| | |
| | neutts-nano (neuphonic/neutts-nano) is a pretrained 24-layer LlamaForCausalLM |
| | (dim=576, ~117M params) that generates NeuCodec codes as <|speech_N|> tokens. |
| | Only the projector MLP is trained. |
| | |
| | NeuCodec uses a single FSQ codebook (levels=[4]*8, vocab=65536) at 50 tokens/sec, |
| | outputting 24kHz audio. Codes 0-65535 map to neutts-nano tokens <|speech_0|>..<|speech_65535|>. |
| | """ |
| |
|
| | import logging |
| | from dataclasses import dataclass |
| | from typing import Iterator, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import ModelOutput |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | NEUCODEC_VOCAB_SIZE = 65536 |
| | NEUCODEC_SAMPLE_RATE = 24000 |
| |
|
| | |
| | BOS_TOKEN = NEUCODEC_VOCAB_SIZE |
| | EOS_TOKEN = NEUCODEC_VOCAB_SIZE + 1 |
| | PAD_TOKEN = NEUCODEC_VOCAB_SIZE + 2 |
| | TOTAL_VOCAB = NEUCODEC_VOCAB_SIZE + 3 |
| |
|
| |
|
| | class AudioHeadConfig(PretrainedConfig): |
| | """Configuration for AudioHead with frozen TTS backbone + trainable projector.""" |
| |
|
| | model_type = "audio_head" |
| |
|
| | def __init__( |
| | self, |
| | tts_model_id: str = "neuphonic/neutts-nano", |
| | llm_model_id: str = "HuggingFaceTB/SmolLM3-3B", |
| | projector_hidden: int = 1024, |
| | max_audio_tokens: int = 500, |
| | neucodec_model_id: str = "neuphonic/neucodec", |
| | temperature: float = 1.0, |
| | top_k: int = 50, |
| | **kwargs, |
| | ): |
| | self.tts_model_id = tts_model_id |
| | self.llm_model_id = llm_model_id |
| | self.projector_hidden = projector_hidden |
| | self.max_audio_tokens = max_audio_tokens |
| | self.neucodec_model_id = neucodec_model_id |
| | self.temperature = temperature |
| | self.top_k = top_k |
| | super().__init__(**kwargs) |
| |
|
| |
|
| | @dataclass |
| | class AudioHeadOutput(ModelOutput): |
| | """Output of AudioHead forward pass. |
| | |
| | Attributes: |
| | loss: Cross-entropy loss when codec_labels are provided. |
| | codes: Generated NeuCodec codes in inference mode [batch, gen_len]. |
| | """ |
| |
|
| | loss: Optional[torch.Tensor] = None |
| | codes: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class AudioHead(PreTrainedModel): |
| | """Frozen TTS backbone + trainable projector for speech generation. |
| | |
| | Loads neutts-nano (a pretrained LlamaForCausalLM that generates NeuCodec tokens) |
| | and freezes it entirely. A frozen LLM converts text to hidden states, and a |
| | trainable MLP projector maps those hidden states into neutts-nano's input space. |
| | |
| | Standalone training: text_token_ids → frozen LLM → hidden states → projector → backbone → speech codes |
| | Pipeline inference: llm_hidden_states → projector → backbone → speech codes |
| | """ |
| |
|
| | config_class = AudioHeadConfig |
| | |
| | |
| | _supports_param_buffer_assignment = False |
| |
|
| | def __init__(self, config: AudioHeadConfig): |
| | super().__init__(config) |
| | self.max_tokens = config.max_audio_tokens |
| |
|
| | |
| | |
| | self._backbone_loaded = False |
| | if not self._is_meta_init(): |
| | self._load_backbone(config) |
| |
|
| | def _is_meta_init(self) -> bool: |
| | """Check if we're inside a meta device context manager.""" |
| | try: |
| | test = torch.empty(1) |
| | return test.device.type == "meta" |
| | except Exception: |
| | return False |
| |
|
| | def _load_backbone(self, config: AudioHeadConfig) -> None: |
| | """Load the frozen TTS backbone, frozen LLM, and initialize the projector.""" |
| | if self._backbone_loaded: |
| | return |
| |
|
| | |
| | logger.info("Loading TTS backbone: %s", config.tts_model_id) |
| | self.backbone = AutoModelForCausalLM.from_pretrained( |
| | config.tts_model_id, |
| | torch_dtype=torch.bfloat16, |
| | ) |
| | self.backbone.requires_grad_(False) |
| | self.backbone.eval() |
| |
|
| | |
| | self.tts_tokenizer = AutoTokenizer.from_pretrained(config.tts_model_id) |
| |
|
| | |
| | self.speech_token_offset = self.tts_tokenizer.convert_tokens_to_ids("<|speech_0|>") |
| | self.speech_start_id = self.tts_tokenizer.convert_tokens_to_ids( |
| | "<|SPEECH_GENERATION_START|>" |
| | ) |
| | self.speech_end_id = self.tts_tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") |
| |
|
| | |
| | |
| | |
| | logger.info("Loading frozen LLM: %s", config.llm_model_id) |
| | self.llm = AutoModelForCausalLM.from_pretrained( |
| | config.llm_model_id, |
| | torch_dtype=torch.bfloat16, |
| | ) |
| | self.llm.requires_grad_(False) |
| | self.llm.eval() |
| |
|
| | |
| | |
| | llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_id, trust_remote_code=True) |
| | prompt_enc = llm_tokenizer( |
| | "Speak the following text aloud: ", |
| | return_tensors="pt", |
| | add_special_tokens=True, |
| | ) |
| | self.register_buffer( |
| | "_prompt_prefix_ids", |
| | prompt_enc.input_ids, |
| | persistent=False, |
| | ) |
| | self._prompt_len = prompt_enc.input_ids.shape[1] |
| |
|
| | llm_dim = self.llm.config.hidden_size |
| |
|
| | |
| | backbone_dim = self.backbone.config.hidden_size |
| |
|
| | |
| | |
| | |
| | from transformers.models.llama.modeling_llama import LlamaRMSNorm |
| |
|
| | self.projector = nn.Sequential( |
| | nn.Linear(llm_dim, config.projector_hidden), |
| | LlamaRMSNorm(config.projector_hidden, eps=1e-6), |
| | nn.GELU(), |
| | nn.Linear(config.projector_hidden, backbone_dim), |
| | LlamaRMSNorm(backbone_dim, eps=1e-6), |
| | ).to(torch.bfloat16) |
| |
|
| | |
| | self.temperature = config.temperature |
| | self.top_k = config.top_k |
| |
|
| | |
| | self.neucodec_model = None |
| |
|
| | self._backbone_loaded = True |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | """Load AudioHead: config + projector weights from disk/Hub, backbone from HF Hub.""" |
| | from pathlib import Path |
| |
|
| | from safetensors.torch import load_file |
| |
|
| | path = Path(pretrained_model_name_or_path) |
| |
|
| | |
| | if not path.is_dir(): |
| | from huggingface_hub import snapshot_download |
| |
|
| | path = Path(snapshot_download(pretrained_model_name_or_path)) |
| |
|
| | |
| | config = AudioHeadConfig.from_pretrained(path) |
| |
|
| | |
| | model = cls(config) |
| |
|
| | |
| | safetensors_path = path / "model.safetensors" |
| | if safetensors_path.exists(): |
| | projector_state = load_file(safetensors_path) |
| | model.load_state_dict(projector_state, strict=False) |
| | logger.info("Loaded projector weights from %s", safetensors_path) |
| |
|
| | return model |
| |
|
| | def train(self, mode: bool = True): |
| | """Override to keep backbone and LLM in eval mode (disables dropout, etc.).""" |
| | super().train(mode) |
| | |
| | self.backbone.eval() |
| | if self.llm is not None: |
| | self.llm.eval() |
| | return self |
| |
|
| | def _embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| | """Embed tokens using the frozen backbone's embedding table.""" |
| | return self.backbone.model.embed_tokens(token_ids) |
| |
|
| | def _codec_to_speech_ids(self, codec_codes: torch.Tensor) -> torch.Tensor: |
| | """Map NeuCodec codes [0, 65535] to neutts-nano speech token IDs.""" |
| | return codec_codes + self.speech_token_offset |
| |
|
| | def _speech_ids_to_codec(self, speech_ids: torch.Tensor) -> torch.Tensor: |
| | """Map neutts-nano speech token IDs back to NeuCodec codes [0, 65535].""" |
| | return speech_ids - self.speech_token_offset |
| |
|
| | def forward( |
| | self, |
| | text_token_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | llm_hidden_states: Optional[torch.Tensor] = None, |
| | codec_labels: Optional[torch.Tensor] = None, |
| | codec_input_ids: Optional[torch.Tensor] = None, |
| | codec_attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> AudioHeadOutput: |
| | """Forward pass for training or inference. |
| | |
| | Args: |
| | text_token_ids: Text token IDs [batch, seq_len] (LLM tokenizer vocab). |
| | Run through frozen LLM to get hidden states. Mutually exclusive |
| | with llm_hidden_states. |
| | attention_mask: Text attention mask [batch, seq_len] (1=real, 0=padding) |
| | llm_hidden_states: Pre-computed LLM hidden states [batch, seq_len, llm_dim]. |
| | Used in pipeline mode when ASRModel provides hidden states directly. |
| | codec_labels: Target NeuCodec codes [batch, audio_len] (-100 for ignore) |
| | codec_input_ids: Teacher-forced NeuCodec codes [batch, audio_len] |
| | codec_attention_mask: Codec attention mask [batch, audio_len] |
| | **kwargs: Absorbed silently (Trainer may pass extra keys). |
| | |
| | Returns: |
| | AudioHeadOutput with loss (training) or codes (inference). |
| | """ |
| | |
| | if llm_hidden_states is not None: |
| | hidden_states = llm_hidden_states |
| | elif text_token_ids is not None: |
| | |
| | |
| | batch_size = text_token_ids.shape[0] |
| | device = text_token_ids.device |
| | prompt = self._prompt_prefix_ids.expand(batch_size, -1).to(device) |
| | full_ids = torch.cat([prompt, text_token_ids], dim=1) |
| |
|
| | if attention_mask is not None: |
| | prompt_mask = torch.ones( |
| | batch_size, self._prompt_len, device=device, dtype=attention_mask.dtype |
| | ) |
| | full_mask = torch.cat([prompt_mask, attention_mask], dim=1) |
| | else: |
| | full_mask = None |
| |
|
| | with torch.no_grad(): |
| | llm_out = self.llm.model( |
| | input_ids=full_ids, |
| | attention_mask=full_mask, |
| | ) |
| | |
| | hidden_states = llm_out.last_hidden_state[:, self._prompt_len :] |
| | else: |
| | raise ValueError("Either text_token_ids or llm_hidden_states must be provided") |
| |
|
| | batch_size, text_len = hidden_states.shape[:2] |
| | device = hidden_states.device |
| |
|
| | |
| | |
| | prefix = self.projector(hidden_states) |
| |
|
| | if codec_labels is None: |
| | |
| | codes = self._generate(prefix, attention_mask) |
| | return AudioHeadOutput(codes=codes) |
| |
|
| | |
| | assert codec_input_ids is not None, "codec_input_ids required when codec_labels provided" |
| |
|
| | |
| | |
| | |
| | speech_input = self._map_collator_ids_to_speech(codec_input_ids) |
| |
|
| | with torch.no_grad(): |
| | token_emb = self._embed_tokens(speech_input) |
| |
|
| | audio_len = token_emb.shape[1] |
| |
|
| | |
| | |
| | hidden = torch.cat([prefix, token_emb], dim=1) |
| |
|
| | |
| | prefix_mask = ( |
| | attention_mask |
| | if attention_mask is not None |
| | else torch.ones(batch_size, text_len, device=device, dtype=torch.long) |
| | ) |
| | audio_mask = ( |
| | codec_attention_mask |
| | if codec_attention_mask is not None |
| | else torch.ones(batch_size, audio_len, device=device, dtype=torch.long) |
| | ) |
| | combined_mask = torch.cat([prefix_mask, audio_mask], dim=1) |
| |
|
| | |
| | |
| | |
| | |
| | outputs = self.backbone.model( |
| | inputs_embeds=hidden, |
| | attention_mask=combined_mask, |
| | ) |
| |
|
| | |
| | audio_hidden = outputs.last_hidden_state[:, text_len:] |
| |
|
| | |
| | |
| | |
| | logits = self.backbone.lm_head(audio_hidden) |
| |
|
| | |
| | speech_labels = self._map_collator_labels_to_speech(codec_labels) |
| |
|
| | |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | speech_labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | return AudioHeadOutput(loss=loss) |
| |
|
| | def _map_collator_ids_to_speech(self, codec_input_ids: torch.Tensor) -> torch.Tensor: |
| | """Map S2SDataCollator codec_input_ids to neutts-nano token IDs. |
| | |
| | S2SDataCollator produces: |
| | - BOS_TOKEN (65536) at position 0 |
| | - NeuCodec codes (0-65535) for real audio |
| | - PAD_TOKEN (65538) for padding |
| | |
| | Maps to: |
| | - BOS_TOKEN → <|SPEECH_GENERATION_START|> |
| | - codes 0-65535 → <|speech_0|>..<|speech_65535|> |
| | - PAD_TOKEN → pad_token_id |
| | """ |
| | result = codec_input_ids.clone() |
| |
|
| | |
| | bos_mask = codec_input_ids == NEUCODEC_VOCAB_SIZE |
| | result[bos_mask] = self.speech_start_id |
| |
|
| | |
| | eos_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 1) |
| | result[eos_mask] = self.speech_end_id |
| |
|
| | |
| | pad_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 2) |
| | result[pad_mask] = self.tts_tokenizer.pad_token_id |
| |
|
| | |
| | codec_mask = codec_input_ids < NEUCODEC_VOCAB_SIZE |
| | result[codec_mask] = codec_input_ids[codec_mask] + self.speech_token_offset |
| |
|
| | return result |
| |
|
| | def _map_collator_labels_to_speech(self, codec_labels: torch.Tensor) -> torch.Tensor: |
| | """Map S2SDataCollator codec_labels to neutts-nano token IDs. |
| | |
| | codec_labels contains: |
| | - NeuCodec codes (0-65535) for real targets |
| | - EOS_TOKEN (65537) at the end |
| | - -100 for ignore positions |
| | """ |
| | result = codec_labels.clone() |
| |
|
| | valid = codec_labels != -100 |
| |
|
| | |
| | eos_mask = valid & (codec_labels == (NEUCODEC_VOCAB_SIZE + 1)) |
| | result[eos_mask] = self.speech_end_id |
| |
|
| | |
| | codec_mask = valid & (codec_labels < NEUCODEC_VOCAB_SIZE) |
| | result[codec_mask] = codec_labels[codec_mask] + self.speech_token_offset |
| |
|
| | return result |
| |
|
| | def _generate( |
| | self, prefix: torch.Tensor, prefix_mask: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | """AR generation with KV cache on frozen backbone. |
| | |
| | Args: |
| | prefix: Projected text embeddings [batch, text_len, 576]. |
| | prefix_mask: Attention mask for prefix tokens (unused for now, |
| | reserved for batched generation with padding). |
| | """ |
| | _ = prefix_mask |
| | batch_size, text_len, _ = prefix.shape |
| | device = prefix.device |
| |
|
| | all_codes = [] |
| |
|
| | |
| | start_token = torch.full( |
| | (batch_size, 1), self.speech_start_id, dtype=torch.long, device=device |
| | ) |
| | start_emb = self._embed_tokens(start_token) |
| | hidden = torch.cat([prefix, start_emb], dim=1) |
| |
|
| | position_ids = torch.arange(text_len + 1, device=device).unsqueeze(0).expand(batch_size, -1) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.backbone.model( |
| | inputs_embeds=hidden, |
| | position_ids=position_ids, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| | last_hidden = outputs.last_hidden_state[:, -1:] |
| |
|
| | for step in range(self.max_tokens): |
| | |
| | logits = self.backbone.lm_head(last_hidden.squeeze(1)) |
| |
|
| | |
| | speech_logits = logits[ |
| | :, self.speech_token_offset : self.speech_token_offset + NEUCODEC_VOCAB_SIZE |
| | ] |
| |
|
| | |
| | end_logit = logits[:, self.speech_end_id : self.speech_end_id + 1] |
| |
|
| | |
| | combined = torch.cat([speech_logits, end_logit], dim=-1) |
| |
|
| | |
| | if self.temperature != 1.0: |
| | combined = combined / self.temperature |
| | if self.top_k > 0: |
| | topk_vals, _ = combined.topk(min(self.top_k, combined.size(-1))) |
| | combined[combined < topk_vals[:, -1:]] = float("-inf") |
| |
|
| | probs = F.softmax(combined, dim=-1) |
| | sampled = torch.multinomial(probs, 1).squeeze(-1) |
| |
|
| | |
| | is_eos = sampled == NEUCODEC_VOCAB_SIZE |
| | if is_eos.all(): |
| | break |
| |
|
| | |
| | codec_code = sampled.clamp(0, NEUCODEC_VOCAB_SIZE - 1) |
| | all_codes.append(codec_code) |
| |
|
| | |
| | next_token_id = codec_code + self.speech_token_offset |
| | |
| | next_token_id[is_eos] = self.speech_end_id |
| |
|
| | next_emb = self._embed_tokens(next_token_id.unsqueeze(1)) |
| |
|
| | next_pos = torch.full( |
| | (batch_size, 1), |
| | text_len + 1 + step + 1, |
| | dtype=torch.long, |
| | device=device, |
| | ) |
| |
|
| | outputs = self.backbone.model( |
| | inputs_embeds=next_emb, |
| | position_ids=next_pos, |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| | last_hidden = outputs.last_hidden_state |
| |
|
| | if all_codes: |
| | codes = torch.stack(all_codes, dim=1) |
| | else: |
| | codes = torch.empty(batch_size, 0, dtype=torch.long, device=device) |
| |
|
| | return codes |
| |
|
| | def state_dict(self, *args, **kwargs): |
| | """Only save projector weights (backbone is frozen/pretrained).""" |
| | full = super().state_dict(*args, **kwargs) |
| | return {k: v for k, v in full.items() if k.startswith("projector.")} |
| |
|
| | def _load_neucodec(self): |
| | """Load frozen NeuCodec model for audio decoding.""" |
| | from neucodec import NeuCodec |
| |
|
| | self.neucodec_model = NeuCodec.from_pretrained(self.config.neucodec_model_id) |
| | self.neucodec_model.eval() |
| | self.neucodec_model.requires_grad_(False) |
| | logger.info("Loaded frozen NeuCodec model for audio decoding") |
| |
|
| | def decode_to_audio(self, codes: torch.Tensor) -> list[torch.Tensor]: |
| | """Decode NeuCodec FSQ tokens to audio waveforms. |
| | |
| | Args: |
| | codes: Codec tokens [batch, seq_len] (values 0-65535) |
| | |
| | Returns: |
| | List of audio waveform tensors (one per batch item) |
| | """ |
| | if self.neucodec_model is None: |
| | self._load_neucodec() |
| | assert self.neucodec_model is not None |
| |
|
| | codes_3d = codes.unsqueeze(1).to(self.neucodec_model.device) |
| |
|
| | with torch.no_grad(): |
| | audio_values = self.neucodec_model.decode_code(codes_3d) |
| |
|
| | return [audio_values[i, 0] for i in range(audio_values.shape[0])] |
| |
|
| | def generate_streaming( |
| | self, |
| | text_token_ids: Optional[torch.Tensor] = None, |
| | llm_hidden_states: Optional[torch.Tensor] = None, |
| | chunk_samples: int = 24000, |
| | ) -> Iterator[torch.Tensor]: |
| | """Generate audio and yield waveform chunks for streaming playback.""" |
| | output = self(text_token_ids=text_token_ids, llm_hidden_states=llm_hidden_states) |
| | codes = output.codes |
| | audios = self.decode_to_audio(codes) |
| |
|
| | for audio in audios: |
| | for start in range(0, audio.shape[-1], chunk_samples): |
| | end = min(start + chunk_samples, audio.shape[-1]) |
| | yield audio[..., start:end] |
| |
|