| |
| """ |
| OmniVoice Audio Decoder Export — Model Surgery + torch.export |
| |
| Extracts the HiggsAudioV2 decoder as a standalone nn.Module: |
| Input: audio_codes [B, C, T] (C=8 codebooks, T=time steps) |
| Output: waveform [B, 1, samples] (24kHz) |
| |
| Each time step = 40ms = 960 samples at 24kHz |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import time |
|
|
|
|
| class OmniVoiceDecoder(nn.Module): |
| """Standalone audio decoder wrapper for export.""" |
|
|
| def __init__(self, audio_tokenizer): |
| super().__init__() |
| |
| self.tokenizer = audio_tokenizer |
|
|
| def forward(self, audio_codes: torch.LongTensor) -> torch.Tensor: |
| """ |
| Args: |
| audio_codes: [B, C, T] where C=8, T=time steps |
| Returns: |
| waveform: [B, 1, samples] |
| """ |
| out = self.tokenizer.decode(audio_codes) |
| return out.audio_values |
|
|
|
|
| def main(): |
| print("Loading OmniVoice...", flush=True) |
| t0 = time.time() |
|
|
| from omnivoice import OmniVoice |
| model = OmniVoice.from_pretrained( |
| "k2-fsa/OmniVoice", |
| device_map="cpu", |
| dtype=torch.float32, |
| ) |
| print(f" Loaded in {time.time()-t0:.1f}s", flush=True) |
|
|
| |
| print("Extracting decoder...", flush=True) |
| decoder = OmniVoiceDecoder(model.audio_tokenizer) |
| decoder.eval() |
|
|
| params = sum(p.numel() for p in decoder.parameters()) |
| print(f" Decoder params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True) |
|
|
| |
| for T in [50, 100, 200]: |
| audio_codes = torch.randint(0, 1024, (1, 8, T), dtype=torch.long) |
| t0 = time.time() |
| with torch.no_grad(): |
| waveform = decoder(audio_codes) |
| elapsed = time.time() - t0 |
| print(f" T={T}: output={waveform.shape}, duration={waveform.shape[-1]/24000:.2f}s, time={elapsed:.1f}s", flush=True) |
|
|
| |
| print("\nAttempting torch.export...", flush=True) |
| test_codes = torch.randint(0, 1024, (1, 8, 100), dtype=torch.long) |
|
|
| try: |
| from torch.export import export |
| t0 = time.time() |
| ep = export(decoder, (test_codes,)) |
| print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True) |
| print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True) |
|
|
| torch.export.save(ep, "omnivoice_decoder_exported.pt2") |
| print(" Saved to omnivoice_decoder_exported.pt2", flush=True) |
|
|
| except Exception as e: |
| print(f" torch.export failed: {type(e).__name__}: {e}", flush=True) |
| print(" Will need to dig into the decoder internals for fixes", flush=True) |
|
|
| |
| torch.save(decoder.state_dict(), "omnivoice_decoder_state.pt") |
| print(f"\nSaved decoder state dict to omnivoice_decoder_state.pt", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|