| | |
| | """ |
| | Convert Sortformer to CoreML with proper dynamic length handling. |
| | |
| | The key issue: Original conversion traced with fixed lengths (spkcache=120, fifo=40), |
| | but at runtime we need to handle empty state (spkcache=0, fifo=0) for first chunk. |
| | |
| | Solution: Use scripting instead of tracing, or trace with multiple example lengths. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import coremltools as ct |
| | import numpy as np |
| | import os |
| | import sys |
| |
|
| | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | sys.path.insert(0, os.path.join(SCRIPT_DIR, 'NeMo')) |
| |
|
| | from nemo.collections.asr.models import SortformerEncLabelModel |
| |
|
| | print("=" * 70) |
| | print("CONVERTING SORTFORMER WITH DYNAMIC LENGTH SUPPORT") |
| | print("=" * 70) |
| |
|
| | |
| | model_path = os.path.join(SCRIPT_DIR, 'diar_streaming_sortformer_4spk-v2.nemo') |
| | print(f"Loading model: {model_path}") |
| | model = SortformerEncLabelModel.restore_from(model_path, map_location='cpu', strict=False) |
| | model.eval() |
| |
|
| | |
| | modules = model.sortformer_modules |
| | modules.chunk_len = 6 |
| | modules.chunk_left_context = 1 |
| | modules.chunk_right_context = 1 |
| | modules.fifo_len = 40 |
| | modules.spkcache_len = 120 |
| | modules.spkcache_update_period = 30 |
| |
|
| | print(f"Config: chunk_len={modules.chunk_len}, left={modules.chunk_left_context}, right={modules.chunk_right_context}") |
| | print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") |
| |
|
| | |
| | chunk_frames = (modules.chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor |
| | fc_d_model = modules.fc_d_model |
| | feat_dim = 128 |
| |
|
| | print(f"Chunk frames: {chunk_frames}") |
| |
|
| | class DynamicPreEncoderWrapper(nn.Module): |
| | """Pre-encoder that properly handles dynamic lengths.""" |
| |
|
| | def __init__(self, model, max_spkcache=120, max_fifo=40, max_chunk=8): |
| | super().__init__() |
| | self.model = model |
| | self.max_spkcache = max_spkcache |
| | self.max_fifo = max_fifo |
| | self.max_chunk = max_chunk |
| | self.max_total = max_spkcache + max_fifo + max_chunk |
| |
|
| | def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
| | |
| | chunk_embs, chunk_emb_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
| |
|
| | |
| | spk_len = spkcache_lengths[0].item() if spkcache_lengths.numel() > 0 else 0 |
| | fifo_len = fifo_lengths[0].item() if fifo_lengths.numel() > 0 else 0 |
| | chunk_len = chunk_emb_lengths[0].item() |
| | total_len = spk_len + fifo_len + chunk_len |
| |
|
| | |
| | B, _, D = spkcache.shape |
| | output = torch.zeros(B, self.max_total, D, device=chunk.device, dtype=chunk.dtype) |
| |
|
| | |
| | if spk_len > 0: |
| | output[:, :spk_len, :] = spkcache[:, :spk_len, :] |
| | if fifo_len > 0: |
| | output[:, spk_len:spk_len+fifo_len, :] = fifo[:, :fifo_len, :] |
| | output[:, spk_len+fifo_len:spk_len+fifo_len+chunk_len, :] = chunk_embs[:, :chunk_len, :] |
| |
|
| | total_length = torch.tensor([total_len], dtype=torch.long) |
| |
|
| | return output, total_length, chunk_embs, chunk_emb_lengths |
| |
|
| |
|
| | class DynamicHeadWrapper(nn.Module): |
| | """Head that properly handles dynamic lengths with masking.""" |
| |
|
| | def __init__(self, model): |
| | super().__init__() |
| | self.model = model |
| |
|
| | def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs, chunk_emb_lengths): |
| | |
| | fc_embs, fc_lengths = self.model.frontend_encoder( |
| | processed_signal=pre_encoder_embs, |
| | processed_signal_length=pre_encoder_lengths, |
| | bypass_pre_encode=True, |
| | ) |
| |
|
| | |
| | preds = self.model.forward_infer(fc_embs, fc_lengths) |
| |
|
| | |
| | |
| | max_len = preds.shape[1] |
| | length = pre_encoder_lengths[0] |
| | mask = torch.arange(max_len, device=preds.device) < length |
| | preds = preds * mask.unsqueeze(0).unsqueeze(-1).float() |
| |
|
| | return preds, chunk_embs, chunk_emb_lengths |
| |
|
| |
|
| | |
| | print("\n" + "=" * 70) |
| | print("TESTING DYNAMIC WRAPPERS") |
| | print("=" * 70) |
| |
|
| | pre_encoder = DynamicPreEncoderWrapper(model) |
| | head = DynamicHeadWrapper(model) |
| | pre_encoder.eval() |
| | head.eval() |
| |
|
| | |
| | print("\nTest 1: Empty state (chunk 0)") |
| | chunk = torch.randn(1, 56, 128) |
| | chunk_len = torch.tensor([56], dtype=torch.long) |
| | spkcache = torch.zeros(1, 120, 512) |
| | spkcache_len = torch.tensor([0], dtype=torch.long) |
| | fifo = torch.zeros(1, 40, 512) |
| | fifo_len = torch.tensor([0], dtype=torch.long) |
| |
|
| | with torch.no_grad(): |
| | pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( |
| | chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len |
| | ) |
| | preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) |
| |
|
| | print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") |
| | print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") |
| | print(f" Predictions: {preds.shape}") |
| | sums = [f"{preds[0, i, :].sum().item():.4f}" for i in range(min(8, preds.shape[1]))] |
| | print(f" First 8 pred frames sum: {sums}") |
| |
|
| | |
| | print("\nTest 2: Full state") |
| | chunk = torch.randn(1, 64, 128) |
| | chunk_len = torch.tensor([64], dtype=torch.long) |
| | spkcache = torch.randn(1, 120, 512) |
| | spkcache_len = torch.tensor([120], dtype=torch.long) |
| | fifo = torch.randn(1, 40, 512) |
| | fifo_len = torch.tensor([40], dtype=torch.long) |
| |
|
| | with torch.no_grad(): |
| | pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( |
| | chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len |
| | ) |
| | preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) |
| |
|
| | print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") |
| | print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") |
| | print(f" Predictions: {preds.shape}") |
| |
|
| | print("\n" + "=" * 70) |
| | print("ISSUE IDENTIFIED") |
| | print("=" * 70) |
| | print(""" |
| | The problem is that the current CoreML model was traced with FIXED lengths. |
| | When lengths change at runtime, the traced operations don't adapt. |
| | |
| | The fix requires re-tracing with proper dynamic handling OR using coremltools |
| | flexible shapes feature. |
| | |
| | For now, let's try a simpler approach: always pad inputs to max size and |
| | use the length parameters only for extracting the correct output slice. |
| | """) |
| |
|
| | |
| | |
| | |
| |
|
| | print("\nATTEMPTING CONVERSION WITH FLEXIBLE SHAPES...") |
| |
|
| | |
| | try: |
| | |
| | class SimplePipelineWrapper(nn.Module): |
| | def __init__(self, model): |
| | super().__init__() |
| | self.model = model |
| |
|
| | def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): |
| | |
| | chunk_embs, chunk_emb_lens = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) |
| |
|
| | |
| | spk_len = spkcache_lengths[0] |
| | fifo_len = fifo_lengths[0] |
| | chunk_len = chunk_emb_lens[0] |
| |
|
| | |
| | |
| | B = chunk.shape[0] |
| | max_out = 168 |
| | D = 512 |
| |
|
| | concat_embs = torch.zeros(B, max_out, D, device=chunk.device, dtype=chunk.dtype) |
| |
|
| | |
| | for i in range(120): |
| | if i < spk_len: |
| | concat_embs[:, i, :] = spkcache[:, i, :] |
| |
|
| | |
| | for i in range(40): |
| | if i < fifo_len: |
| | concat_embs[:, 120 + i, :] = fifo[:, i, :] |
| |
|
| | |
| | for i in range(8): |
| | if i < chunk_len: |
| | concat_embs[:, 120 + 40 + i, :] = chunk_embs[:, i, :] |
| |
|
| | total_len = spk_len + fifo_len + chunk_len |
| | total_lens = total_len.unsqueeze(0) |
| |
|
| | |
| | fc_embs, fc_lens = self.model.frontend_encoder( |
| | processed_signal=concat_embs, |
| | processed_signal_length=total_lens, |
| | bypass_pre_encode=True, |
| | ) |
| |
|
| | |
| | preds = self.model.forward_infer(fc_embs, fc_lens) |
| |
|
| | return preds, chunk_embs, chunk_emb_lens |
| |
|
| | wrapper = SimplePipelineWrapper(model) |
| | wrapper.eval() |
| |
|
| | |
| | print("Tracing with empty state example...") |
| | chunk = torch.randn(1, 64, 128) |
| | chunk_len = torch.tensor([56], dtype=torch.long) |
| | spkcache = torch.zeros(1, 120, 512) |
| | spkcache_len = torch.tensor([0], dtype=torch.long) |
| | fifo = torch.zeros(1, 40, 512) |
| | fifo_len = torch.tensor([0], dtype=torch.long) |
| |
|
| | with torch.no_grad(): |
| | traced = torch.jit.trace(wrapper, (chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len)) |
| |
|
| | print("Converting to CoreML...") |
| | mlmodel = ct.convert( |
| | traced, |
| | inputs=[ |
| | ct.TensorType(name="chunk", shape=(1, 64, 128), dtype=np.float32), |
| | ct.TensorType(name="chunk_lengths", shape=(1,), dtype=np.int32), |
| | ct.TensorType(name="spkcache", shape=(1, 120, 512), dtype=np.float32), |
| | ct.TensorType(name="spkcache_lengths", shape=(1,), dtype=np.int32), |
| | ct.TensorType(name="fifo", shape=(1, 40, 512), dtype=np.float32), |
| | ct.TensorType(name="fifo_lengths", shape=(1,), dtype=np.int32), |
| | ], |
| | outputs=[ |
| | ct.TensorType(name="speaker_preds", dtype=np.float32), |
| | ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32), |
| | ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32), |
| | ], |
| | minimum_deployment_target=ct.target.iOS16, |
| | compute_precision=ct.precision.FLOAT32, |
| | compute_units=ct.ComputeUnit.CPU_ONLY, |
| | ) |
| |
|
| | output_path = os.path.join(SCRIPT_DIR, 'coreml_models', 'SortformerPipeline_Dynamic.mlpackage') |
| | mlmodel.save(output_path) |
| | print(f"Saved to: {output_path}") |
| |
|
| | |
| | print("\nTesting new CoreML model...") |
| | test_output = mlmodel.predict({ |
| | 'chunk': chunk.numpy(), |
| | 'chunk_lengths': chunk_len.numpy().astype(np.int32), |
| | 'spkcache': spkcache.numpy(), |
| | 'spkcache_lengths': spkcache_len.numpy().astype(np.int32), |
| | 'fifo': fifo.numpy(), |
| | 'fifo_lengths': fifo_len.numpy().astype(np.int32), |
| | }) |
| |
|
| | coreml_preds = np.array(test_output['speaker_preds']) |
| | print(f"CoreML predictions shape: {coreml_preds.shape}") |
| | print(f"CoreML first 8 frames:") |
| | for i in range(min(8, coreml_preds.shape[1])): |
| | print(f" Frame {i}: {coreml_preds[0, i, :]}") |
| |
|
| | except Exception as e: |
| | print(f"Error during conversion: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|