File size: 6,424 Bytes
f38c279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
ONNX-based decoder for KaniTTS - CPU-optimized version with reduced lookback
Use this if CUDA is not working with ONNX Runtime
"""

import numpy as np
import onnxruntime as ort
from typing import List
from pathlib import Path


class ONNXKaniTTSDecoderOptimized:
    """
    CPU-optimized ONNX Runtime decoder for KaniTTS
    Reduces lookback frames for faster CPU inference
    """
    def __init__(self, onnx_model_path="onnx/nano_codec_decoder.onnx", device="cuda"):
        self.sample_rate = 22050
        self.device = device

        # Decoder settings - OPTIMIZED FOR CPU
        self.num_codebooks = 4
        self.codebook_size = 4032
        self.chunk_size = 25

        # REDUCED lookback for CPU performance
        self.lookback_frames = 3  # Reduced from 15 to 3 (5x faster on CPU!)
        self.first_frame_lookback = 1  # Reduced from 3 to 1

        # Buffer for sliding window
        self.history_codes = []
        self.frame_count = 0

        # Setup ONNX Runtime
        print(f"Loading ONNX decoder (CPU-optimized): {onnx_model_path}")
        self._setup_ort_session(onnx_model_path)

        # Pre-warm decoder
        print("Pre-warming ONNX decoder...")
        self._prewarm_decoder()

    def _setup_ort_session(self, onnx_model_path):
        """Setup ONNX Runtime session - CPU optimized"""
        # Try CUDA first, fall back to CPU
        providers_to_try = []

        if self.device == "cuda":
            # Try CUDA with error handling
            providers_to_try.append(("CUDAExecutionProvider", {
                "device_id": 0,
                "arena_extend_strategy": "kNextPowerOfTwo",
                "gpu_mem_limit": 2 * 1024 * 1024 * 1024,
                "cudnn_conv_algo_search": "EXHAUSTIVE",
                "do_copy_in_default_stream": True,
            }))

        # Always include CPU as fallback
        providers_to_try.append("CPUExecutionProvider")

        # Session options optimized for CPU
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        sess_options.intra_op_num_threads = 4  # Use multiple CPU threads
        sess_options.inter_op_num_threads = 4

        # Create session
        try:
            self.session = ort.InferenceSession(
                onnx_model_path,
                sess_options=sess_options,
                providers=providers_to_try
            )
        except Exception as e:
            print(f"Warning: Failed to create session with CUDA, using CPU only: {e}")
            self.session = ort.InferenceSession(
                onnx_model_path,
                sess_options=sess_options,
                providers=["CPUExecutionProvider"]
            )

        # Verify providers
        active_providers = self.session.get_providers()
        print(f"Active providers: {active_providers}")

        if 'CUDAExecutionProvider' not in active_providers:
            print("⚠️  WARNING: Running on CPU (slow!)")
            print("   Using reduced lookback (3 frames) for better CPU performance")
            print("   See PERFORMANCE_FIX.md for GPU acceleration instructions")

        # Get input/output info
        self.input_names = [inp.name for inp in self.session.get_inputs()]
        self.output_names = [out.name for out in self.session.get_outputs()]

    def _prewarm_decoder(self):
        """Pre-warm the decoder"""
        dummy_codes = [0, 0, 0, 0]
        _ = self.decode_frame(dummy_codes)
        self.history_codes = []
        self.frame_count = 0
        print("ONNX decoder pre-warmed!")

    def decode_frame(self, codes: List[int]) -> np.ndarray:
        """
        Decode a single frame (4 tokens) to audio

        Args:
            codes: List of 4 integers

        Returns:
            numpy array of int16 audio samples
        """
        if len(codes) != self.num_codebooks:
            raise ValueError(f"Expected {self.num_codebooks} codes, got {len(codes)}")

        # Add to history
        self.history_codes.append(codes)
        self.frame_count += 1

        # Use reduced lookback
        effective_lookback = self.first_frame_lookback if self.frame_count == 1 else self.lookback_frames

        # Keep only recent history
        max_history = effective_lookback + 1
        if len(self.history_codes) > max_history:
            self.history_codes = self.history_codes[-max_history:]

        # Prepare context
        context_codes = self.history_codes if len(self.history_codes) > 1 else [codes]

        # Convert to numpy
        codes_np = np.array(
            [[frame for frame in context_codes]],
            dtype=np.int64
        ).transpose(0, 2, 1)

        num_frames = len(context_codes)
        tokens_len_np = np.array([num_frames], dtype=np.int64)

        # Run ONNX inference
        outputs = self.session.run(
            self.output_names,
            {'tokens': codes_np, 'tokens_len': tokens_len_np}
        )

        audio = outputs[0]

        # Extract only new audio
        samples_per_frame = int(self.sample_rate * 0.08)

        if len(self.history_codes) > 1:
            start_idx = (len(self.history_codes) - 1) * samples_per_frame
            audio = audio[:, start_idx:]

        # Convert to int16
        audio_np = audio.squeeze()
        audio_int16 = (audio_np * 32767).astype(np.int16)

        return audio_int16

    def reset_history(self):
        """Reset decoder history"""
        self.history_codes = []
        self.frame_count = 0


if __name__ == "__main__":
    print("Testing CPU-optimized ONNX decoder...")

    decoder = ONNXKaniTTSDecoderOptimized(
        onnx_model_path="nano_codec_decoder.onnx",
        device="cuda"  # Will fall back to CPU if CUDA not available
    )

    # Quick benchmark
    import time
    times = []

    for _ in range(20):
        codes = [np.random.randint(0, 500) for _ in range(4)]
        start = time.time()
        audio = decoder.decode_frame(codes)
        times.append((time.time() - start) * 1000)

    print(f"\n📊 Performance:")
    print(f"  Average: {np.mean(times):.2f} ms/frame")
    print(f"  Min: {np.min(times):.2f} ms")
    print(f"  Max: {np.max(times):.2f} ms")
    print(f"\n✓ Decoder test completed!")

    if np.mean(times) < 100:
        print("✓ Performance is good!")
    else:
        print("⚠️  Performance is slow. See PERFORMANCE_FIX.md for solutions.")