nano-codec-decoder-onnx / onnx_decoder_optimized.py
Prasanna05's picture
Add Python decoder class
f38c279 verified
"""
ONNX-based decoder for KaniTTS - CPU-optimized version with reduced lookback
Use this if CUDA is not working with ONNX Runtime
"""
import numpy as np
import onnxruntime as ort
from typing import List
from pathlib import Path
class ONNXKaniTTSDecoderOptimized:
"""
CPU-optimized ONNX Runtime decoder for KaniTTS
Reduces lookback frames for faster CPU inference
"""
def __init__(self, onnx_model_path="onnx/nano_codec_decoder.onnx", device="cuda"):
self.sample_rate = 22050
self.device = device
# Decoder settings - OPTIMIZED FOR CPU
self.num_codebooks = 4
self.codebook_size = 4032
self.chunk_size = 25
# REDUCED lookback for CPU performance
self.lookback_frames = 3 # Reduced from 15 to 3 (5x faster on CPU!)
self.first_frame_lookback = 1 # Reduced from 3 to 1
# Buffer for sliding window
self.history_codes = []
self.frame_count = 0
# Setup ONNX Runtime
print(f"Loading ONNX decoder (CPU-optimized): {onnx_model_path}")
self._setup_ort_session(onnx_model_path)
# Pre-warm decoder
print("Pre-warming ONNX decoder...")
self._prewarm_decoder()
def _setup_ort_session(self, onnx_model_path):
"""Setup ONNX Runtime session - CPU optimized"""
# Try CUDA first, fall back to CPU
providers_to_try = []
if self.device == "cuda":
# Try CUDA with error handling
providers_to_try.append(("CUDAExecutionProvider", {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"gpu_mem_limit": 2 * 1024 * 1024 * 1024,
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}))
# Always include CPU as fallback
providers_to_try.append("CPUExecutionProvider")
# Session options optimized for CPU
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4 # Use multiple CPU threads
sess_options.inter_op_num_threads = 4
# Create session
try:
self.session = ort.InferenceSession(
onnx_model_path,
sess_options=sess_options,
providers=providers_to_try
)
except Exception as e:
print(f"Warning: Failed to create session with CUDA, using CPU only: {e}")
self.session = ort.InferenceSession(
onnx_model_path,
sess_options=sess_options,
providers=["CPUExecutionProvider"]
)
# Verify providers
active_providers = self.session.get_providers()
print(f"Active providers: {active_providers}")
if 'CUDAExecutionProvider' not in active_providers:
print("⚠️ WARNING: Running on CPU (slow!)")
print(" Using reduced lookback (3 frames) for better CPU performance")
print(" See PERFORMANCE_FIX.md for GPU acceleration instructions")
# Get input/output info
self.input_names = [inp.name for inp in self.session.get_inputs()]
self.output_names = [out.name for out in self.session.get_outputs()]
def _prewarm_decoder(self):
"""Pre-warm the decoder"""
dummy_codes = [0, 0, 0, 0]
_ = self.decode_frame(dummy_codes)
self.history_codes = []
self.frame_count = 0
print("ONNX decoder pre-warmed!")
def decode_frame(self, codes: List[int]) -> np.ndarray:
"""
Decode a single frame (4 tokens) to audio
Args:
codes: List of 4 integers
Returns:
numpy array of int16 audio samples
"""
if len(codes) != self.num_codebooks:
raise ValueError(f"Expected {self.num_codebooks} codes, got {len(codes)}")
# Add to history
self.history_codes.append(codes)
self.frame_count += 1
# Use reduced lookback
effective_lookback = self.first_frame_lookback if self.frame_count == 1 else self.lookback_frames
# Keep only recent history
max_history = effective_lookback + 1
if len(self.history_codes) > max_history:
self.history_codes = self.history_codes[-max_history:]
# Prepare context
context_codes = self.history_codes if len(self.history_codes) > 1 else [codes]
# Convert to numpy
codes_np = np.array(
[[frame for frame in context_codes]],
dtype=np.int64
).transpose(0, 2, 1)
num_frames = len(context_codes)
tokens_len_np = np.array([num_frames], dtype=np.int64)
# Run ONNX inference
outputs = self.session.run(
self.output_names,
{'tokens': codes_np, 'tokens_len': tokens_len_np}
)
audio = outputs[0]
# Extract only new audio
samples_per_frame = int(self.sample_rate * 0.08)
if len(self.history_codes) > 1:
start_idx = (len(self.history_codes) - 1) * samples_per_frame
audio = audio[:, start_idx:]
# Convert to int16
audio_np = audio.squeeze()
audio_int16 = (audio_np * 32767).astype(np.int16)
return audio_int16
def reset_history(self):
"""Reset decoder history"""
self.history_codes = []
self.frame_count = 0
if __name__ == "__main__":
print("Testing CPU-optimized ONNX decoder...")
decoder = ONNXKaniTTSDecoderOptimized(
onnx_model_path="nano_codec_decoder.onnx",
device="cuda" # Will fall back to CPU if CUDA not available
)
# Quick benchmark
import time
times = []
for _ in range(20):
codes = [np.random.randint(0, 500) for _ in range(4)]
start = time.time()
audio = decoder.decode_frame(codes)
times.append((time.time() - start) * 1000)
print(f"\n📊 Performance:")
print(f" Average: {np.mean(times):.2f} ms/frame")
print(f" Min: {np.min(times):.2f} ms")
print(f" Max: {np.max(times):.2f} ms")
print(f"\n✓ Decoder test completed!")
if np.mean(times) < 100:
print("✓ Performance is good!")
else:
print("⚠️ Performance is slow. See PERFORMANCE_FIX.md for solutions.")