|
|
|
|
|
"""Export Parakeet Realtime EOU RNNT components into CoreML. |
|
|
|
|
|
This model uses a cache-aware streaming FastConformer encoder. |
|
|
The encoder requires splitting into: |
|
|
1. Initial encoder (no cache, for first chunk) |
|
|
2. Streaming encoder (with cache inputs/outputs) |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import coremltools as ct |
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExportSettings: |
|
|
output_dir: Path |
|
|
compute_units: ct.ComputeUnit |
|
|
deployment_target: Optional[ct.target] |
|
|
compute_precision: Optional[ct.precision] |
|
|
max_audio_seconds: float |
|
|
max_symbol_steps: int |
|
|
|
|
|
chunk_size_frames: int |
|
|
cache_size: int |
|
|
|
|
|
|
|
|
class PreprocessorWrapper(torch.nn.Module): |
|
|
"""Wrapper for the preprocessor (mel spectrogram extraction).""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, audio_signal: torch.Tensor, length: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
mel, mel_length = self.module( |
|
|
input_signal=audio_signal, length=length.to(dtype=torch.long) |
|
|
) |
|
|
return mel, mel_length |
|
|
|
|
|
|
|
|
class EncoderInitialWrapper(torch.nn.Module): |
|
|
"""Encoder wrapper for the initial chunk (no cache input). |
|
|
|
|
|
This is used for the first chunk of audio where there's no previous cache. |
|
|
It outputs the encoder features and initial cache states. |
|
|
""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, features: torch.Tensor, length: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Forward pass for initial chunk without cache. |
|
|
|
|
|
Args: |
|
|
features: Mel spectrogram [B, D, T] |
|
|
length: Sequence lengths [B] |
|
|
|
|
|
Returns: |
|
|
encoded: Encoder output [B, D, T_enc] |
|
|
encoded_lengths: Output lengths [B] |
|
|
""" |
|
|
|
|
|
encoded, encoded_lengths = self.module( |
|
|
audio_signal=features, length=length.to(dtype=torch.long) |
|
|
) |
|
|
return encoded, encoded_lengths |
|
|
|
|
|
|
|
|
class EncoderStreamingWrapper(torch.nn.Module): |
|
|
"""Encoder wrapper for streaming with cache. |
|
|
|
|
|
This is used for subsequent chunks where cache states are available. |
|
|
It takes cache states as input and outputs updated cache states. |
|
|
""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
length: torch.Tensor, |
|
|
cache_last_channel: torch.Tensor, |
|
|
cache_last_time: torch.Tensor, |
|
|
cache_last_channel_len: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Forward pass with cache for streaming.""" |
|
|
|
|
|
cache_last_channel_t = cache_last_channel.transpose(0, 1) |
|
|
cache_last_time_t = cache_last_time.transpose(0, 1) |
|
|
cache_len_i64 = cache_last_channel_len.to(dtype=torch.int64) |
|
|
|
|
|
|
|
|
encoded, encoded_lengths, cache_ch_next, cache_t_next, cache_len_next = self.module( |
|
|
audio_signal=features, |
|
|
length=length.to(dtype=torch.long), |
|
|
cache_last_channel=cache_last_channel_t, |
|
|
cache_last_time=cache_last_time_t, |
|
|
cache_last_channel_len=cache_len_i64, |
|
|
) |
|
|
|
|
|
|
|
|
cache_ch_next = cache_ch_next.transpose(0, 1) |
|
|
cache_t_next = cache_t_next.transpose(0, 1) |
|
|
|
|
|
return ( |
|
|
encoded, |
|
|
encoded_lengths.to(dtype=torch.int32), |
|
|
cache_ch_next, |
|
|
cache_t_next, |
|
|
cache_len_next.to(dtype=torch.int32), |
|
|
) |
|
|
|
|
|
|
|
|
class DecoderWrapper(torch.nn.Module): |
|
|
"""Wrapper for the RNNT prediction network (decoder).""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
targets: torch.Tensor, |
|
|
target_lengths: torch.Tensor, |
|
|
h_in: torch.Tensor, |
|
|
c_in: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
state = [h_in, c_in] |
|
|
decoder_output, _, new_state = self.module( |
|
|
targets=targets.to(dtype=torch.long), |
|
|
target_length=target_lengths.to(dtype=torch.long), |
|
|
states=state, |
|
|
) |
|
|
return decoder_output, new_state[0], new_state[1] |
|
|
|
|
|
|
|
|
class JointWrapper(torch.nn.Module): |
|
|
"""Wrapper for the RNNT joint network.""" |
|
|
|
|
|
def __init__(self, module: torch.nn.Module) -> None: |
|
|
super().__init__() |
|
|
self.module = module |
|
|
|
|
|
def forward( |
|
|
self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
encoder_outputs = encoder_outputs.transpose(1, 2) |
|
|
decoder_outputs = decoder_outputs.transpose(1, 2) |
|
|
|
|
|
|
|
|
enc_proj = self.module.enc(encoder_outputs) |
|
|
dec_proj = self.module.pred(decoder_outputs) |
|
|
|
|
|
|
|
|
x = enc_proj.unsqueeze(2) + dec_proj.unsqueeze(1) |
|
|
x = self.module.joint_net[0](x) |
|
|
x = self.module.joint_net[1](x) |
|
|
out = self.module.joint_net[2](x) |
|
|
return out |
|
|
|
|
|
|
|
|
class MelEncoderWrapper(torch.nn.Module): |
|
|
"""Fused wrapper: waveform -> mel -> encoder (no cache, initial chunk).""" |
|
|
|
|
|
def __init__( |
|
|
self, preprocessor: PreprocessorWrapper, encoder: EncoderInitialWrapper |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.preprocessor = preprocessor |
|
|
self.encoder = encoder |
|
|
|
|
|
def forward( |
|
|
self, audio_signal: torch.Tensor, audio_length: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
mel, mel_length = self.preprocessor(audio_signal, audio_length) |
|
|
encoded, enc_len = self.encoder(mel, mel_length.to(dtype=torch.int32)) |
|
|
return encoded, enc_len |
|
|
|
|
|
|
|
|
class MelEncoderStreamingWrapper(torch.nn.Module): |
|
|
"""Fused wrapper: waveform -> mel -> encoder (with cache, streaming).""" |
|
|
|
|
|
def __init__( |
|
|
self, preprocessor: PreprocessorWrapper, encoder: EncoderStreamingWrapper |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.preprocessor = preprocessor |
|
|
self.encoder = encoder |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
audio_signal: torch.Tensor, |
|
|
audio_length: torch.Tensor, |
|
|
cache_last_channel: torch.Tensor, |
|
|
cache_last_time: torch.Tensor, |
|
|
cache_last_channel_len: torch.Tensor, |
|
|
) -> Tuple[ |
|
|
torch.Tensor, |
|
|
torch.Tensor, |
|
|
torch.Tensor, |
|
|
torch.Tensor, |
|
|
torch.Tensor, |
|
|
]: |
|
|
mel, mel_length = self.preprocessor(audio_signal, audio_length) |
|
|
return self.encoder( |
|
|
mel, |
|
|
mel_length.to(dtype=torch.int32), |
|
|
cache_last_channel, |
|
|
cache_last_time, |
|
|
cache_last_channel_len, |
|
|
) |
|
|
|
|
|
|
|
|
class JointDecisionWrapper(torch.nn.Module): |
|
|
"""Joint + decision head: outputs label id, label prob. |
|
|
|
|
|
Unlike TDT, EOU models don't have duration outputs. |
|
|
They have a special EOU token that marks end of utterance. |
|
|
""" |
|
|
|
|
|
def __init__(self, joint: JointWrapper, vocab_size: int) -> None: |
|
|
super().__init__() |
|
|
self.joint = joint |
|
|
self.vocab_with_blank = int(vocab_size) + 1 |
|
|
|
|
|
def forward( |
|
|
self, encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
logits = self.joint(encoder_outputs, decoder_outputs) |
|
|
token_logits = logits[..., : self.vocab_with_blank] |
|
|
|
|
|
|
|
|
token_ids = torch.argmax(token_logits, dim=-1).to(dtype=torch.int32) |
|
|
token_probs_all = torch.softmax(token_logits, dim=-1) |
|
|
token_prob = torch.gather( |
|
|
token_probs_all, dim=-1, index=token_ids.long().unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
return token_ids, token_prob |
|
|
|
|
|
|
|
|
class JointDecisionSingleStep(torch.nn.Module): |
|
|
"""Single-step variant for streaming: encoder_step [1, D, 1] -> [1,1,1]. |
|
|
|
|
|
Inputs: |
|
|
- encoder_step: [B=1, D, T=1] |
|
|
- decoder_step: [B=1, D, U=1] |
|
|
|
|
|
Returns: |
|
|
- token_id: [1, 1, 1] int32 |
|
|
- token_prob: [1, 1, 1] float32 |
|
|
- top_k_ids: [1, 1, 1, K] int32 |
|
|
- top_k_logits: [1, 1, 1, K] float32 |
|
|
""" |
|
|
|
|
|
def __init__(self, joint: JointWrapper, vocab_size: int, top_k: int = 64) -> None: |
|
|
super().__init__() |
|
|
self.joint = joint |
|
|
self.vocab_with_blank = int(vocab_size) + 1 |
|
|
self.top_k = int(top_k) |
|
|
|
|
|
def forward( |
|
|
self, encoder_step: torch.Tensor, decoder_step: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
logits = self.joint(encoder_step, decoder_step) |
|
|
token_logits = logits[..., : self.vocab_with_blank] |
|
|
|
|
|
token_ids = torch.argmax(token_logits, dim=-1, keepdim=False).to( |
|
|
dtype=torch.int32 |
|
|
) |
|
|
token_probs_all = torch.softmax(token_logits, dim=-1) |
|
|
token_prob = torch.gather( |
|
|
token_probs_all, dim=-1, index=token_ids.long().unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
|
topk_logits, topk_ids_long = torch.topk( |
|
|
token_logits, k=min(self.top_k, token_logits.shape[-1]), dim=-1 |
|
|
) |
|
|
topk_ids = topk_ids_long.to(dtype=torch.int32) |
|
|
return token_ids, token_prob, topk_ids, topk_logits |
|
|
|
|
|
|
|
|
def _coreml_convert( |
|
|
traced: torch.jit.ScriptModule, |
|
|
inputs, |
|
|
outputs, |
|
|
settings: ExportSettings, |
|
|
compute_units_override: Optional[ct.ComputeUnit] = None, |
|
|
) -> ct.models.MLModel: |
|
|
cu = ( |
|
|
compute_units_override |
|
|
if compute_units_override is not None |
|
|
else settings.compute_units |
|
|
) |
|
|
kwargs = { |
|
|
"convert_to": "mlprogram", |
|
|
"inputs": inputs, |
|
|
"outputs": outputs, |
|
|
"compute_units": cu, |
|
|
} |
|
|
print("Converting:", traced.__class__.__name__) |
|
|
print("Conversion kwargs:", kwargs) |
|
|
if settings.deployment_target is not None: |
|
|
kwargs["minimum_deployment_target"] = settings.deployment_target |
|
|
if settings.compute_precision is not None: |
|
|
kwargs["compute_precision"] = settings.compute_precision |
|
|
return ct.convert(traced, **kwargs) |
|
|
|