#!/usr/bin/env python3 """ OmniVoice Audio Decoder Export — Model Surgery + torch.export Extracts the HiggsAudioV2 decoder as a standalone nn.Module: Input: audio_codes [B, C, T] (C=8 codebooks, T=time steps) Output: waveform [B, 1, samples] (24kHz) Each time step = 40ms = 960 samples at 24kHz """ import torch import torch.nn as nn import time class OmniVoiceDecoder(nn.Module): """Standalone audio decoder wrapper for export.""" def __init__(self, audio_tokenizer): super().__init__() # Extract the actual decoder submodules self.tokenizer = audio_tokenizer def forward(self, audio_codes: torch.LongTensor) -> torch.Tensor: """ Args: audio_codes: [B, C, T] where C=8, T=time steps Returns: waveform: [B, 1, samples] """ out = self.tokenizer.decode(audio_codes) return out.audio_values def main(): print("Loading OmniVoice...", flush=True) t0 = time.time() from omnivoice import OmniVoice model = OmniVoice.from_pretrained( "k2-fsa/OmniVoice", device_map="cpu", dtype=torch.float32, ) print(f" Loaded in {time.time()-t0:.1f}s", flush=True) # Extract decoder print("Extracting decoder...", flush=True) decoder = OmniVoiceDecoder(model.audio_tokenizer) decoder.eval() params = sum(p.numel() for p in decoder.parameters()) print(f" Decoder params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True) # Test with different sequence lengths for T in [50, 100, 200]: audio_codes = torch.randint(0, 1024, (1, 8, T), dtype=torch.long) t0 = time.time() with torch.no_grad(): waveform = decoder(audio_codes) elapsed = time.time() - t0 print(f" T={T}: output={waveform.shape}, duration={waveform.shape[-1]/24000:.2f}s, time={elapsed:.1f}s", flush=True) # Try torch.export print("\nAttempting torch.export...", flush=True) test_codes = torch.randint(0, 1024, (1, 8, 100), dtype=torch.long) try: from torch.export import export t0 = time.time() ep = export(decoder, (test_codes,)) print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True) print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True) torch.export.save(ep, "omnivoice_decoder_exported.pt2") print(" Saved to omnivoice_decoder_exported.pt2", flush=True) except Exception as e: print(f" torch.export failed: {type(e).__name__}: {e}", flush=True) print(" Will need to dig into the decoder internals for fixes", flush=True) # Save state dict torch.save(decoder.state_dict(), "omnivoice_decoder_state.pt") print(f"\nSaved decoder state dict to omnivoice_decoder_state.pt", flush=True) if __name__ == "__main__": main()