| |
| """Export Parakeet Realtime EOU 120M RNNT components into CoreML.""" |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| import coremltools as ct |
| import torch |
|
|
|
|
| @dataclass |
| class ExportSettings: |
| output_dir: Path |
| compute_units: ct.ComputeUnit |
| deployment_target: Optional[ct.target] |
| compute_precision: Optional[ct.precision] |
| max_audio_seconds: float |
| max_symbol_steps: int |
|
|
|
|
| class PreprocessorWrapper(torch.nn.Module): |
| """Wrapper for the audio preprocessor (mel spectrogram extraction).""" |
|
|
| def __init__(self, module: torch.nn.Module) -> None: |
| super().__init__() |
| self.module = module |
|
|
| def forward( |
| self, audio_signal: torch.Tensor, length: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| mel, mel_length = self.module( |
| input_signal=audio_signal, length=length.to(dtype=torch.long) |
| ) |
| return mel, mel_length |
|
|
|
|
| class EncoderWrapper(torch.nn.Module): |
| """Wrapper for the cache-aware FastConformer encoder. |
| |
| Note: For the realtime EOU model, the encoder is cache-aware which means |
| it can operate in a streaming fashion. For CoreML export, we export |
| without cache state for simplicity (full-context mode). |
| """ |
|
|
| def __init__(self, module: torch.nn.Module) -> None: |
| super().__init__() |
| self.module = module |
|
|
| def forward( |
| self, features: torch.Tensor, length: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| encoded, encoded_lengths = self.module( |
| audio_signal=features, length=length.to(dtype=torch.long) |
| ) |
| |
| |
| frame_times = ( |
| torch.arange(encoded.shape[-1], device=encoded.device, dtype=torch.float32) |
| * 0.08 |
| ) |
| return encoded, encoded_lengths, frame_times |
|
|
|
|
| class DecoderWrapper(torch.nn.Module): |
| """Wrapper for the RNNT prediction network (decoder).""" |
|
|
| def __init__(self, module: torch.nn.Module) -> None: |
| super().__init__() |
| self.module = module |
|
|
| def forward( |
| self, |
| targets: torch.Tensor, |
| target_lengths: torch.Tensor, |
| h_in: torch.Tensor, |
| c_in: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| state = [h_in, c_in] |
| decoder_output, _, new_state = self.module( |
| targets=targets.to(dtype=torch.long), |
| target_length=target_lengths.to(dtype=torch.long), |
| states=state, |
| ) |
| return decoder_output, new_state[0], new_state[1] |
|
|
|
|
| class JointWrapper(torch.nn.Module): |
| """Wrapper for the RNNT joint network. |
| |
| Note: Unlike Parakeet TDT v3, the realtime EOU model does NOT have |
| duration outputs (num_extra_outputs). The joint network outputs only |
| token logits over the vocabulary + blank. |
| """ |
|
|
| def __init__(self, module: torch.nn.Module) -> None: |
| super().__init__() |
| self.module = module |
|
|
| def forward( |
| self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor |
| ) -> torch.Tensor: |
| |
| |
| encoder_outputs = encoder_outputs.transpose(1, 2) |
| decoder_outputs = decoder_outputs.transpose(1, 2) |
|
|
| |
| enc_proj = self.module.enc(encoder_outputs) |
| dec_proj = self.module.pred(decoder_outputs) |
|
|
| |
| x = enc_proj.unsqueeze(2) + dec_proj.unsqueeze(1) |
| x = self.module.joint_net[0](x) |
| x = self.module.joint_net[1](x) |
| out = self.module.joint_net[2](x) |
| return out |
|
|
|
|
| class MelEncoderWrapper(torch.nn.Module): |
| """Fused wrapper: waveform -> mel -> encoder. |
| |
| Inputs: |
| - audio_signal: [B, S] |
| - audio_length: [B] |
| |
| Outputs: |
| - encoder: [B, D, T_enc] |
| - encoder_length: [B] |
| - frame_times: [T_enc] |
| """ |
|
|
| def __init__( |
| self, preprocessor: PreprocessorWrapper, encoder: EncoderWrapper |
| ) -> None: |
| super().__init__() |
| self.preprocessor = preprocessor |
| self.encoder = encoder |
|
|
| def forward( |
| self, audio_signal: torch.Tensor, audio_length: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| mel, mel_length = self.preprocessor(audio_signal, audio_length) |
| encoded, enc_len, frame_times = self.encoder(mel, mel_length.to(dtype=torch.int32)) |
| return encoded, enc_len, frame_times |
|
|
|
|
| class JointDecisionWrapper(torch.nn.Module): |
| """Joint + decision head: outputs label id and label prob. |
| |
| Unlike Parakeet TDT v3, this model does NOT have duration outputs. |
| |
| Inputs: |
| - encoder_outputs: [B, D, T] |
| - decoder_outputs: [B, D, U] |
| |
| Returns: |
| - token_id: [B, T, U] int32 |
| - token_prob: [B, T, U] float32 |
| """ |
|
|
| def __init__(self, joint: JointWrapper, vocab_size: int) -> None: |
| super().__init__() |
| self.joint = joint |
| self.vocab_with_blank = int(vocab_size) + 1 |
|
|
| def forward(self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor): |
| logits = self.joint(encoder_outputs, decoder_outputs) |
|
|
| |
| token_ids = torch.argmax(logits, dim=-1).to(dtype=torch.int32) |
| token_probs_all = torch.softmax(logits, dim=-1) |
| |
| token_prob = torch.gather( |
| token_probs_all, dim=-1, index=token_ids.long().unsqueeze(-1) |
| ).squeeze(-1) |
|
|
| return token_ids, token_prob |
|
|
|
|
| class JointDecisionSingleStep(torch.nn.Module): |
| """Single-step variant for streaming: encoder_step -> token decision. |
| |
| Inputs: |
| - encoder_step: [B=1, D, T=1] |
| - decoder_step: [B=1, D, U=1] |
| |
| Returns: |
| - token_id: [1, 1, 1] int32 |
| - token_prob: [1, 1, 1] float32 |
| - top_k_ids: [1, 1, 1, K] int32 |
| - top_k_logits: [1, 1, 1, K] float32 |
| """ |
|
|
| def __init__(self, joint: JointWrapper, vocab_size: int, top_k: int = 64) -> None: |
| super().__init__() |
| self.joint = joint |
| self.vocab_with_blank = int(vocab_size) + 1 |
| self.top_k = int(top_k) |
|
|
| def forward(self, encoder_step: torch.Tensor, decoder_step: torch.Tensor): |
| |
| logits = self.joint(encoder_step, decoder_step) |
|
|
| token_ids = torch.argmax(logits, dim=-1, keepdim=False).to(dtype=torch.int32) |
| token_probs_all = torch.softmax(logits, dim=-1) |
| token_prob = torch.gather( |
| token_probs_all, dim=-1, index=token_ids.long().unsqueeze(-1) |
| ).squeeze(-1) |
|
|
| |
| topk_logits, topk_ids_long = torch.topk( |
| logits, k=min(self.top_k, logits.shape[-1]), dim=-1 |
| ) |
| topk_ids = topk_ids_long.to(dtype=torch.int32) |
| return token_ids, token_prob, topk_ids, topk_logits |
|
|
|
|
| def _coreml_convert( |
| traced: torch.jit.ScriptModule, |
| inputs, |
| outputs, |
| settings: ExportSettings, |
| compute_units_override: Optional[ct.ComputeUnit] = None, |
| compute_precision: Optional[ct.precision] = None, |
| ) -> ct.models.MLModel: |
| cu = ( |
| compute_units_override |
| if compute_units_override is not None |
| else settings.compute_units |
| ) |
| kwargs = { |
| "convert_to": "mlprogram", |
| "inputs": inputs, |
| "outputs": outputs, |
| "compute_units": cu, |
| } |
| print("Converting:", traced.__class__.__name__) |
| print("Conversion kwargs:", kwargs) |
| if settings.deployment_target is not None: |
| kwargs["minimum_deployment_target"] = settings.deployment_target |
| |
| |
| if compute_precision is not None: |
| kwargs["compute_precision"] = compute_precision |
| elif settings.compute_precision is not None: |
| kwargs["compute_precision"] = settings.compute_precision |
| |
| return ct.convert(traced, **kwargs) |
|
|