File size: 1,890 Bytes
c3c1782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Example usage of the ONNX NanoCodec decoder
"""

import numpy as np
import onnxruntime as ort

def example_basic_inference():
    """Basic ONNX inference example"""
    print("Loading ONNX model...")

    session = ort.InferenceSession(
        "nano_codec_decoder.onnx",
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
    )

    print(f"Providers: {session.get_providers()}")

    # Create sample input
    num_frames = 10
    tokens = np.random.randint(0, 500, (1, 4, num_frames), dtype=np.int64)
    tokens_len = np.array([num_frames], dtype=np.int64)

    print(f"\nInput tokens: {tokens.shape}")

    # Run inference
    outputs = session.run(
        None,
        {"tokens": tokens, "tokens_len": tokens_len}
    )

    audio, audio_len = outputs
    print(f"Output audio: {audio.shape}")
    print(f"Audio duration: {audio.shape[1] / 22050:.2f} seconds")

    return audio


def example_with_decoder_class():
    """Example using the decoder class"""
    from onnx_decoder_optimized import ONNXKaniTTSDecoderOptimized

    print("Initializing decoder...")
    decoder = ONNXKaniTTSDecoderOptimized(
        onnx_model_path="nano_codec_decoder.onnx",
        device="cuda"
    )

    # Decode multiple frames
    print("\nDecoding frames...")
    for i in range(5):
        codes = [np.random.randint(0, 500) for _ in range(4)]
        audio = decoder.decode_frame(codes)
        print(f"  Frame {i+1}: {audio.shape} samples")

    decoder.reset_history()
    print("✓ Decoding complete")


if __name__ == "__main__":
    print("="*60)
    print("ONNX NanoCodec Decoder Examples")
    print("="*60)

    # Example 1
    print("\n[1/2] Basic inference...")
    example_basic_inference()

    # Example 2
    print("\n[2/2] Using decoder class...")
    example_with_decoder_class()

    print("\n" + "="*60)
    print("Examples complete!")
    print("="*60)