OmniVoice-LiteRT / export_decoder.py
acul3's picture
Upload export_decoder.py with huggingface_hub
eba3fe7 verified
#!/usr/bin/env python3
"""
OmniVoice Audio Decoder Export — Model Surgery + torch.export
Extracts the HiggsAudioV2 decoder as a standalone nn.Module:
Input: audio_codes [B, C, T] (C=8 codebooks, T=time steps)
Output: waveform [B, 1, samples] (24kHz)
Each time step = 40ms = 960 samples at 24kHz
"""
import torch
import torch.nn as nn
import time
class OmniVoiceDecoder(nn.Module):
"""Standalone audio decoder wrapper for export."""
def __init__(self, audio_tokenizer):
super().__init__()
# Extract the actual decoder submodules
self.tokenizer = audio_tokenizer
def forward(self, audio_codes: torch.LongTensor) -> torch.Tensor:
"""
Args:
audio_codes: [B, C, T] where C=8, T=time steps
Returns:
waveform: [B, 1, samples]
"""
out = self.tokenizer.decode(audio_codes)
return out.audio_values
def main():
print("Loading OmniVoice...", flush=True)
t0 = time.time()
from omnivoice import OmniVoice
model = OmniVoice.from_pretrained(
"k2-fsa/OmniVoice",
device_map="cpu",
dtype=torch.float32,
)
print(f" Loaded in {time.time()-t0:.1f}s", flush=True)
# Extract decoder
print("Extracting decoder...", flush=True)
decoder = OmniVoiceDecoder(model.audio_tokenizer)
decoder.eval()
params = sum(p.numel() for p in decoder.parameters())
print(f" Decoder params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True)
# Test with different sequence lengths
for T in [50, 100, 200]:
audio_codes = torch.randint(0, 1024, (1, 8, T), dtype=torch.long)
t0 = time.time()
with torch.no_grad():
waveform = decoder(audio_codes)
elapsed = time.time() - t0
print(f" T={T}: output={waveform.shape}, duration={waveform.shape[-1]/24000:.2f}s, time={elapsed:.1f}s", flush=True)
# Try torch.export
print("\nAttempting torch.export...", flush=True)
test_codes = torch.randint(0, 1024, (1, 8, 100), dtype=torch.long)
try:
from torch.export import export
t0 = time.time()
ep = export(decoder, (test_codes,))
print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True)
print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True)
torch.export.save(ep, "omnivoice_decoder_exported.pt2")
print(" Saved to omnivoice_decoder_exported.pt2", flush=True)
except Exception as e:
print(f" torch.export failed: {type(e).__name__}: {e}", flush=True)
print(" Will need to dig into the decoder internals for fixes", flush=True)
# Save state dict
torch.save(decoder.state_dict(), "omnivoice_decoder_state.pt")
print(f"\nSaved decoder state dict to omnivoice_decoder_state.pt", flush=True)
if __name__ == "__main__":
main()