File size: 2,870 Bytes
eba3fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""
OmniVoice Audio Decoder Export — Model Surgery + torch.export

Extracts the HiggsAudioV2 decoder as a standalone nn.Module:
  Input:  audio_codes [B, C, T] (C=8 codebooks, T=time steps)
  Output: waveform [B, 1, samples] (24kHz)

Each time step = 40ms = 960 samples at 24kHz
"""

import torch
import torch.nn as nn
import time


class OmniVoiceDecoder(nn.Module):
    """Standalone audio decoder wrapper for export."""

    def __init__(self, audio_tokenizer):
        super().__init__()
        # Extract the actual decoder submodules
        self.tokenizer = audio_tokenizer

    def forward(self, audio_codes: torch.LongTensor) -> torch.Tensor:
        """
        Args:
            audio_codes: [B, C, T] where C=8, T=time steps
        Returns:
            waveform: [B, 1, samples]
        """
        out = self.tokenizer.decode(audio_codes)
        return out.audio_values


def main():
    print("Loading OmniVoice...", flush=True)
    t0 = time.time()

    from omnivoice import OmniVoice
    model = OmniVoice.from_pretrained(
        "k2-fsa/OmniVoice",
        device_map="cpu",
        dtype=torch.float32,
    )
    print(f"  Loaded in {time.time()-t0:.1f}s", flush=True)

    # Extract decoder
    print("Extracting decoder...", flush=True)
    decoder = OmniVoiceDecoder(model.audio_tokenizer)
    decoder.eval()

    params = sum(p.numel() for p in decoder.parameters())
    print(f"  Decoder params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True)

    # Test with different sequence lengths
    for T in [50, 100, 200]:
        audio_codes = torch.randint(0, 1024, (1, 8, T), dtype=torch.long)
        t0 = time.time()
        with torch.no_grad():
            waveform = decoder(audio_codes)
        elapsed = time.time() - t0
        print(f"  T={T}: output={waveform.shape}, duration={waveform.shape[-1]/24000:.2f}s, time={elapsed:.1f}s", flush=True)

    # Try torch.export
    print("\nAttempting torch.export...", flush=True)
    test_codes = torch.randint(0, 1024, (1, 8, 100), dtype=torch.long)

    try:
        from torch.export import export
        t0 = time.time()
        ep = export(decoder, (test_codes,))
        print(f"  torch.export succeeded in {time.time()-t0:.1f}s!", flush=True)
        print(f"  Graph: {len(ep.graph.nodes)} nodes", flush=True)

        torch.export.save(ep, "omnivoice_decoder_exported.pt2")
        print("  Saved to omnivoice_decoder_exported.pt2", flush=True)

    except Exception as e:
        print(f"  torch.export failed: {type(e).__name__}: {e}", flush=True)
        print("  Will need to dig into the decoder internals for fixes", flush=True)

    # Save state dict
    torch.save(decoder.state_dict(), "omnivoice_decoder_state.pt")
    print(f"\nSaved decoder state dict to omnivoice_decoder_state.pt", flush=True)


if __name__ == "__main__":
    main()