|
|
|
|
|
""" |
|
|
Convert Sortformer to CoreML with proper dynamic length handling. |
|
|
|
|
|
The key issue: Original conversion traced with fixed lengths (spkcache=120, fifo=40), |
|
|
but at runtime we need to handle empty state (spkcache=0, fifo=0) for first chunk. |
|
|
|
|
|
Solution: Use scripting instead of tracing, or trace with multiple example lengths. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import coremltools as ct |
|
|
import numpy as np |
|
|
import os |
|
|
import sys |
|
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
sys.path.insert(0, os.path.join(SCRIPT_DIR, 'NeMo')) |
|
|
|
|
|
from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
|
|
|
print("=" * 70) |
|
|
print("CONVERTING SORTFORMER WITH DYNAMIC LENGTH SUPPORT") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
model_path = os.path.join(SCRIPT_DIR, 'diar_streaming_sortformer_4spk-v2.nemo') |
|
|
print(f"Loading model: {model_path}") |
|
|
model = SortformerEncLabelModel.restore_from(model_path, map_location='cpu', strict=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
modules = model.sortformer_modules |
|
|
modules.chunk_len = 6 |
|
|
modules.chunk_left_context = 1 |
|
|
modules.chunk_right_context = 1 |
|
|
modules.fifo_len = 40 |
|
|
modules.spkcache_len = 120 |
|
|
modules.spkcache_update_period = 30 |
|
|
|
|
|
print(f"Config: chunk_len={modules.chunk_len}, left={modules.chunk_left_context}, right={modules.chunk_right_context}") |
|
|
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") |
|
|
|
|
|
|
|
|
chunk_frames = (modules.chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor |
|
|
fc_d_model = modules.fc_d_model |
|
|
feat_dim = 128 |
|
|
|
|
|
print(f"Chunk frames: {chunk_frames}") |
|
|
|
|
|
class DynamicPreEncoderWrapper(nn.Module): |
|
|
"""Pre-encoder that properly handles dynamic lengths.""" |
|
|
|
|
|
def __init__(self, model, max_spkcache=120, max_fifo=40, max_chunk=8): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.max_spkcache = max_spkcache |
|
|
self.max_fifo = max_fifo |
|
|
self.max_chunk = max_chunk |
|
|
self.max_total = max_spkcache + max_fifo + max_chunk |
|
|
|
|
|
def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
|
|
|
|
|
chunk_embs, chunk_emb_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
|
|
|
|
|
|
|
|
spk_len = spkcache_lengths[0].item() if spkcache_lengths.numel() > 0 else 0 |
|
|
fifo_len = fifo_lengths[0].item() if fifo_lengths.numel() > 0 else 0 |
|
|
chunk_len = chunk_emb_lengths[0].item() |
|
|
total_len = spk_len + fifo_len + chunk_len |
|
|
|
|
|
|
|
|
B, _, D = spkcache.shape |
|
|
output = torch.zeros(B, self.max_total, D, device=chunk.device, dtype=chunk.dtype) |
|
|
|
|
|
|
|
|
if spk_len > 0: |
|
|
output[:, :spk_len, :] = spkcache[:, :spk_len, :] |
|
|
if fifo_len > 0: |
|
|
output[:, spk_len:spk_len+fifo_len, :] = fifo[:, :fifo_len, :] |
|
|
output[:, spk_len+fifo_len:spk_len+fifo_len+chunk_len, :] = chunk_embs[:, :chunk_len, :] |
|
|
|
|
|
total_length = torch.tensor([total_len], dtype=torch.long) |
|
|
|
|
|
return output, total_length, chunk_embs, chunk_emb_lengths |
|
|
|
|
|
|
|
|
class DynamicHeadWrapper(nn.Module): |
|
|
"""Head that properly handles dynamic lengths with masking.""" |
|
|
|
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs, chunk_emb_lengths): |
|
|
|
|
|
fc_embs, fc_lengths = self.model.frontend_encoder( |
|
|
processed_signal=pre_encoder_embs, |
|
|
processed_signal_length=pre_encoder_lengths, |
|
|
bypass_pre_encode=True, |
|
|
) |
|
|
|
|
|
|
|
|
preds = self.model.forward_infer(fc_embs, fc_lengths) |
|
|
|
|
|
|
|
|
|
|
|
max_len = preds.shape[1] |
|
|
length = pre_encoder_lengths[0] |
|
|
mask = torch.arange(max_len, device=preds.device) < length |
|
|
preds = preds * mask.unsqueeze(0).unsqueeze(-1).float() |
|
|
|
|
|
return preds, chunk_embs, chunk_emb_lengths |
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TESTING DYNAMIC WRAPPERS") |
|
|
print("=" * 70) |
|
|
|
|
|
pre_encoder = DynamicPreEncoderWrapper(model) |
|
|
head = DynamicHeadWrapper(model) |
|
|
pre_encoder.eval() |
|
|
head.eval() |
|
|
|
|
|
|
|
|
print("\nTest 1: Empty state (chunk 0)") |
|
|
chunk = torch.randn(1, 56, 128) |
|
|
chunk_len = torch.tensor([56], dtype=torch.long) |
|
|
spkcache = torch.zeros(1, 120, 512) |
|
|
spkcache_len = torch.tensor([0], dtype=torch.long) |
|
|
fifo = torch.zeros(1, 40, 512) |
|
|
fifo_len = torch.tensor([0], dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( |
|
|
chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len |
|
|
) |
|
|
preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) |
|
|
|
|
|
print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") |
|
|
print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") |
|
|
print(f" Predictions: {preds.shape}") |
|
|
sums = [f"{preds[0, i, :].sum().item():.4f}" for i in range(min(8, preds.shape[1]))] |
|
|
print(f" First 8 pred frames sum: {sums}") |
|
|
|
|
|
|
|
|
print("\nTest 2: Full state") |
|
|
chunk = torch.randn(1, 64, 128) |
|
|
chunk_len = torch.tensor([64], dtype=torch.long) |
|
|
spkcache = torch.randn(1, 120, 512) |
|
|
spkcache_len = torch.tensor([120], dtype=torch.long) |
|
|
fifo = torch.randn(1, 40, 512) |
|
|
fifo_len = torch.tensor([40], dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( |
|
|
chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len |
|
|
) |
|
|
preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) |
|
|
|
|
|
print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") |
|
|
print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") |
|
|
print(f" Predictions: {preds.shape}") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("ISSUE IDENTIFIED") |
|
|
print("=" * 70) |
|
|
print(""" |
|
|
The problem is that the current CoreML model was traced with FIXED lengths. |
|
|
When lengths change at runtime, the traced operations don't adapt. |
|
|
|
|
|
The fix requires re-tracing with proper dynamic handling OR using coremltools |
|
|
flexible shapes feature. |
|
|
|
|
|
For now, let's try a simpler approach: always pad inputs to max size and |
|
|
use the length parameters only for extracting the correct output slice. |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nATTEMPTING CONVERSION WITH FLEXIBLE SHAPES...") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
class SimplePipelineWrapper(nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
|
|
|
|
|
chunk_embs, chunk_emb_lens = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
|
|
|
|
|
|
|
|
spk_len = spkcache_lengths[0] |
|
|
fifo_len = fifo_lengths[0] |
|
|
chunk_len = chunk_emb_lens[0] |
|
|
|
|
|
|
|
|
|
|
|
B = chunk.shape[0] |
|
|
max_out = 168 |
|
|
D = 512 |
|
|
|
|
|
concat_embs = torch.zeros(B, max_out, D, device=chunk.device, dtype=chunk.dtype) |
|
|
|
|
|
|
|
|
for i in range(120): |
|
|
if i < spk_len: |
|
|
concat_embs[:, i, :] = spkcache[:, i, :] |
|
|
|
|
|
|
|
|
for i in range(40): |
|
|
if i < fifo_len: |
|
|
concat_embs[:, 120 + i, :] = fifo[:, i, :] |
|
|
|
|
|
|
|
|
for i in range(8): |
|
|
if i < chunk_len: |
|
|
concat_embs[:, 120 + 40 + i, :] = chunk_embs[:, i, :] |
|
|
|
|
|
total_len = spk_len + fifo_len + chunk_len |
|
|
total_lens = total_len.unsqueeze(0) |
|
|
|
|
|
|
|
|
fc_embs, fc_lens = self.model.frontend_encoder( |
|
|
processed_signal=concat_embs, |
|
|
processed_signal_length=total_lens, |
|
|
bypass_pre_encode=True, |
|
|
) |
|
|
|
|
|
|
|
|
preds = self.model.forward_infer(fc_embs, fc_lens) |
|
|
|
|
|
return preds, chunk_embs, chunk_emb_lens |
|
|
|
|
|
wrapper = SimplePipelineWrapper(model) |
|
|
wrapper.eval() |
|
|
|
|
|
|
|
|
print("Tracing with empty state example...") |
|
|
chunk = torch.randn(1, 64, 128) |
|
|
chunk_len = torch.tensor([56], dtype=torch.long) |
|
|
spkcache = torch.zeros(1, 120, 512) |
|
|
spkcache_len = torch.tensor([0], dtype=torch.long) |
|
|
fifo = torch.zeros(1, 40, 512) |
|
|
fifo_len = torch.tensor([0], dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
|
traced = torch.jit.trace(wrapper, (chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len)) |
|
|
|
|
|
print("Converting to CoreML...") |
|
|
mlmodel = ct.convert( |
|
|
traced, |
|
|
inputs=[ |
|
|
ct.TensorType(name="chunk", shape=(1, 64, 128), dtype=np.float32), |
|
|
ct.TensorType(name="chunk_lengths", shape=(1,), dtype=np.int32), |
|
|
ct.TensorType(name="spkcache", shape=(1, 120, 512), dtype=np.float32), |
|
|
ct.TensorType(name="spkcache_lengths", shape=(1,), dtype=np.int32), |
|
|
ct.TensorType(name="fifo", shape=(1, 40, 512), dtype=np.float32), |
|
|
ct.TensorType(name="fifo_lengths", shape=(1,), dtype=np.int32), |
|
|
], |
|
|
outputs=[ |
|
|
ct.TensorType(name="speaker_preds", dtype=np.float32), |
|
|
ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32), |
|
|
ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32), |
|
|
], |
|
|
minimum_deployment_target=ct.target.iOS16, |
|
|
compute_precision=ct.precision.FLOAT32, |
|
|
compute_units=ct.ComputeUnit.CPU_ONLY, |
|
|
) |
|
|
|
|
|
output_path = os.path.join(SCRIPT_DIR, 'coreml_models', 'SortformerPipeline_Dynamic.mlpackage') |
|
|
mlmodel.save(output_path) |
|
|
print(f"Saved to: {output_path}") |
|
|
|
|
|
|
|
|
print("\nTesting new CoreML model...") |
|
|
test_output = mlmodel.predict({ |
|
|
'chunk': chunk.numpy(), |
|
|
'chunk_lengths': chunk_len.numpy().astype(np.int32), |
|
|
'spkcache': spkcache.numpy(), |
|
|
'spkcache_lengths': spkcache_len.numpy().astype(np.int32), |
|
|
'fifo': fifo.numpy(), |
|
|
'fifo_lengths': fifo_len.numpy().astype(np.int32), |
|
|
}) |
|
|
|
|
|
coreml_preds = np.array(test_output['speaker_preds']) |
|
|
print(f"CoreML predictions shape: {coreml_preds.shape}") |
|
|
print(f"CoreML first 8 frames:") |
|
|
for i in range(min(8, coreml_preds.shape[1])): |
|
|
print(f" Frame {i}: {coreml_preds[0, i, :]}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during conversion: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|