File size: 5,593 Bytes
ed33fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""Get exact NeMo streaming inference output for comparison with Swift."""

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import numpy as np
import librosa
import json

from nemo.collections.asr.models import SortformerEncLabelModel

def main():
    print("Loading NeMo model...")
    model = SortformerEncLabelModel.restore_from(
        'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu'
    )
    model.eval()

    # Disable dither for deterministic output
    if hasattr(model.preprocessor, 'featurizer'):
        if hasattr(model.preprocessor.featurizer, 'dither'):
            model.preprocessor.featurizer.dither = 0.0

    # Configure for Gradient Descent's streaming config (same as Swift)
    modules = model.sortformer_modules
    modules.chunk_len = 6
    modules.chunk_left_context = 1
    modules.chunk_right_context = 7
    modules.fifo_len = 40
    modules.spkcache_len = 188
    modules.spkcache_update_period = 31

    print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
    print(f"        fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}")

    # Load audio
    audio_path = "../audio.wav"
    audio, sr = librosa.load(audio_path, sr=16000, mono=True)
    print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)")

    waveform = torch.from_numpy(audio).unsqueeze(0).float()

    # Get mel features using model's preprocessor
    with torch.no_grad():
        audio_len = torch.tensor([waveform.shape[1]])
        features, feat_len = model.process_signal(
            audio_signal=waveform, audio_signal_length=audio_len
        )

    # features is [batch, mel, time], need [batch, time, mel] for streaming
    features = features[:, :, :feat_len.max()]
    print(f"Features: {features.shape} (batch, mel, time)")

    # Streaming inference using forward_streaming_step
    subsampling = modules.subsampling_factor  # 8
    chunk_len = modules.chunk_len  # 6
    left_context = modules.chunk_left_context  # 1
    right_context = modules.chunk_right_context  # 7
    core_frames = chunk_len * subsampling  # 48 mel frames

    total_mel_frames = features.shape[2]
    print(f"Total mel frames: {total_mel_frames}")
    print(f"Core frames per chunk: {core_frames}")

    # Initialize streaming state
    streaming_state = modules.init_streaming_state(device=features.device)

    # Initialize total_preds tensor
    total_preds = torch.zeros((1, 0, 4), device=features.device)

    all_preds = []
    chunk_idx = 0

    # Process chunks like streaming_feat_loader
    stt_feat = 0
    while stt_feat < total_mel_frames:
        end_feat = min(stt_feat + core_frames, total_mel_frames)

        # Calculate context (in mel frames)
        left_offset = min(left_context * subsampling, stt_feat)
        right_offset = min(right_context * subsampling, total_mel_frames - end_feat)

        chunk_start = stt_feat - left_offset
        chunk_end = end_feat + right_offset

        # Extract chunk - [batch, mel, time] -> [batch, time, mel]
        chunk = features[:, :, chunk_start:chunk_end]  # [1, 128, T]
        chunk_t = chunk.transpose(1, 2)  # [1, T, 128]
        chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long)

        with torch.no_grad():
            # Use forward_streaming_step
            streaming_state, total_preds = model.forward_streaming_step(
                processed_signal=chunk_t,
                processed_signal_length=chunk_len_tensor,
                streaming_state=streaming_state,
                total_preds=total_preds,
                left_offset=left_offset,
                right_offset=right_offset,
            )

        chunk_idx += 1
        stt_feat = end_feat

    # total_preds now contains all predictions
    all_preds = total_preds[0].numpy()  # [total_frames, 4]
    print(f"\nTotal output frames: {all_preds.shape[0]}")
    print(f"Predictions shape: {all_preds.shape}")

    # Print timeline
    print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===")
    print("Frame  Time    Spk0   Spk1   Spk2   Spk3  | Visual")
    print("-" * 60)

    for frame in range(all_preds.shape[0]):
        time_sec = frame * 0.08
        probs = all_preds[frame]
        visual = ['■' if p > 0.55 else '·' for p in probs]
        print(f"{frame:5d}  {time_sec:5.2f}s  {probs[0]:.3f}  {probs[1]:.3f}  {probs[2]:.3f}  {probs[3]:.3f}  | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]")

    print("-" * 60)

    # Speaker activity summary
    print("\n=== Speaker Activity Summary ===")
    threshold = 0.55
    for spk in range(4):
        active_frames = np.sum(all_preds[:, spk] > threshold)
        active_time = active_frames * 0.08
        percent = active_time / (all_preds.shape[0] * 0.08) * 100
        print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)")

    # Save to JSON for comparison
    output = {
        "total_frames": int(all_preds.shape[0]),
        "frame_duration_seconds": 0.08,
        "probabilities": all_preds.flatten().tolist(),
        "config": {
            "chunk_len": chunk_len,
            "chunk_left_context": left_context,
            "chunk_right_context": right_context,
            "fifo_len": modules.fifo_len,
            "spkcache_len": modules.spkcache_len,
        }
    }

    with open("/tmp/nemo_streaming_reference.json", "w") as f:
        json.dump(output, f, indent=2)
    print("\nSaved to /tmp/nemo_streaming_reference.json")

if __name__ == "__main__":
    main()