|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as patches |
|
|
import matplotlib |
|
|
import seaborn as sns |
|
|
import numpy as np |
|
|
import threading |
|
|
import onnx2torch |
|
|
import onnxscript |
|
|
from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
from pydub import AudioSegment |
|
|
import coremltools as ct |
|
|
from pydub.playback import play as play_audio |
|
|
|
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
audio_file = "audio.wav" |
|
|
|
|
|
|
|
|
print("Loading audio file for playback...") |
|
|
full_audio = AudioSegment.from_wav(audio_file) |
|
|
|
|
|
|
|
|
model = SortformerEncLabelModel.from_pretrained( |
|
|
"nvidia/diar_streaming_sortformer_4spk-v2.1", |
|
|
map_location=device |
|
|
) |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
print(model.output_names) |
|
|
|
|
|
def streaming_input_examples(self): |
|
|
"""Input tensor examples for exporting streaming version of model""" |
|
|
batch_size = 4 |
|
|
feat_in = self.cfg.get("preprocessor", {}).get("features", 128) |
|
|
chunk = torch.rand([batch_size, 120, feat_in]).to(self.device) |
|
|
chunk_lengths = torch.tensor([120] * batch_size).to(self.device) |
|
|
spkcache = torch.randn([batch_size, 188, 512]).to(self.device) |
|
|
spkcache_lengths = torch.tensor([40, 188, 0, 68]).to(self.device) |
|
|
fifo = torch.randn([batch_size, 188, 512]).to(self.device) |
|
|
fifo_lengths = torch.tensor([50, 88, 0, 90]).to(self.device) |
|
|
return chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths |
|
|
|
|
|
|
|
|
inputs = streaming_input_examples(model) |
|
|
|
|
|
export_out = model.export("streaming-sortformer.onnx", input_example=inputs) |
|
|
scripted_model = onnx2torch.convert('streaming-sortformer.onnx') |
|
|
|
|
|
BATCH_SIZE = 4 |
|
|
CHUNK_LEN = 120 |
|
|
FEAT_DIM = 128 |
|
|
CACHE_LEN = 188 |
|
|
EMBED_DIM = 512 |
|
|
|
|
|
ct_inputs = [ |
|
|
ct.TensorType(name="chunk", shape=(BATCH_SIZE, CHUNK_LEN, FEAT_DIM)), |
|
|
ct.TensorType(name="chunk_lens", shape=(BATCH_SIZE,)), |
|
|
ct.TensorType(name="spkcache", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)), |
|
|
ct.TensorType(name="spkcache_lens", shape=(BATCH_SIZE,)), |
|
|
ct.TensorType(name="fifo", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)), |
|
|
ct.TensorType(name="fifo_lens", shape=(BATCH_SIZE,)), |
|
|
] |
|
|
|
|
|
ct_outputs = [ |
|
|
ct.TensorType(name="preds"), |
|
|
ct.TensorType(name="new_spkcache"), |
|
|
ct.TensorType(name="new_spkcache_lens"), |
|
|
ct.TensorType(name="new_fifo"), |
|
|
ct.TensorType(name="new_fifo_lens"), |
|
|
] |
|
|
|
|
|
|
|
|
ct.convert( |
|
|
scripted_model, |
|
|
inputs=ct_inputs, |
|
|
outputs=ct_outputs, |
|
|
convert_to="mlprogram", |
|
|
minimum_deployment_target=ct.target.iOS17, |
|
|
compute_precision=ct.precision.FLOAT16, |
|
|
) |
|
|
|