|
|
|
|
|
"""Export Parakeet Realtime EOU 120M RNNT components into CoreML.""" |
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import coremltools as ct |
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExportSettings: |
|
|
output_dir: Path |
|
|
compute_units: ct.ComputeUnit |
|
|
deployment_target: Optional[ct.target] |
|
|
compute_precision: Optional[ct.precision] |
|
|
max_audio_seconds: float |
|
|
max_symbol_steps: int |
|
|
|
|
|
|
|
|
class PreprocessorWrapper(torch.nn.Module): |
|
|
"""Wrapper for the audio preprocessor (mel spectrogram extraction).""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, audio_signal: torch.Tensor, length: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
mel, mel_length = self.module( |
|
|
input_signal=audio_signal, length=length.to(dtype=torch.long) |
|
|
) |
|
|
return mel, mel_length |
|
|
|
|
|
|
|
|
class EncoderWrapper(torch.nn.Module): |
|
|
"""Wrapper for the cache-aware FastConformer encoder. |
|
|
|
|
|
Note: For the realtime EOU model, the encoder is cache-aware which means |
|
|
it can operate in a streaming fashion. For CoreML export, we export |
|
|
without cache state for simplicity (full-context mode). |
|
|
""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, features: torch.Tensor, length: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
encoded, encoded_lengths = self.module( |
|
|
audio_signal=features, length=length.to(dtype=torch.long) |
|
|
) |
|
|
|
|
|
|
|
|
frame_times = ( |
|
|
torch.arange(encoded.shape[-1], device=encoded.device, dtype=torch.float32) |
|
|
* 0.08 |
|
|
) |
|
|
return encoded, encoded_lengths, frame_times |
|
|
|
|
|
|
|
|
class DecoderWrapper(torch.nn.Module): |
|
|
"""Wrapper for the RNNT prediction network (decoder).""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
targets: torch.Tensor, |
|
|
target_lengths: torch.Tensor, |
|
|
h_in: torch.Tensor, |
|
|
c_in: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
state = [h_in, c_in] |
|
|
decoder_output, _, new_state = self.module( |
|
|
targets=targets.to(dtype=torch.long), |
|
|
target_length=target_lengths.to(dtype=torch.long), |
|
|
states=state, |
|
|
) |
|
|
return decoder_output, new_state[0], new_state[1] |
|
|
|
|
|
|
|
|
class JointWrapper(torch.nn.Module): |
|
|
"""Wrapper for the RNNT joint network. |
|
|
|
|
|
Note: Unlike Parakeet TDT v3, the realtime EOU model does NOT have |
|
|
duration outputs (num_extra_outputs). The joint network outputs only |
|
|
token logits over the vocabulary + blank. |
|
|
""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
encoder_outputs = encoder_outputs.transpose(1, 2) |
|
|
decoder_outputs = decoder_outputs.transpose(1, 2) |
|
|
|
|
|
|
|
|
enc_proj = self.module.enc(encoder_outputs) |
|
|
dec_proj = self.module.pred(decoder_outputs) |
|
|
|
|
|
|
|
|
x = enc_proj.unsqueeze(2) + dec_proj.unsqueeze(1) |
|
|
x = self.module.joint_net[0](x) |
|
|
x = self.module.joint_net[1](x) |
|
|
out = self.module.joint_net[2](x) |
|
|
return out |
|
|
|
|
|
|
|
|
class MelEncoderWrapper(torch.nn.Module): |
|
|
"""Fused wrapper: waveform -> mel -> encoder. |
|
|
|
|
|
Inputs: |
|
|
- audio_signal: [B, S] |
|
|
- audio_length: [B] |
|
|
|
|
|
Outputs: |
|
|
- encoder: [B, D, T_enc] |
|
|
- encoder_length: [B] |
|
|
- frame_times: [T_enc] |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, preprocessor: PreprocessorWrapper, encoder: EncoderWrapper |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.preprocessor = preprocessor |
|
|
self.encoder = encoder |
|
|
|
|
|
def forward( |
|
|
self, audio_signal: torch.Tensor, audio_length: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
mel, mel_length = self.preprocessor(audio_signal, audio_length) |
|
|
encoded, enc_len, frame_times = self.encoder(mel, mel_length.to(dtype=torch.int32)) |
|
|
return encoded, enc_len, frame_times |
|
|
|
|
|
|
|
|
class JointDecisionWrapper(torch.nn.Module): |
|
|
"""Joint + decision head: outputs label id and label prob. |
|
|
|
|
|
Unlike Parakeet TDT v3, this model does NOT have duration outputs. |
|
|
|
|
|
Inputs: |
|
|
- encoder_outputs: [B, D, T] |
|
|
- decoder_outputs: [B, D, U] |
|
|
|
|
|
Returns: |
|
|
- token_id: [B, T, U] int32 |
|
|
- token_prob: [B, T, U] float32 |
|
|
""" |
|
|
|
|
|
def __init__(self, joint: JointWrapper, vocab_size: int) -> None: |
|
|
super().__init__() |
|
|
self.joint = joint |
|
|
self.vocab_with_blank = int(vocab_size) + 1 |
|
|
|
|
|
def forward(self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor): |
|
|
logits = self.joint(encoder_outputs, decoder_outputs) |
|
|
|
|
|
|
|
|
token_ids = torch.argmax(logits, dim=-1).to(dtype=torch.int32) |
|
|
token_probs_all = torch.softmax(logits, dim=-1) |
|
|
|
|
|
token_prob = torch.gather( |
|
|
token_probs_all, dim=-1, index=token_ids.long().unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
return token_ids, token_prob |
|
|
|
|
|
|
|
|
class JointDecisionSingleStep(torch.nn.Module): |
|
|
"""Single-step variant for streaming: encoder_step -> token decision. |
|
|
|
|
|
Inputs: |
|
|
- encoder_step: [B=1, D, T=1] |
|
|
- decoder_step: [B=1, D, U=1] |
|
|
|
|
|
Returns: |
|
|
- token_id: [1, 1, 1] int32 |
|
|
- token_prob: [1, 1, 1] float32 |
|
|
- top_k_ids: [1, 1, 1, K] int32 |
|
|
- top_k_logits: [1, 1, 1, K] float32 |
|
|
""" |
|
|
|
|
|
def __init__(self, joint: JointWrapper, vocab_size: int, top_k: int = 64) -> None: |
|
|
super().__init__() |
|
|
self.joint = joint |
|
|
self.vocab_with_blank = int(vocab_size) + 1 |
|
|
self.top_k = int(top_k) |
|
|
|
|
|
def forward(self, encoder_step: torch.Tensor, decoder_step: torch.Tensor): |
|
|
|
|
|
logits = self.joint(encoder_step, decoder_step) |
|
|
|
|
|
token_ids = torch.argmax(logits, dim=-1, keepdim=False).to(dtype=torch.int32) |
|
|
token_probs_all = torch.softmax(logits, dim=-1) |
|
|
token_prob = torch.gather( |
|
|
token_probs_all, dim=-1, index=token_ids.long().unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
|
topk_logits, topk_ids_long = torch.topk( |
|
|
logits, k=min(self.top_k, logits.shape[-1]), dim=-1 |
|
|
) |
|
|
topk_ids = topk_ids_long.to(dtype=torch.int32) |
|
|
return token_ids, token_prob, topk_ids, topk_logits |
|
|
|
|
|
|
|
|
def _coreml_convert( |
|
|
traced: torch.jit.ScriptModule, |
|
|
inputs, |
|
|
outputs, |
|
|
settings: ExportSettings, |
|
|
compute_units_override: Optional[ct.ComputeUnit] = None, |
|
|
compute_precision: Optional[ct.precision] = None, |
|
|
) -> ct.models.MLModel: |
|
|
cu = ( |
|
|
compute_units_override |
|
|
if compute_units_override is not None |
|
|
else settings.compute_units |
|
|
) |
|
|
kwargs = { |
|
|
"convert_to": "mlprogram", |
|
|
"inputs": inputs, |
|
|
"outputs": outputs, |
|
|
"compute_units": cu, |
|
|
} |
|
|
print("Converting:", traced.__class__.__name__) |
|
|
print("Conversion kwargs:", kwargs) |
|
|
if settings.deployment_target is not None: |
|
|
kwargs["minimum_deployment_target"] = settings.deployment_target |
|
|
|
|
|
|
|
|
if compute_precision is not None: |
|
|
kwargs["compute_precision"] = compute_precision |
|
|
elif settings.compute_precision is not None: |
|
|
kwargs["compute_precision"] = settings.compute_precision |
|
|
|
|
|
return ct.convert(traced, **kwargs) |
|
|
|