|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import coremltools as ct |
|
|
import numpy as np |
|
|
import typer |
|
|
from pathlib import Path |
|
|
from typing import Tuple, List, Optional |
|
|
import json |
|
|
import shutil |
|
|
|
|
|
|
|
|
import coremltools as ct |
|
|
import numpy as np |
|
|
import argparse |
|
|
from nemo.collections.asr.models import EncDecRNNTBPEModel |
|
|
|
|
|
app = typer.Typer() |
|
|
|
|
|
class LoopbackEncoderWrapper(nn.Module): |
|
|
""" |
|
|
Wraps the entire Parakeet Encoder (PreEncode + Conformer) for CoreML Loopback Streaming. |
|
|
|
|
|
Inputs: |
|
|
- audio_signal: [B, D, T] (Mel spectrogram chunk) |
|
|
- audio_length: [B] |
|
|
- pre_cache: [B, D, pre_cache_size] (Previous audio context) |
|
|
- cache_last_channel: [layers, B, cache_size, hidden] |
|
|
- cache_last_time: [layers, B, hidden, time_cache] |
|
|
- cache_last_channel_len: [B] |
|
|
|
|
|
Outputs: |
|
|
- encoded_output: [B, D_out, T_out] |
|
|
- encoded_length: [B] |
|
|
- new_pre_cache: [B, D, pre_cache_size] |
|
|
- new_cache_last_channel |
|
|
- new_cache_last_time |
|
|
- new_cache_last_channel_len |
|
|
""" |
|
|
def __init__(self, encoder, pre_cache_size=16): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.pre_cache_size = pre_cache_size |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
audio_signal: torch.Tensor, |
|
|
audio_length: torch.Tensor, |
|
|
pre_cache: torch.Tensor, |
|
|
cache_last_channel: torch.Tensor, |
|
|
cache_last_time: torch.Tensor, |
|
|
cache_last_channel_len: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_input = torch.cat([pre_cache, audio_signal], dim=2) |
|
|
full_length = audio_length + self.pre_cache_size |
|
|
|
|
|
|
|
|
|
|
|
new_pre_cache = full_input[:, :, -self.pre_cache_size:] |
|
|
|
|
|
|
|
|
|
|
|
current_cache = [cache_last_channel, cache_last_time, cache_last_channel_len] |
|
|
|
|
|
encoded, encoded_len, new_cache_channel, new_cache_time, new_cache_len = self.encoder.cache_aware_stream_step( |
|
|
processed_signal=full_input, |
|
|
processed_signal_length=full_length, |
|
|
cache_last_channel=cache_last_channel, |
|
|
cache_last_time=cache_last_time, |
|
|
cache_last_channel_len=cache_last_channel_len |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoded_len_32 = encoded_len.to(dtype=torch.int32) |
|
|
new_channel_len_32 = new_cache_len.to(dtype=torch.int32) |
|
|
|
|
|
return encoded, encoded_len_32, new_pre_cache, new_cache_channel, new_cache_time, new_channel_len_32 |
|
|
|
|
|
def _coreml_convert( |
|
|
traced_model, |
|
|
inputs, |
|
|
outputs, |
|
|
compute_units=ct.ComputeUnit.CPU_ONLY |
|
|
): |
|
|
return ct.convert( |
|
|
traced_model, |
|
|
inputs=inputs, |
|
|
outputs=outputs, |
|
|
compute_units=compute_units, |
|
|
minimum_deployment_target=ct.target.macOS14, |
|
|
) |
|
|
|
|
|
def main(): |
|
|
model_id: str = "nvidia/parakeet_realtime_eou_120m-v1" |
|
|
output_dir: str = "temp_swift_models/StreamingLoopback" |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"Loading model: {model_id}...") |
|
|
asr_model = EncDecRNNTBPEModel.from_pretrained(model_name=model_id) |
|
|
asr_model.eval() |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--chunk-frames", type=int, default=17, help="Number of frames in the input chunk (e.g. 17 for 160ms, 129 for 1.28s)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
encoder = asr_model.encoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunk_size_in = args.chunk_frames |
|
|
mel_dim = 128 |
|
|
hidden_dim = encoder.d_model |
|
|
num_layers = len(encoder.layers) |
|
|
|
|
|
|
|
|
cache_channel_size = 70 |
|
|
cache_time_size = 8 |
|
|
pre_cache_size = 16 |
|
|
|
|
|
print(f"Config: Chunk={chunk_size_in}, Mel={mel_dim}, Hidden={hidden_dim}, Layers={num_layers}") |
|
|
print(f"Cache: Channel={cache_channel_size}, Time={cache_time_size}, Pre={pre_cache_size}") |
|
|
|
|
|
|
|
|
wrapper = LoopbackEncoderWrapper(encoder, pre_cache_size=pre_cache_size) |
|
|
wrapper.eval() |
|
|
|
|
|
|
|
|
batch_size = 1 |
|
|
test_mel = torch.randn(batch_size, mel_dim, chunk_size_in) |
|
|
test_mel_len = torch.tensor([chunk_size_in], dtype=torch.int32) |
|
|
test_pre_cache = torch.zeros(batch_size, mel_dim, pre_cache_size) |
|
|
|
|
|
|
|
|
test_cache_channel = torch.zeros(num_layers, batch_size, cache_channel_size, hidden_dim) |
|
|
test_cache_time = torch.zeros(num_layers, batch_size, hidden_dim, cache_time_size) |
|
|
test_cache_len = torch.zeros(batch_size, dtype=torch.int32) |
|
|
|
|
|
print("Tracing model...") |
|
|
traced_model = torch.jit.trace( |
|
|
wrapper, |
|
|
(test_mel, test_mel_len, test_pre_cache, test_cache_channel, test_cache_time, test_cache_len), |
|
|
strict=False |
|
|
) |
|
|
|
|
|
|
|
|
print("Converting to CoreML...") |
|
|
|
|
|
inputs = [ |
|
|
ct.TensorType(name="audio_signal", shape=(1, 128, chunk_size_in), dtype=np.float32), |
|
|
ct.TensorType(name="audio_length", shape=(1,), dtype=np.int32), |
|
|
ct.TensorType(name="pre_cache", shape=(1, 128, pre_cache_size), dtype=np.float32), |
|
|
ct.TensorType(name="cache_last_channel", shape=(num_layers, 1, cache_channel_size, hidden_dim), dtype=np.float32), |
|
|
ct.TensorType(name="cache_last_time", shape=(num_layers, 1, hidden_dim, cache_time_size), dtype=np.float32), |
|
|
ct.TensorType(name="cache_last_channel_len", shape=(1,), dtype=np.int32), |
|
|
] |
|
|
|
|
|
outputs = [ |
|
|
ct.TensorType(name="encoded_output", dtype=np.float32), |
|
|
ct.TensorType(name="encoded_length", dtype=np.int32), |
|
|
ct.TensorType(name="new_pre_cache", dtype=np.float32), |
|
|
ct.TensorType(name="new_cache_last_channel", dtype=np.float32), |
|
|
ct.TensorType(name="new_cache_last_time", dtype=np.float32), |
|
|
ct.TensorType(name="new_cache_last_channel_len", dtype=np.int32), |
|
|
] |
|
|
|
|
|
mlmodel = _coreml_convert(traced_model, inputs, outputs) |
|
|
|
|
|
save_path = output_path / "streaming_encoder.mlpackage" |
|
|
mlmodel.save(str(save_path)) |
|
|
print(f"Saved: {save_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|