#!/usr/bin/env python3 """ Convert Sortformer to CoreML with proper dynamic length handling. The key issue: Original conversion traced with fixed lengths (spkcache=120, fifo=40), but at runtime we need to handle empty state (spkcache=0, fifo=0) for first chunk. Solution: Use scripting instead of tracing, or trace with multiple example lengths. """ import torch import torch.nn as nn import coremltools as ct import numpy as np import os import sys SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(SCRIPT_DIR, 'NeMo')) from nemo.collections.asr.models import SortformerEncLabelModel print("=" * 70) print("CONVERTING SORTFORMER WITH DYNAMIC LENGTH SUPPORT") print("=" * 70) # Load model model_path = os.path.join(SCRIPT_DIR, 'diar_streaming_sortformer_4spk-v2.nemo') print(f"Loading model: {model_path}") model = SortformerEncLabelModel.restore_from(model_path, map_location='cpu', strict=False) model.eval() # Configure for low-latency streaming modules = model.sortformer_modules modules.chunk_len = 6 modules.chunk_left_context = 1 modules.chunk_right_context = 1 modules.fifo_len = 40 modules.spkcache_len = 120 modules.spkcache_update_period = 30 print(f"Config: chunk_len={modules.chunk_len}, left={modules.chunk_left_context}, right={modules.chunk_right_context}") print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") # Dimensions chunk_frames = (modules.chunk_len + modules.chunk_left_context + modules.chunk_right_context) * modules.subsampling_factor fc_d_model = modules.fc_d_model # 512 feat_dim = 128 print(f"Chunk frames: {chunk_frames}") class DynamicPreEncoderWrapper(nn.Module): """Pre-encoder that properly handles dynamic lengths.""" def __init__(self, model, max_spkcache=120, max_fifo=40, max_chunk=8): super().__init__() self.model = model self.max_spkcache = max_spkcache self.max_fifo = max_fifo self.max_chunk = max_chunk self.max_total = max_spkcache + max_fifo + max_chunk def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): # Pre-encode the chunk chunk_embs, chunk_emb_lengths = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) # Get actual lengths as scalars spk_len = spkcache_lengths[0].item() if spkcache_lengths.numel() > 0 else 0 fifo_len = fifo_lengths[0].item() if fifo_lengths.numel() > 0 else 0 chunk_len = chunk_emb_lengths[0].item() total_len = spk_len + fifo_len + chunk_len # Create output tensor (packed at start, rest is zeros) B, _, D = spkcache.shape output = torch.zeros(B, self.max_total, D, device=chunk.device, dtype=chunk.dtype) # Copy valid frames if spk_len > 0: output[:, :spk_len, :] = spkcache[:, :spk_len, :] if fifo_len > 0: output[:, spk_len:spk_len+fifo_len, :] = fifo[:, :fifo_len, :] output[:, spk_len+fifo_len:spk_len+fifo_len+chunk_len, :] = chunk_embs[:, :chunk_len, :] total_length = torch.tensor([total_len], dtype=torch.long) return output, total_length, chunk_embs, chunk_emb_lengths class DynamicHeadWrapper(nn.Module): """Head that properly handles dynamic lengths with masking.""" def __init__(self, model): super().__init__() self.model = model def forward(self, pre_encoder_embs, pre_encoder_lengths, chunk_embs, chunk_emb_lengths): # Encode fc_embs, fc_lengths = self.model.frontend_encoder( processed_signal=pre_encoder_embs, processed_signal_length=pre_encoder_lengths, bypass_pre_encode=True, ) # Get predictions preds = self.model.forward_infer(fc_embs, fc_lengths) # Apply mask based on actual length # preds shape: [B, T, num_speakers] max_len = preds.shape[1] length = pre_encoder_lengths[0] mask = torch.arange(max_len, device=preds.device) < length preds = preds * mask.unsqueeze(0).unsqueeze(-1).float() return preds, chunk_embs, chunk_emb_lengths # Test with both empty and full state print("\n" + "=" * 70) print("TESTING DYNAMIC WRAPPERS") print("=" * 70) pre_encoder = DynamicPreEncoderWrapper(model) head = DynamicHeadWrapper(model) pre_encoder.eval() head.eval() # Test 1: Empty state (like chunk 0) print("\nTest 1: Empty state (chunk 0)") chunk = torch.randn(1, 56, 128) # First chunk has fewer frames chunk_len = torch.tensor([56], dtype=torch.long) spkcache = torch.zeros(1, 120, 512) spkcache_len = torch.tensor([0], dtype=torch.long) fifo = torch.zeros(1, 40, 512) fifo_len = torch.tensor([0], dtype=torch.long) with torch.no_grad(): pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len ) preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") print(f" Predictions: {preds.shape}") sums = [f"{preds[0, i, :].sum().item():.4f}" for i in range(min(8, preds.shape[1]))] print(f" First 8 pred frames sum: {sums}") # Test 2: Full state print("\nTest 2: Full state") chunk = torch.randn(1, 64, 128) chunk_len = torch.tensor([64], dtype=torch.long) spkcache = torch.randn(1, 120, 512) spkcache_len = torch.tensor([120], dtype=torch.long) fifo = torch.randn(1, 40, 512) fifo_len = torch.tensor([40], dtype=torch.long) with torch.no_grad(): pre_out, pre_len, chunk_embs, chunk_emb_len = pre_encoder( chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len ) preds, _, _ = head(pre_out, pre_len, chunk_embs, chunk_emb_len) print(f" Pre-encoder output: {pre_out.shape}, length={pre_len.item()}") print(f" Chunk embeddings: {chunk_embs.shape}, length={chunk_emb_len.item()}") print(f" Predictions: {preds.shape}") print("\n" + "=" * 70) print("ISSUE IDENTIFIED") print("=" * 70) print(""" The problem is that the current CoreML model was traced with FIXED lengths. When lengths change at runtime, the traced operations don't adapt. The fix requires re-tracing with proper dynamic handling OR using coremltools flexible shapes feature. For now, let's try a simpler approach: always pad inputs to max size and use the length parameters only for extracting the correct output slice. """) # The issue is that torch.jit.trace captures specific tensor values # We need to use torch.jit.script for truly dynamic behavior # But many NeMo operations don't work with script print("\nATTEMPTING CONVERSION WITH FLEXIBLE SHAPES...") # Try using coremltools range shapes try: # Create wrapper that handles the length masking internally class SimplePipelineWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths): # Pre-encode chunk chunk_embs, chunk_emb_lens = self.model.encoder.pre_encode(x=chunk, lengths=chunk_lengths) # Get lengths spk_len = spkcache_lengths[0] fifo_len = fifo_lengths[0] chunk_len = chunk_emb_lens[0] # Concatenate (always use fixed output size, rely on length for valid region) # This matches what NeMo does internally B = chunk.shape[0] max_out = 168 # 120 + 40 + 8 D = 512 concat_embs = torch.zeros(B, max_out, D, device=chunk.device, dtype=chunk.dtype) # Copy spkcache for i in range(120): if i < spk_len: concat_embs[:, i, :] = spkcache[:, i, :] # Copy fifo for i in range(40): if i < fifo_len: concat_embs[:, 120 + i, :] = fifo[:, i, :] # Copy chunk embeddings for i in range(8): if i < chunk_len: concat_embs[:, 120 + 40 + i, :] = chunk_embs[:, i, :] total_len = spk_len + fifo_len + chunk_len total_lens = total_len.unsqueeze(0) # Run through encoder fc_embs, fc_lens = self.model.frontend_encoder( processed_signal=concat_embs, processed_signal_length=total_lens, bypass_pre_encode=True, ) # Get predictions preds = self.model.forward_infer(fc_embs, fc_lens) return preds, chunk_embs, chunk_emb_lens wrapper = SimplePipelineWrapper(model) wrapper.eval() # Trace with empty state example print("Tracing with empty state example...") chunk = torch.randn(1, 64, 128) chunk_len = torch.tensor([56], dtype=torch.long) # Actual length spkcache = torch.zeros(1, 120, 512) spkcache_len = torch.tensor([0], dtype=torch.long) fifo = torch.zeros(1, 40, 512) fifo_len = torch.tensor([0], dtype=torch.long) with torch.no_grad(): traced = torch.jit.trace(wrapper, (chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len)) print("Converting to CoreML...") mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="chunk", shape=(1, 64, 128), dtype=np.float32), ct.TensorType(name="chunk_lengths", shape=(1,), dtype=np.int32), ct.TensorType(name="spkcache", shape=(1, 120, 512), dtype=np.float32), ct.TensorType(name="spkcache_lengths", shape=(1,), dtype=np.int32), ct.TensorType(name="fifo", shape=(1, 40, 512), dtype=np.float32), ct.TensorType(name="fifo_lengths", shape=(1,), dtype=np.int32), ], outputs=[ ct.TensorType(name="speaker_preds", dtype=np.float32), ct.TensorType(name="chunk_pre_encoder_embs", dtype=np.float32), ct.TensorType(name="chunk_pre_encoder_lengths", dtype=np.int32), ], minimum_deployment_target=ct.target.iOS16, compute_precision=ct.precision.FLOAT32, compute_units=ct.ComputeUnit.CPU_ONLY, # Start with CPU for debugging ) output_path = os.path.join(SCRIPT_DIR, 'coreml_models', 'SortformerPipeline_Dynamic.mlpackage') mlmodel.save(output_path) print(f"Saved to: {output_path}") # Test the new model print("\nTesting new CoreML model...") test_output = mlmodel.predict({ 'chunk': chunk.numpy(), 'chunk_lengths': chunk_len.numpy().astype(np.int32), 'spkcache': spkcache.numpy(), 'spkcache_lengths': spkcache_len.numpy().astype(np.int32), 'fifo': fifo.numpy(), 'fifo_lengths': fifo_len.numpy().astype(np.int32), }) coreml_preds = np.array(test_output['speaker_preds']) print(f"CoreML predictions shape: {coreml_preds.shape}") print(f"CoreML first 8 frames:") for i in range(min(8, coreml_preds.shape[1])): print(f" Frame {i}: {coreml_preds[0, i, :]}") except Exception as e: print(f"Error during conversion: {e}") import traceback traceback.print_exc()