File size: 5,593 Bytes
ed33fd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#!/usr/bin/env python3
"""Get exact NeMo streaming inference output for comparison with Swift."""
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import numpy as np
import librosa
import json
from nemo.collections.asr.models import SortformerEncLabelModel
def main():
print("Loading NeMo model...")
model = SortformerEncLabelModel.restore_from(
'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu'
)
model.eval()
# Disable dither for deterministic output
if hasattr(model.preprocessor, 'featurizer'):
if hasattr(model.preprocessor.featurizer, 'dither'):
model.preprocessor.featurizer.dither = 0.0
# Configure for Gradient Descent's streaming config (same as Swift)
modules = model.sortformer_modules
modules.chunk_len = 6
modules.chunk_left_context = 1
modules.chunk_right_context = 7
modules.fifo_len = 40
modules.spkcache_len = 188
modules.spkcache_update_period = 31
print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")
# Load audio
audio_path = "../audio.wav"
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)")
waveform = torch.from_numpy(audio).unsqueeze(0).float()
# Get mel features using model's preprocessor
with torch.no_grad():
audio_len = torch.tensor([waveform.shape[1]])
features, feat_len = model.process_signal(
audio_signal=waveform, audio_signal_length=audio_len
)
# features is [batch, mel, time], need [batch, time, mel] for streaming
features = features[:, :, :feat_len.max()]
print(f"Features: {features.shape} (batch, mel, time)")
# Streaming inference using forward_streaming_step
subsampling = modules.subsampling_factor # 8
chunk_len = modules.chunk_len # 6
left_context = modules.chunk_left_context # 1
right_context = modules.chunk_right_context # 7
core_frames = chunk_len * subsampling # 48 mel frames
total_mel_frames = features.shape[2]
print(f"Total mel frames: {total_mel_frames}")
print(f"Core frames per chunk: {core_frames}")
# Initialize streaming state
streaming_state = modules.init_streaming_state(device=features.device)
# Initialize total_preds tensor
total_preds = torch.zeros((1, 0, 4), device=features.device)
all_preds = []
chunk_idx = 0
# Process chunks like streaming_feat_loader
stt_feat = 0
while stt_feat < total_mel_frames:
end_feat = min(stt_feat + core_frames, total_mel_frames)
# Calculate context (in mel frames)
left_offset = min(left_context * subsampling, stt_feat)
right_offset = min(right_context * subsampling, total_mel_frames - end_feat)
chunk_start = stt_feat - left_offset
chunk_end = end_feat + right_offset
# Extract chunk - [batch, mel, time] -> [batch, time, mel]
chunk = features[:, :, chunk_start:chunk_end] # [1, 128, T]
chunk_t = chunk.transpose(1, 2) # [1, T, 128]
chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long)
with torch.no_grad():
# Use forward_streaming_step
streaming_state, total_preds = model.forward_streaming_step(
processed_signal=chunk_t,
processed_signal_length=chunk_len_tensor,
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
chunk_idx += 1
stt_feat = end_feat
# total_preds now contains all predictions
all_preds = total_preds[0].numpy() # [total_frames, 4]
print(f"\nTotal output frames: {all_preds.shape[0]}")
print(f"Predictions shape: {all_preds.shape}")
# Print timeline
print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===")
print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual")
print("-" * 60)
for frame in range(all_preds.shape[0]):
time_sec = frame * 0.08
probs = all_preds[frame]
visual = ['■' if p > 0.55 else '·' for p in probs]
print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]")
print("-" * 60)
# Speaker activity summary
print("\n=== Speaker Activity Summary ===")
threshold = 0.55
for spk in range(4):
active_frames = np.sum(all_preds[:, spk] > threshold)
active_time = active_frames * 0.08
percent = active_time / (all_preds.shape[0] * 0.08) * 100
print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)")
# Save to JSON for comparison
output = {
"total_frames": int(all_preds.shape[0]),
"frame_duration_seconds": 0.08,
"probabilities": all_preds.flatten().tolist(),
"config": {
"chunk_len": chunk_len,
"chunk_left_context": left_context,
"chunk_right_context": right_context,
"fifo_len": modules.fifo_len,
"spkcache_len": modules.spkcache_len,
}
}
with open("/tmp/nemo_streaming_reference.json", "w") as f:
json.dump(output, f, indent=2)
print("\nSaved to /tmp/nemo_streaming_reference.json")
if __name__ == "__main__":
main()
|