| |
| """ |
| OmniVoice Backbone Export β Step 1: Model Surgery + torch.export |
| |
| Extracts the diffusion backbone as a standalone nn.Module: |
| Input: input_ids [B, C, S], audio_mask [B, S] |
| Output: audio_logits [B, C, S, V] |
| |
| Where: |
| B = batch size (1 for single, 2 for CFG) |
| C = 8 (num_audio_codebook) |
| S = fixed sequence length |
| V = 1025 (audio_vocab_size) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import time |
|
|
|
|
| |
| MAX_SEQ_LEN = 2048 |
| BATCH_SIZE = 1 |
|
|
|
|
| class OmniVoiceBackbone(nn.Module): |
| """Standalone backbone for export β no Python control flow, fixed shapes.""" |
|
|
| def __init__(self, omnivoice_model): |
| super().__init__() |
| cfg = omnivoice_model.config |
|
|
| |
| self.text_embeddings = omnivoice_model.get_input_embeddings() |
| self.audio_embeddings = omnivoice_model.audio_embeddings |
| self.llm = omnivoice_model.llm |
| self.audio_heads = omnivoice_model.audio_heads |
|
|
| |
| self.num_codebook = cfg.num_audio_codebook |
| self.audio_vocab_size = cfg.audio_vocab_size |
|
|
| |
| self.register_buffer( |
| "codebook_offsets", |
| torch.arange(cfg.num_audio_codebook).view(1, -1, 1) * cfg.audio_vocab_size, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| audio_mask: torch.BoolTensor, |
| ) -> torch.Tensor: |
| """ |
| Returns: audio_logits [B, C, S, V] |
| """ |
| B, C, S = input_ids.shape |
|
|
| |
| text_embeds = self.text_embeddings(input_ids[:, 0, :]) |
|
|
| |
| shifted_ids = (input_ids * audio_mask.unsqueeze(1).long()) + self.codebook_offsets |
| audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1) |
|
|
| |
| inputs_embeds = torch.where( |
| audio_mask.unsqueeze(-1), |
| audio_embeds, |
| text_embeds, |
| ) |
|
|
| |
| attention_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=input_ids.device) |
|
|
| |
| llm_out = self.llm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
| hidden_states = llm_out[0] |
|
|
| |
| logits_flat = self.audio_heads(hidden_states) |
| audio_logits = logits_flat.view(B, S, self.num_codebook, self.audio_vocab_size) |
| audio_logits = audio_logits.permute(0, 2, 1, 3) |
|
|
| return audio_logits |
|
|
|
|
| def main(): |
| print("Loading OmniVoice...", flush=True) |
| t0 = time.time() |
|
|
| from omnivoice import OmniVoice |
| model = OmniVoice.from_pretrained( |
| "k2-fsa/OmniVoice", |
| device_map="cpu", |
| dtype=torch.float32, |
| ) |
| print(f" Loaded in {time.time()-t0:.1f}s", flush=True) |
|
|
| |
| print("Extracting backbone...", flush=True) |
| backbone = OmniVoiceBackbone(model) |
| backbone.eval() |
|
|
| |
| params = sum(p.numel() for p in backbone.parameters()) |
| print(f" Backbone params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True) |
|
|
| |
| seq_len = 512 |
| input_ids = torch.randint(0, 1024, (BATCH_SIZE, 8, seq_len), dtype=torch.long) |
| audio_mask = torch.zeros(BATCH_SIZE, seq_len, dtype=torch.bool) |
| audio_mask[:, 200:] = True |
|
|
| |
| print("Testing forward pass...", flush=True) |
| t0 = time.time() |
| with torch.no_grad(): |
| logits = backbone(input_ids, audio_mask) |
| print(f" Forward: {time.time()-t0:.1f}s", flush=True) |
| print(f" Output shape: {logits.shape}", flush=True) |
| print(f" Expected: [1, 8, {seq_len}, 1025]", flush=True) |
|
|
| |
| print("\nAttempting torch.export...", flush=True) |
| try: |
| from torch.export import export |
| t0 = time.time() |
| ep = export(backbone, (input_ids, audio_mask)) |
| print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True) |
| print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True) |
|
|
| |
| torch.export.save(ep, "omnivoice_backbone_exported.pt2") |
| print(" Saved to omnivoice_backbone_exported.pt2", flush=True) |
|
|
| except Exception as e: |
| print(f" torch.export failed: {e}", flush=True) |
| print(" Will need fixes β see error above", flush=True) |
|
|
| |
| torch.save(backbone.state_dict(), "omnivoice_backbone_state.pt") |
| print(f"\nSaved backbone state dict to omnivoice_backbone_state.pt", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|