|
|
""" |
|
|
ONNX-based decoder for KaniTTS - CPU-optimized version with reduced lookback |
|
|
Use this if CUDA is not working with ONNX Runtime |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
from typing import List |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class ONNXKaniTTSDecoderOptimized: |
|
|
""" |
|
|
CPU-optimized ONNX Runtime decoder for KaniTTS |
|
|
Reduces lookback frames for faster CPU inference |
|
|
""" |
|
|
def __init__(self, onnx_model_path="onnx/nano_codec_decoder.onnx", device="cuda"): |
|
|
self.sample_rate = 22050 |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.num_codebooks = 4 |
|
|
self.codebook_size = 4032 |
|
|
self.chunk_size = 25 |
|
|
|
|
|
|
|
|
self.lookback_frames = 3 |
|
|
self.first_frame_lookback = 1 |
|
|
|
|
|
|
|
|
self.history_codes = [] |
|
|
self.frame_count = 0 |
|
|
|
|
|
|
|
|
print(f"Loading ONNX decoder (CPU-optimized): {onnx_model_path}") |
|
|
self._setup_ort_session(onnx_model_path) |
|
|
|
|
|
|
|
|
print("Pre-warming ONNX decoder...") |
|
|
self._prewarm_decoder() |
|
|
|
|
|
def _setup_ort_session(self, onnx_model_path): |
|
|
"""Setup ONNX Runtime session - CPU optimized""" |
|
|
|
|
|
providers_to_try = [] |
|
|
|
|
|
if self.device == "cuda": |
|
|
|
|
|
providers_to_try.append(("CUDAExecutionProvider", { |
|
|
"device_id": 0, |
|
|
"arena_extend_strategy": "kNextPowerOfTwo", |
|
|
"gpu_mem_limit": 2 * 1024 * 1024 * 1024, |
|
|
"cudnn_conv_algo_search": "EXHAUSTIVE", |
|
|
"do_copy_in_default_stream": True, |
|
|
})) |
|
|
|
|
|
|
|
|
providers_to_try.append("CPUExecutionProvider") |
|
|
|
|
|
|
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
sess_options.intra_op_num_threads = 4 |
|
|
sess_options.inter_op_num_threads = 4 |
|
|
|
|
|
|
|
|
try: |
|
|
self.session = ort.InferenceSession( |
|
|
onnx_model_path, |
|
|
sess_options=sess_options, |
|
|
providers=providers_to_try |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to create session with CUDA, using CPU only: {e}") |
|
|
self.session = ort.InferenceSession( |
|
|
onnx_model_path, |
|
|
sess_options=sess_options, |
|
|
providers=["CPUExecutionProvider"] |
|
|
) |
|
|
|
|
|
|
|
|
active_providers = self.session.get_providers() |
|
|
print(f"Active providers: {active_providers}") |
|
|
|
|
|
if 'CUDAExecutionProvider' not in active_providers: |
|
|
print("⚠️ WARNING: Running on CPU (slow!)") |
|
|
print(" Using reduced lookback (3 frames) for better CPU performance") |
|
|
print(" See PERFORMANCE_FIX.md for GPU acceleration instructions") |
|
|
|
|
|
|
|
|
self.input_names = [inp.name for inp in self.session.get_inputs()] |
|
|
self.output_names = [out.name for out in self.session.get_outputs()] |
|
|
|
|
|
def _prewarm_decoder(self): |
|
|
"""Pre-warm the decoder""" |
|
|
dummy_codes = [0, 0, 0, 0] |
|
|
_ = self.decode_frame(dummy_codes) |
|
|
self.history_codes = [] |
|
|
self.frame_count = 0 |
|
|
print("ONNX decoder pre-warmed!") |
|
|
|
|
|
def decode_frame(self, codes: List[int]) -> np.ndarray: |
|
|
""" |
|
|
Decode a single frame (4 tokens) to audio |
|
|
|
|
|
Args: |
|
|
codes: List of 4 integers |
|
|
|
|
|
Returns: |
|
|
numpy array of int16 audio samples |
|
|
""" |
|
|
if len(codes) != self.num_codebooks: |
|
|
raise ValueError(f"Expected {self.num_codebooks} codes, got {len(codes)}") |
|
|
|
|
|
|
|
|
self.history_codes.append(codes) |
|
|
self.frame_count += 1 |
|
|
|
|
|
|
|
|
effective_lookback = self.first_frame_lookback if self.frame_count == 1 else self.lookback_frames |
|
|
|
|
|
|
|
|
max_history = effective_lookback + 1 |
|
|
if len(self.history_codes) > max_history: |
|
|
self.history_codes = self.history_codes[-max_history:] |
|
|
|
|
|
|
|
|
context_codes = self.history_codes if len(self.history_codes) > 1 else [codes] |
|
|
|
|
|
|
|
|
codes_np = np.array( |
|
|
[[frame for frame in context_codes]], |
|
|
dtype=np.int64 |
|
|
).transpose(0, 2, 1) |
|
|
|
|
|
num_frames = len(context_codes) |
|
|
tokens_len_np = np.array([num_frames], dtype=np.int64) |
|
|
|
|
|
|
|
|
outputs = self.session.run( |
|
|
self.output_names, |
|
|
{'tokens': codes_np, 'tokens_len': tokens_len_np} |
|
|
) |
|
|
|
|
|
audio = outputs[0] |
|
|
|
|
|
|
|
|
samples_per_frame = int(self.sample_rate * 0.08) |
|
|
|
|
|
if len(self.history_codes) > 1: |
|
|
start_idx = (len(self.history_codes) - 1) * samples_per_frame |
|
|
audio = audio[:, start_idx:] |
|
|
|
|
|
|
|
|
audio_np = audio.squeeze() |
|
|
audio_int16 = (audio_np * 32767).astype(np.int16) |
|
|
|
|
|
return audio_int16 |
|
|
|
|
|
def reset_history(self): |
|
|
"""Reset decoder history""" |
|
|
self.history_codes = [] |
|
|
self.frame_count = 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Testing CPU-optimized ONNX decoder...") |
|
|
|
|
|
decoder = ONNXKaniTTSDecoderOptimized( |
|
|
onnx_model_path="nano_codec_decoder.onnx", |
|
|
device="cuda" |
|
|
) |
|
|
|
|
|
|
|
|
import time |
|
|
times = [] |
|
|
|
|
|
for _ in range(20): |
|
|
codes = [np.random.randint(0, 500) for _ in range(4)] |
|
|
start = time.time() |
|
|
audio = decoder.decode_frame(codes) |
|
|
times.append((time.time() - start) * 1000) |
|
|
|
|
|
print(f"\n📊 Performance:") |
|
|
print(f" Average: {np.mean(times):.2f} ms/frame") |
|
|
print(f" Min: {np.min(times):.2f} ms") |
|
|
print(f" Max: {np.max(times):.2f} ms") |
|
|
print(f"\n✓ Decoder test completed!") |
|
|
|
|
|
if np.mean(times) < 100: |
|
|
print("✓ Performance is good!") |
|
|
else: |
|
|
print("⚠️ Performance is slow. See PERFORMANCE_FIX.md for solutions.") |
|
|
|