acul3 commited on
Commit
eba3fe7
·
verified ·
1 Parent(s): 0b88b36

Upload export_decoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_decoder.py +89 -0
export_decoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OmniVoice Audio Decoder Export — Model Surgery + torch.export
4
+
5
+ Extracts the HiggsAudioV2 decoder as a standalone nn.Module:
6
+ Input: audio_codes [B, C, T] (C=8 codebooks, T=time steps)
7
+ Output: waveform [B, 1, samples] (24kHz)
8
+
9
+ Each time step = 40ms = 960 samples at 24kHz
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import time
15
+
16
+
17
+ class OmniVoiceDecoder(nn.Module):
18
+ """Standalone audio decoder wrapper for export."""
19
+
20
+ def __init__(self, audio_tokenizer):
21
+ super().__init__()
22
+ # Extract the actual decoder submodules
23
+ self.tokenizer = audio_tokenizer
24
+
25
+ def forward(self, audio_codes: torch.LongTensor) -> torch.Tensor:
26
+ """
27
+ Args:
28
+ audio_codes: [B, C, T] where C=8, T=time steps
29
+ Returns:
30
+ waveform: [B, 1, samples]
31
+ """
32
+ out = self.tokenizer.decode(audio_codes)
33
+ return out.audio_values
34
+
35
+
36
+ def main():
37
+ print("Loading OmniVoice...", flush=True)
38
+ t0 = time.time()
39
+
40
+ from omnivoice import OmniVoice
41
+ model = OmniVoice.from_pretrained(
42
+ "k2-fsa/OmniVoice",
43
+ device_map="cpu",
44
+ dtype=torch.float32,
45
+ )
46
+ print(f" Loaded in {time.time()-t0:.1f}s", flush=True)
47
+
48
+ # Extract decoder
49
+ print("Extracting decoder...", flush=True)
50
+ decoder = OmniVoiceDecoder(model.audio_tokenizer)
51
+ decoder.eval()
52
+
53
+ params = sum(p.numel() for p in decoder.parameters())
54
+ print(f" Decoder params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True)
55
+
56
+ # Test with different sequence lengths
57
+ for T in [50, 100, 200]:
58
+ audio_codes = torch.randint(0, 1024, (1, 8, T), dtype=torch.long)
59
+ t0 = time.time()
60
+ with torch.no_grad():
61
+ waveform = decoder(audio_codes)
62
+ elapsed = time.time() - t0
63
+ print(f" T={T}: output={waveform.shape}, duration={waveform.shape[-1]/24000:.2f}s, time={elapsed:.1f}s", flush=True)
64
+
65
+ # Try torch.export
66
+ print("\nAttempting torch.export...", flush=True)
67
+ test_codes = torch.randint(0, 1024, (1, 8, 100), dtype=torch.long)
68
+
69
+ try:
70
+ from torch.export import export
71
+ t0 = time.time()
72
+ ep = export(decoder, (test_codes,))
73
+ print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True)
74
+ print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True)
75
+
76
+ torch.export.save(ep, "omnivoice_decoder_exported.pt2")
77
+ print(" Saved to omnivoice_decoder_exported.pt2", flush=True)
78
+
79
+ except Exception as e:
80
+ print(f" torch.export failed: {type(e).__name__}: {e}", flush=True)
81
+ print(" Will need to dig into the decoder internals for fixes", flush=True)
82
+
83
+ # Save state dict
84
+ torch.save(decoder.state_dict(), "omnivoice_decoder_state.pt")
85
+ print(f"\nSaved decoder state dict to omnivoice_decoder_state.pt", flush=True)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()