| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
| import torch.distributed as dist |
| from lightning import LightningModule |
| from omegaconf import DictConfig, OmegaConf |
| from peft import PeftModel |
| from torch import Tensor |
| from torch.distributed.fsdp import fully_shard |
| from torch.distributed.tensor import Replicate, Shard |
| from torch.distributed.tensor.parallel import ( |
| ColwiseParallel, |
| PrepareModuleInput, |
| RowwiseParallel, |
| SequenceParallel, |
| loss_parallel, |
| parallelize_module, |
| ) |
| from transformers import DynamicCache |
|
|
| from nemo.collections.audio.parts.utils.transforms import resample |
| from nemo.collections.common.tokenizers import AutoTokenizer |
| from nemo.collections.speechlm2.data.utils import get_pad_id |
| from nemo.collections.speechlm2.models.duplex_s2s_model import replace_control_speech_codes |
| from nemo.collections.speechlm2.modules import TransformerARSpeechDecoder |
| from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin |
| from nemo.collections.speechlm2.parts.lora import maybe_install_lora |
| from nemo.collections.speechlm2.parts.metrics.asr_bleu import ASRBLEU |
| from nemo.collections.speechlm2.parts.metrics.bleu import BLEU |
| from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen |
| from nemo.collections.speechlm2.parts.precision import fp32_precision |
| from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, setup_audio_codec, setup_speech_encoder |
| from nemo.collections.speechlm2.parts.text_utils import tokens_to_str |
| from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType |
| from nemo.utils import logging |
|
|
|
|
| class DuplexS2SSpeechDecoderModel(LightningModule, HFHubMixin): |
| def __init__(self, cfg: dict) -> None: |
| assert isinstance(cfg, dict), ( |
| "You must pass the config to DuplexS2SModel as a Python dict to support hyperparameter serialization " |
| f"in PTL checkpoints (we got: '{type(cfg)=}')." |
| ) |
| super().__init__() |
| self.save_hyperparameters() |
| self.cfg = DictConfig(cfg) |
|
|
| setup_audio_codec(self) |
| self._codebook_size = self.audio_codec.vector_quantizer.codebook_size |
| self._num_codebooks = self.audio_codec.vector_quantizer.num_groups |
|
|
| |
| |
| |
| |
| self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True) |
| llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights).train() |
| self.llm = llm.model |
| self.lm_head = llm.lm_head |
| |
| |
| self.embed_tokens = self.llm.embed_tokens |
| del self.llm.embed_tokens |
| maybe_install_lora(self) |
|
|
| |
| setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights) |
|
|
| self.speech_generation = TransformerARSpeechDecoder( |
| speech_decoder_parms=OmegaConf.to_container(self.cfg.speech_decoder), |
| lantent_dim=self.llm.config.hidden_size, |
| num_audio_codebooks=self._num_codebooks, |
| num_audio_tokens_per_codebook=self.speech_vocab_size, |
| ) |
|
|
| self.embed_audio_tokens = torch.nn.ModuleList( |
| [ |
| torch.nn.Embedding(self.speech_vocab_size, self.embed_tokens.embedding_dim) |
| for _ in range(self._num_codebooks) |
| ] |
| ) |
| self.audio_head = torch.nn.Linear(self.llm.config.hidden_size, self.speech_vocab_size * self._num_codebooks) |
|
|
| |
| self.register_buffer( |
| "_control_codes", |
| torch.tensor([self.speech_bos_id, self.speech_eos_id, self.speech_delay_id], device=self.device), |
| ) |
|
|
| self._use_fsdp = False |
| self._use_tp = False |
|
|
| @property |
| def speech_vocab_size(self): |
| """Return the size of the audio codec codebook including extra speech BOS and EOS tokens.""" |
| return self._codebook_size + 3 |
|
|
| @property |
| def speech_bos_id(self) -> int: |
| """Indicates start of utterance generation (not start of inference!).""" |
| return self._codebook_size |
|
|
| @property |
| def speech_eos_id(self) -> int: |
| """Indicates end of utterance generation.""" |
| return self._codebook_size + 1 |
|
|
| @property |
| def speech_delay_id(self) -> int: |
| """Indicates start of inference (the very first frame).""" |
| return self._codebook_size + 2 |
|
|
| @property |
| def text_vocab_size(self): |
| """Return the size of the text tokenizer.""" |
| return self.tokenizer.vocab_size |
|
|
| @property |
| def text_bos_id(self) -> int: |
| return self.tokenizer.bos_id |
|
|
| @property |
| def text_eos_id(self) -> int: |
| return self.tokenizer.eos_id |
|
|
| @property |
| def text_pad_id(self) -> int: |
| """ |
| Text pad ID is used as a 'blank' for frames when the model is not speaking |
| and for frames where the model is speaking but has already predicted the |
| entire text channel's content. |
| |
| Example: |
| |
| flow: |---user---||-------assistant--------||-user-| |
| text channel: 0000000000 1xxxxxxx0000000000000002 000000 |
| |
| Where 0 indicates PAD ID, 1 indicates BOS ID, 2 indacates EOS ID, |
| and x indicates tokens corresponding to actual text |
| |
| """ |
| return get_pad_id(self.tokenizer) |
|
|
| def forward(self, input_embeds: Tensor, cache=None, input_audio_tokens=None, loss_mask=None) -> dict[str, Tensor]: |
| """ |
| Separated text and speech prediction: |
| - Speech prediction is achieved by a independent AR decoder based on last_hidden_state + audio tokens |
| - For KV-cache: |
| (1) llm cache depends on input cache is None or Not |
| (2) speech_generation cache relys on reset_input_and_kv_cache function. |
| """ |
|
|
| out = self.llm( |
| inputs_embeds=input_embeds, past_key_values=cache, use_cache=cache is not None, return_dict=True |
| ) |
| B, T = input_embeds.shape[:2] |
| text_logits = self.lm_head(out['last_hidden_state']) |
|
|
| if loss_mask is not None: |
| |
| loss_mask = loss_mask[:, :, -1].reshape(loss_mask.size(0), loss_mask.size(1)) |
| self.speech_generation.reset_input_and_kv_cache(use_cache=False) |
|
|
| _, audio_logits = self.speech_generation( |
| out['last_hidden_state'].transpose(0, 1), loss_mask, input_audio_tokens=input_audio_tokens |
| ) |
|
|
| audio_logits = audio_logits.view(B, T, self._num_codebooks, self.speech_vocab_size) |
|
|
| ans = { |
| "text_logits": text_logits, |
| "audio_logits": audio_logits, |
| } |
| if cache is not None: |
| ans["cache"] = out["past_key_values"] |
| return ans |
|
|
| def prepare_inputs(self, batch: dict): |
| """ |
| Similar to DuplexS2SModel.prepare_inputs, with following changes: |
| (1) Add 'input_audio_tokens' and 'loss_mask' in return value for TransformerARSpeechDecoder |
| (2) Remove audio codec embedding from 'input_embeds' |
| """ |
|
|
| source_encoded, source_encoded_lens = self.perception( |
| input_signal=batch["source_audio"], input_signal_length=batch["source_audio_lens"] |
| ) |
|
|
| target_tokens = batch["target_tokens"] |
| if (diff := target_tokens.shape[1] - source_encoded.shape[1]) < 0: |
| target_tokens = torch.cat( |
| [ |
| target_tokens, |
| ( |
| torch.ones(source_encoded.shape[0], abs(diff), device=source_encoded.device) * self.text_pad_id |
| ).to(torch.long), |
| ], |
| dim=-1, |
| ) |
| elif diff > 0: |
| target_tokens = target_tokens[:, : source_encoded.shape[1]] |
|
|
| with fp32_precision(), torch.no_grad(): |
| target_codes, target_codes_lens = self.audio_codec.encode( |
| audio=batch["target_audio"], audio_len=batch["target_audio_lens"] |
| ) |
| target_codes = target_codes.transpose(1, 2) |
|
|
| if (tl := target_codes.shape[1]) != (sl := source_encoded.shape[1]): |
| if tl < sl: |
| diff = sl - tl |
| source_encoded = source_encoded[:, :tl] |
| target_tokens = target_tokens[:, :tl] |
| torch.clamp_(source_encoded_lens, max=tl) |
| else: |
| diff = tl - sl |
| target_codes = target_codes[:, :sl] |
| torch.clamp_(target_codes_lens, max=sl) |
| if diff > 2: |
| logging.warning( |
| f"A mismatch between source ({sl}) and target ({tl}) sequence length greater than 2 detected. " |
| f"This may indicate significant desynchronization in longer sessions." |
| ) |
|
|
| btt = target_tokens[..., None] |
| target_codes = torch.where(btt == self.text_bos_id, self.speech_bos_id, target_codes) |
| target_codes = torch.where(btt == self.text_eos_id, self.speech_eos_id, target_codes) |
| target_codes = torch.cat( |
| [ |
| torch.full( |
| [target_codes.shape[0], 1, target_codes.shape[-1]], |
| fill_value=self.speech_delay_id, |
| device=self.device, |
| dtype=torch.long, |
| ), |
| target_codes[:, :-1], |
| ], |
| dim=1, |
| ) |
|
|
| input_ids = torch.cat([target_codes, target_tokens[..., None]], dim=-1) |
| if self._use_tp: |
| tp_world_size = self.device_mesh["tensor_parallel"].size() |
| if (remainder := (input_ids.shape[1] - 1) % tp_world_size) != 0: |
| input_ids = input_ids[:, :-remainder] |
| source_encoded = source_encoded[:, :-remainder] |
|
|
| text_inputs = input_ids[:, :-1, -1] |
| text_labels = input_ids[:, 1:, -1] |
| audio_inputs = input_ids[:, :-1, :-1] |
| audio_labels = input_ids[:, 1:, :-1] |
|
|
| input_embeds = self.embed_tokens(text_inputs) |
|
|
| input_embeds.add_(source_encoded[:, :-1] * self.cfg.get("duplex_user_channel_weight", 1.0)) |
|
|
| loss_mask = torch.ones_like( |
| torch.cat([text_labels.unsqueeze(-1), audio_labels], dim=-1), |
| device=self.device, |
| dtype=torch.bool, |
| ) |
|
|
| return { |
| "input_embeds": input_embeds, |
| "input_lens": source_encoded_lens - 1, |
| "output_lens": target_codes_lens - 1, |
| "text_labels": text_labels, |
| "input_audio_tokens": audio_inputs, |
| "audio_labels": audio_labels, |
| "loss_mask": loss_mask, |
| } |
|
|
| def training_step(self, batch: dict, batch_idx: int): |
| for m in (self.perception.preprocessor, self.perception.encoder, self.llm, self.speech_generation): |
| if is_frozen(m): |
| m.eval() |
| inputs = self.prepare_inputs(batch) |
| forward_outputs = self( |
| inputs["input_embeds"], |
| input_audio_tokens=inputs["input_audio_tokens"], |
| loss_mask=inputs["loss_mask"], |
| ) |
| num_frames = inputs["input_lens"].sum() |
| with loss_parallel(): |
| text_loss = ( |
| torch.nn.functional.cross_entropy( |
| forward_outputs["text_logits"].flatten(0, 1), |
| inputs["text_labels"].flatten(0, 1), |
| reduction="sum", |
| ) |
| / num_frames |
| ) |
| audio_loss = torch.nn.functional.cross_entropy( |
| forward_outputs["audio_logits"].flatten(0, 2), |
| inputs["audio_labels"].flatten(0, 2), |
| reduction="sum", |
| ) / (num_frames * self._num_codebooks) |
| loss = self.cfg.text_loss_weight * text_loss + self.cfg.audio_loss_weight * audio_loss |
|
|
| B, T = inputs["input_embeds"].shape[:2] |
| ans = { |
| "loss": loss, |
| "learning_rate": ( |
| torch.as_tensor(self.trainer.optimizers[0].param_groups[0]['lr'] if self._trainer is not None else 0) |
| ), |
| "text_loss": text_loss, |
| "audio_loss": audio_loss, |
| "batch_size": B, |
| "sequence_length": T, |
| "num_frames": num_frames.to(torch.float32), |
| "padding_ratio": num_frames / (B * T), |
| } |
| self.log_dict(ans, on_step=True) |
| return ans |
|
|
| def on_train_epoch_start(self) -> None: |
| setup_audio_codec(self) |
|
|
| def on_validation_epoch_start(self) -> None: |
| self.on_train_epoch_start() |
| self.asr_bleu = ASRBLEU(self.cfg.scoring_asr).reset() |
| self.bleu = BLEU().reset() |
|
|
| def on_validation_epoch_end(self, prefix="val") -> None: |
| asr_bleu = self.asr_bleu.compute() |
| for k, m in asr_bleu.items(): |
| self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) |
| bleu = self.bleu.compute() |
| for k, m in bleu.items(): |
| self.log(f"{prefix}_{k}", m.to(self.device), on_epoch=True, sync_dist=True) |
|
|
| def validation_step(self, batch: dict, batch_idx: int): |
| for name, dataset_batch in batch.items(): |
| if dataset_batch is None: |
| continue |
|
|
| results = self.offline_inference( |
| dataset_batch["source_audio"], |
| dataset_batch["source_audio_lens"], |
| ) |
|
|
| with fp32_precision(): |
| self.asr_bleu.update( |
| name=name, |
| refs=dataset_batch["target_texts"], |
| pred_audio=resample(results["audio"], 22050, 16000), |
| pred_audio_lens=(results["audio_len"] / 22050 * 16000).to(torch.long), |
| ) |
|
|
| self.bleu.update(name=name, refs=dataset_batch["target_texts"], hyps=results["text"]) |
|
|
| def on_test_epoch_start(self) -> None: |
| return self.on_validation_epoch_start() |
|
|
| def on_test_epoch_end(self) -> None: |
| return self.on_validation_epoch_end(prefix="test") |
|
|
| def test_step(self, *args, **kwargs): |
| return self.validation_step(*args, **kwargs) |
|
|
| def _get_bos_embedding(self) -> torch.Tensor: |
| """ |
| Remove the audio codec embedding for the beginning of AR decoding. |
| """ |
| text_bos = torch.full((1,), fill_value=self.text_pad_id, device=self.device) |
| input_embeds = self.embed_tokens(text_bos) |
| return input_embeds |
|
|
| @torch.no_grad() |
| def offline_inference( |
| self, |
| input_signal: torch.Tensor, |
| input_signal_lens: torch.Tensor, |
| decode_audio: bool = True, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Autoregressive prediction. |
| |
| Args: |
| input_signal: a batch of waveforms with shape (B, T) with source sampling rate. |
| input_signal_lens: example lengths as number of samples of shape (B,). |
| decode_audio: bool, whether to decode audio codes to waveform. |
| |
| Returns: |
| A dict with keys: |
| * "text": generated text, de-tokenized to strings, properly skipping text_pad_id; list of length B. |
| * "tokens_text": generated text tokens of shape (B, T2). |
| * "tokens_audio": generated audio codes of shape (B, T2, K) where `K=num_codebooks`. |
| * "tokens_len" output lengths as number of tokens of shape (B,). |
| * "audio": generated waveform of shape (B, T3) (`decode_audio=True`). |
| * "audio_len" output lengths as number of waveform samples of shape (B,) (when `decode_audio=True`). |
| """ |
| input_embeds, lengths = self.perception( |
| input_signal=input_signal, |
| input_signal_length=input_signal_lens, |
| ) |
| B, T_local, H = input_embeds.shape |
|
|
| |
| if self._use_fsdp: |
| T_tensor = torch.tensor([T_local], device=input_embeds.device) |
| dist.all_reduce(T_tensor, op=dist.ReduceOp.MAX) |
| T = int(T_tensor.item()) |
| if T > T_local: |
| last_frame = input_embeds[:, T_local - 1 : T_local, :] |
| pad = last_frame.repeat(1, T - T_local, 1) |
| input_embeds = torch.cat([input_embeds, pad], dim=1) |
| else: |
| T = T_local |
|
|
| |
| input_embeds *= self.cfg.get("duplex_user_channel_weight", 1.0) |
|
|
| |
| cache = DynamicCache() |
| |
| self.speech_generation.reset_input_and_kv_cache(use_cache=True) |
| gen_text = torch.empty(B, T, device=self.device, dtype=torch.long) |
| gen_audio = torch.empty(B, T, self._num_codebooks, device=self.device, dtype=torch.long) |
|
|
| |
| input_embeds[:, 0] += self._get_bos_embedding() |
| first_audio = torch.full( |
| [B, 1, self._num_codebooks], |
| fill_value=self.speech_delay_id, |
| device=self.device, |
| dtype=torch.long, |
| ) |
| ans = self(input_embeds[:, :1], cache=cache, input_audio_tokens=first_audio, loss_mask=None) |
| gen_text[:, 0] = ans["text_logits"][:, -1].argmax(dim=-1) |
| gen_audio[:, 0] = ans["audio_logits"][:, -1].argmax(dim=-1) |
|
|
| |
| for t in range(1, T): |
| last_emb = self.embed_tokens(gen_text[:, t - 1]) |
| input_embeds[:, t] += last_emb |
| current_audio = gen_audio[:, t - 1 : t, :] |
| ans = self(input_embeds[:, t : t + 1], cache=ans["cache"], input_audio_tokens=current_audio) |
| gen_text[:, t] = ans["text_logits"][:, -1].argmax(dim=-1) |
| gen_audio[:, t] = ans["audio_logits"][:, -1].argmax(dim=-1) |
|
|
| |
| if self._use_fsdp and T > T_local: |
| gen_text = gen_text[:, :T_local] |
| gen_audio = gen_audio[:, :T_local] |
|
|
| ans = { |
| "text": tokens_to_str(gen_text, lengths, tokenizer=self.tokenizer, pad_id=self.text_pad_id), |
| "tokens_text": gen_text, |
| "tokens_audio": gen_audio, |
| "tokens_len": lengths, |
| } |
|
|
| if decode_audio: |
| gen_audio_codes = replace_control_speech_codes(gen_audio, self._control_codes) |
| with fp32_precision(), torch.no_grad(): |
| predicted_audio, predicted_audio_lens = self.audio_codec.decode( |
| tokens=gen_audio_codes.transpose(1, 2), tokens_len=lengths |
| ) |
| ans["audio"] = predicted_audio |
| ans["audio_len"] = predicted_audio_lens |
|
|
| return ans |
|
|
| def backward(self, *args, **kwargs): |
| with loss_parallel(): |
| super().backward(*args, **kwargs) |
|
|
| def configure_optimizers(self): |
| return configure_optimizers(self) |
|
|
| @property |
| def oomptimizer_schema(self) -> dict: |
| """ |
| Return a typing schema for optimal batch size calibration for various |
| sequence lengths using OOMptimizer. |
| """ |
| return { |
| "cls": dict, |
| "inputs": [ |
| {"name": "source_audio", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, |
| {"name": "source_audio_lens", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, |
| {"name": "target_audio", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, |
| {"name": "target_audio_lens", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, |
| { |
| "name": "target_tokens", |
| "type": NeuralType(("B", "T"), LabelsType()), |
| "seq_length": "output", |
| "vocab_size": self.tokenizer.vocab_size, |
| }, |
| ], |
| } |
|
|
| def configure_model(self) -> None: |
| |
| device_mesh = self.device_mesh |
| if device_mesh is None: |
| return |
|
|
| llm = self.llm |
| if isinstance(llm, PeftModel): |
| llm = llm.base_model.model |
|
|
| if (tp_mesh := device_mesh["tensor_parallel"]).size() > 1: |
| self._use_tp = True |
|
|
| plan = { |
| "layers.0": PrepareModuleInput( |
| input_layouts=(Replicate(),), |
| desired_input_layouts=(Shard(1),), |
| use_local_output=True, |
| ), |
| "norm": SequenceParallel(), |
| } |
| parallelize_module(llm, tp_mesh, plan) |
|
|
| for transformer_block in llm.layers: |
| plan = { |
| "input_layernorm": SequenceParallel(), |
| "self_attn.q_proj": ColwiseParallel(), |
| "self_attn.k_proj": ColwiseParallel(), |
| "self_attn.v_proj": ColwiseParallel(), |
| "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), |
| "post_attention_layernorm": SequenceParallel(), |
| "mlp": PrepareModuleInput( |
| input_layouts=(Shard(1),), |
| desired_input_layouts=(Replicate(),), |
| ), |
| "mlp.gate_proj": ColwiseParallel(), |
| "mlp.up_proj": ColwiseParallel(), |
| "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), |
| |
| |
| } |
|
|
| |
| attn_layer = transformer_block.self_attn |
| for attr in ("num_heads", "num_key_value_heads", "hidden_size"): |
| val = getattr(attn_layer, attr) |
| if val % tp_mesh.size() != 0: |
| logging.warning( |
| f"attn_layer.{attr}={val} is not divisible by {tp_mesh.size()=}: " |
| f"set a different tensor parallelism size to avoid errors." |
| ) |
| setattr(attn_layer, attr, val // tp_mesh.size()) |
|
|
| parallelize_module(transformer_block, tp_mesh, plan) |
|
|
| for m in (self.lm_head, self.audio_head): |
| parallelize_module( |
| m, |
| tp_mesh, |
| ColwiseParallel( |
| input_layouts=Shard(1), |
| output_layouts=Shard(-1), |
| use_local_output=False, |
| ), |
| ) |
|
|
| if (dp_mesh := device_mesh["data_parallel"]).size() > 1: |
| assert dp_mesh.ndim == 1 |
| self._use_fsdp = True |
|
|
| fsdp_config = {"mesh": dp_mesh} |
|
|
| for idx, layer in enumerate(llm.layers): |
| llm.layers[idx] = fully_shard(layer, **fsdp_config) |
| self.embed_tokens = fully_shard(self.embed_tokens, **fsdp_config) |
| self.llm = fully_shard(self.llm, **fsdp_config) |
| self.lm_head = fully_shard(self.lm_head, **fsdp_config) |
| self.perception = fully_shard(self.perception, **fsdp_config) |
| self.speech_generation = fully_shard(self.speech_generation, **fsdp_config) |
|
|