#!/usr/bin/env python3 """ OmniVoice Backbone Export — Step 1: Model Surgery + torch.export Extracts the diffusion backbone as a standalone nn.Module: Input: input_ids [B, C, S], audio_mask [B, S] Output: audio_logits [B, C, S, V] Where: B = batch size (1 for single, 2 for CFG) C = 8 (num_audio_codebook) S = fixed sequence length V = 1025 (audio_vocab_size) """ import torch import torch.nn as nn import time # ── Config ────────────────────────────────────────────────────────────────── MAX_SEQ_LEN = 2048 # Fixed max sequence length for export BATCH_SIZE = 1 # Single inference (no CFG batching for simplicity) class OmniVoiceBackbone(nn.Module): """Standalone backbone for export — no Python control flow, fixed shapes.""" def __init__(self, omnivoice_model): super().__init__() cfg = omnivoice_model.config # Core components self.text_embeddings = omnivoice_model.get_input_embeddings() self.audio_embeddings = omnivoice_model.audio_embeddings self.llm = omnivoice_model.llm self.audio_heads = omnivoice_model.audio_heads # Constants self.num_codebook = cfg.num_audio_codebook self.audio_vocab_size = cfg.audio_vocab_size # Register codebook offsets as buffer self.register_buffer( "codebook_offsets", torch.arange(cfg.num_audio_codebook).view(1, -1, 1) * cfg.audio_vocab_size, ) def forward( self, input_ids: torch.LongTensor, # [B, C, S] audio_mask: torch.BoolTensor, # [B, S] ) -> torch.Tensor: """ Returns: audio_logits [B, C, S, V] """ B, C, S = input_ids.shape # 1. Text embeddings from first codebook layer text_embeds = self.text_embeddings(input_ids[:, 0, :]) # [B, S, H] # 2. Audio embeddings: shift IDs by codebook offsets, sum across codebooks shifted_ids = (input_ids * audio_mask.unsqueeze(1).long()) + self.codebook_offsets audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1) # [B, S, H] # 3. Merge: use audio where mask is True, text otherwise inputs_embeds = torch.where( audio_mask.unsqueeze(-1), audio_embeds, text_embeds, ) # 4. BIDIRECTIONAL attention mask (all True — NOT causal!) attention_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=input_ids.device) # 5. LLM forward with bidirectional attention llm_out = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, ) hidden_states = llm_out[0] # [B, S, H] # 5. Audio prediction heads logits_flat = self.audio_heads(hidden_states) # [B, S, C*V] audio_logits = logits_flat.view(B, S, self.num_codebook, self.audio_vocab_size) audio_logits = audio_logits.permute(0, 2, 1, 3) # [B, C, S, V] return audio_logits def main(): print("Loading OmniVoice...", flush=True) t0 = time.time() from omnivoice import OmniVoice model = OmniVoice.from_pretrained( "k2-fsa/OmniVoice", device_map="cpu", dtype=torch.float32, # export needs float32 ) print(f" Loaded in {time.time()-t0:.1f}s", flush=True) # Extract backbone print("Extracting backbone...", flush=True) backbone = OmniVoiceBackbone(model) backbone.eval() # Count params params = sum(p.numel() for p in backbone.parameters()) print(f" Backbone params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True) # Create example inputs seq_len = 512 # Start with smaller for testing input_ids = torch.randint(0, 1024, (BATCH_SIZE, 8, seq_len), dtype=torch.long) audio_mask = torch.zeros(BATCH_SIZE, seq_len, dtype=torch.bool) audio_mask[:, 200:] = True # Last portion is audio # Test forward pass print("Testing forward pass...", flush=True) t0 = time.time() with torch.no_grad(): logits = backbone(input_ids, audio_mask) print(f" Forward: {time.time()-t0:.1f}s", flush=True) print(f" Output shape: {logits.shape}", flush=True) print(f" Expected: [1, 8, {seq_len}, 1025]", flush=True) # Try torch.export print("\nAttempting torch.export...", flush=True) try: from torch.export import export t0 = time.time() ep = export(backbone, (input_ids, audio_mask)) print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True) print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True) # Save the exported program torch.export.save(ep, "omnivoice_backbone_exported.pt2") print(" Saved to omnivoice_backbone_exported.pt2", flush=True) except Exception as e: print(f" torch.export failed: {e}", flush=True) print(" Will need fixes — see error above", flush=True) # Also save the PyTorch module for later conversion torch.save(backbone.state_dict(), "omnivoice_backbone_state.pt") print(f"\nSaved backbone state dict to omnivoice_backbone_state.pt", flush=True) if __name__ == "__main__": main()