|
|
import torch |
|
|
import numpy as np |
|
|
import coremltools as ct |
|
|
import librosa |
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import math |
|
|
|
|
|
|
|
|
try: |
|
|
from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
|
|
|
from nemo.collections.asr.modules.sortformer_modules import SortformerModules |
|
|
except ImportError as e: |
|
|
print(f"Error importing NeMo: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def streaming_feat_loader(modules, feat_seq, feat_seq_length, feat_seq_offset): |
|
|
""" |
|
|
Load a chunk of feature sequence for streaming inference. |
|
|
Adapted from NeMo's SortformerModules.streaming_feat_loader |
|
|
|
|
|
Args: |
|
|
modules: SortformerModules instance with chunk_len, subsampling_factor, |
|
|
chunk_left_context, chunk_right_context |
|
|
feat_seq (torch.Tensor): Tensor containing feature sequence |
|
|
Shape: (batch_size, feat_dim, feat frame count) |
|
|
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths |
|
|
Shape: (batch_size,) |
|
|
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets |
|
|
Shape: (batch_size,) |
|
|
|
|
|
Yields: |
|
|
chunk_idx (int): Index of the current chunk |
|
|
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence |
|
|
Shape: (batch_size, feat frame count, feat_dim) # Transposed! |
|
|
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence |
|
|
Shape: (batch_size,) |
|
|
left_offset (int): Left context offset in feature frames |
|
|
right_offset (int): Right context offset in feature frames |
|
|
""" |
|
|
feat_len = feat_seq.shape[2] |
|
|
chunk_len = modules.chunk_len |
|
|
subsampling_factor = modules.subsampling_factor |
|
|
chunk_left_context = getattr(modules, 'chunk_left_context', 0) |
|
|
chunk_right_context = getattr(modules, 'chunk_right_context', 0) |
|
|
|
|
|
num_chunks = math.ceil(feat_len / (chunk_len * subsampling_factor)) |
|
|
print(f"streaming_feat_loader: feat_len={feat_len}, num_chunks={num_chunks}, " |
|
|
f"chunk_len={chunk_len}, subsampling_factor={subsampling_factor}") |
|
|
|
|
|
stt_feat, end_feat, chunk_idx = 0, 0, 0 |
|
|
while end_feat < feat_len: |
|
|
left_offset = min(chunk_left_context * subsampling_factor, stt_feat) |
|
|
end_feat = min(stt_feat + chunk_len * subsampling_factor, feat_len) |
|
|
right_offset = min(chunk_right_context * subsampling_factor, feat_len - end_feat) |
|
|
|
|
|
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset] |
|
|
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp( |
|
|
0, chunk_feat_seq.shape[2] |
|
|
) |
|
|
feat_lengths = feat_lengths * (feat_seq_offset < end_feat) |
|
|
stt_feat = end_feat |
|
|
|
|
|
|
|
|
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2) |
|
|
|
|
|
print(f" chunk_idx: {chunk_idx}, chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, " |
|
|
f"feat_lengths: {feat_lengths}, left_offset: {left_offset}, right_offset: {right_offset}") |
|
|
|
|
|
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset |
|
|
chunk_idx += 1 |
|
|
|
|
|
|
|
|
def run_streaming_inference(model_name, coreml_dir, audio_path): |
|
|
print(f"Loading NeMo Model (for Python Streaming Logic): {model_name}") |
|
|
if os.path.exists(model_name): |
|
|
nemo_model = SortformerEncLabelModel.restore_from(model_name, map_location="cpu") |
|
|
else: |
|
|
nemo_model = SortformerEncLabelModel.from_pretrained(model_name, map_location="cpu") |
|
|
nemo_model.eval() |
|
|
modules = nemo_model.sortformer_modules |
|
|
|
|
|
|
|
|
print("Overriding Config (Inference) to match CoreML...") |
|
|
modules.chunk_len = 4 |
|
|
modules.chunk_right_context = 1 |
|
|
modules.chunk_left_context = 2 |
|
|
|
|
|
modules.fifo_len = 63 |
|
|
modules.spkcache_len = 63 |
|
|
modules.spkcache_update_period = 50 |
|
|
|
|
|
|
|
|
|
|
|
COREML_CHUNK_FRAMES = 56 |
|
|
COREML_SPKCACHE_LEN = 63 |
|
|
COREML_FIFO_LEN = 63 |
|
|
|
|
|
|
|
|
if hasattr(nemo_model.preprocessor, 'featurizer'): |
|
|
if hasattr(nemo_model.preprocessor.featurizer, 'dither'): |
|
|
nemo_model.preprocessor.featurizer.dither = 0.0 |
|
|
if hasattr(nemo_model.preprocessor.featurizer, 'pad_to'): |
|
|
nemo_model.preprocessor.featurizer.pad_to = 0 |
|
|
|
|
|
|
|
|
print(f"Loading CoreML Models from {coreml_dir}...") |
|
|
preproc_model = ct.models.MLModel( |
|
|
os.path.join(coreml_dir, "SortformerPreprocessor.mlpackage"), |
|
|
compute_units=ct.ComputeUnit.CPU_ONLY |
|
|
) |
|
|
main_model = ct.models.MLModel( |
|
|
os.path.join(coreml_dir, "Sortformer.mlpackage"), |
|
|
compute_units=ct.ComputeUnit.ALL |
|
|
) |
|
|
|
|
|
|
|
|
chunk_len = modules.chunk_len |
|
|
subsampling_factor = modules.subsampling_factor |
|
|
sample_rate = 16000 |
|
|
|
|
|
print(f"Chunk Config: {chunk_len} output frames (diar), subsampling_factor={subsampling_factor}") |
|
|
|
|
|
|
|
|
print(f"Loading Audio: {audio_path}") |
|
|
full_audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True) |
|
|
total_samples = len(full_audio) |
|
|
print(f"Total Samples: {total_samples} ({total_samples/sample_rate:.2f}s)") |
|
|
|
|
|
|
|
|
|
|
|
print("Extracting features for entire audio...") |
|
|
audio_tensor = torch.from_numpy(full_audio).unsqueeze(0).float() |
|
|
audio_length = torch.tensor([total_samples], dtype=torch.long) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
processed_signal, processed_signal_length = nemo_model.process_signal( |
|
|
audio_signal=audio_tensor, audio_signal_length=audio_length |
|
|
) |
|
|
|
|
|
print(f"Processed signal shape: {processed_signal.shape}") |
|
|
print(f"Processed signal length: {processed_signal_length}") |
|
|
|
|
|
|
|
|
processed_signal = processed_signal[:, :, :processed_signal_length.max()] |
|
|
|
|
|
|
|
|
print("Initializing Streaming State...") |
|
|
state = modules.init_streaming_state(batch_size=1, device='cpu') |
|
|
|
|
|
|
|
|
batch_size = processed_signal.shape[0] |
|
|
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long) |
|
|
|
|
|
all_preds = [] |
|
|
|
|
|
feat_loader = streaming_feat_loader( |
|
|
modules=modules, |
|
|
feat_seq=processed_signal, |
|
|
feat_seq_length=processed_signal_length, |
|
|
feat_seq_offset=processed_signal_offset, |
|
|
) |
|
|
|
|
|
for chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in feat_loader: |
|
|
|
|
|
|
|
|
chunk_actual_len = chunk_feat_seq_t.shape[1] |
|
|
if chunk_actual_len < COREML_CHUNK_FRAMES: |
|
|
pad_len = COREML_CHUNK_FRAMES - chunk_actual_len |
|
|
chunk_in = torch.nn.functional.pad(chunk_feat_seq_t, (0, 0, 0, pad_len)) |
|
|
else: |
|
|
chunk_in = chunk_feat_seq_t[:, :COREML_CHUNK_FRAMES, :] |
|
|
chunk_len_in = feat_lengths.long() |
|
|
|
|
|
|
|
|
curr_spk_len = state.spkcache.shape[1] |
|
|
curr_fifo_len = state.fifo.shape[1] |
|
|
|
|
|
current_spkcache = state.spkcache |
|
|
|
|
|
if curr_spk_len < COREML_SPKCACHE_LEN: |
|
|
pad_len = COREML_SPKCACHE_LEN - curr_spk_len |
|
|
current_spkcache = torch.nn.functional.pad(current_spkcache, (0, 0, 0, pad_len)) |
|
|
elif curr_spk_len > COREML_SPKCACHE_LEN: |
|
|
current_spkcache = current_spkcache[:, :COREML_SPKCACHE_LEN, :] |
|
|
|
|
|
spkcache_in = current_spkcache |
|
|
|
|
|
spkcache_len_in = torch.tensor([curr_spk_len], dtype=torch.long) |
|
|
|
|
|
|
|
|
current_fifo = state.fifo |
|
|
|
|
|
if curr_fifo_len < COREML_FIFO_LEN: |
|
|
pad_len = COREML_FIFO_LEN - curr_fifo_len |
|
|
current_fifo = torch.nn.functional.pad(current_fifo, (0, 0, 0, pad_len)) |
|
|
elif curr_fifo_len > COREML_FIFO_LEN: |
|
|
current_fifo = current_fifo[:, :COREML_FIFO_LEN, :] |
|
|
|
|
|
fifo_in = current_fifo |
|
|
fifo_len_in = torch.tensor([curr_fifo_len], dtype=torch.long) |
|
|
|
|
|
|
|
|
coreml_inputs = { |
|
|
"chunk": chunk_in.numpy().astype(np.float32), |
|
|
"chunk_lengths": chunk_len_in.numpy().astype(np.int32), |
|
|
"spkcache": spkcache_in.numpy().astype(np.float32), |
|
|
"spkcache_lengths": spkcache_len_in.numpy().astype(np.int32), |
|
|
"fifo": fifo_in.numpy().astype(np.float32), |
|
|
"fifo_lengths": fifo_len_in.numpy().astype(np.int32) |
|
|
} |
|
|
|
|
|
coreml_out = main_model.predict(coreml_inputs) |
|
|
|
|
|
|
|
|
pred_logits = torch.from_numpy(coreml_out["speaker_preds"]) |
|
|
chunk_embs = torch.from_numpy(coreml_out["chunk_pre_encoder_embs"]) |
|
|
chunk_emb_len = int(coreml_out["chunk_pre_encoder_lengths"][0]) |
|
|
|
|
|
|
|
|
chunk_embs = chunk_embs[:, :chunk_emb_len, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lc = round(left_offset / subsampling_factor) |
|
|
rc = math.ceil(right_offset / subsampling_factor) |
|
|
|
|
|
|
|
|
state, chunk_probs = modules.streaming_update( |
|
|
streaming_state=state, |
|
|
chunk=chunk_embs, |
|
|
preds=pred_logits, |
|
|
lc=lc, |
|
|
rc=rc |
|
|
) |
|
|
|
|
|
|
|
|
all_preds.append(chunk_probs) |
|
|
|
|
|
print(f"Processed chunk {chunk_idx + 1}, chunk_probs shape: {chunk_probs.shape}", end='\r') |
|
|
|
|
|
print(f"\nFinished. Total Chunks: {len(all_preds)}") |
|
|
if len(all_preds) > 0: |
|
|
final_probs = torch.cat(all_preds, dim=1) |
|
|
print(f"Final Predictions Shape: {final_probs.shape}") |
|
|
return final_probs |
|
|
return None |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_name", default="nvidia/diar_streaming_sortformer_4spk-v2.1") |
|
|
parser.add_argument("--coreml_dir", default="coreml_models") |
|
|
parser.add_argument("--audio_path", default="test2.wav") |
|
|
args = parser.parse_args() |
|
|
|
|
|
run_streaming_inference(args.model_name, args.coreml_dir, args.audio_path) |
|
|
|