OmniVoice-LiteRT / export_backbone.py
acul3's picture
Upload export_backbone.py with huggingface_hub
0b88b36 verified
#!/usr/bin/env python3
"""
OmniVoice Backbone Export β€” Step 1: Model Surgery + torch.export
Extracts the diffusion backbone as a standalone nn.Module:
Input: input_ids [B, C, S], audio_mask [B, S]
Output: audio_logits [B, C, S, V]
Where:
B = batch size (1 for single, 2 for CFG)
C = 8 (num_audio_codebook)
S = fixed sequence length
V = 1025 (audio_vocab_size)
"""
import torch
import torch.nn as nn
import time
# ── Config ──────────────────────────────────────────────────────────────────
MAX_SEQ_LEN = 2048 # Fixed max sequence length for export
BATCH_SIZE = 1 # Single inference (no CFG batching for simplicity)
class OmniVoiceBackbone(nn.Module):
"""Standalone backbone for export β€” no Python control flow, fixed shapes."""
def __init__(self, omnivoice_model):
super().__init__()
cfg = omnivoice_model.config
# Core components
self.text_embeddings = omnivoice_model.get_input_embeddings()
self.audio_embeddings = omnivoice_model.audio_embeddings
self.llm = omnivoice_model.llm
self.audio_heads = omnivoice_model.audio_heads
# Constants
self.num_codebook = cfg.num_audio_codebook
self.audio_vocab_size = cfg.audio_vocab_size
# Register codebook offsets as buffer
self.register_buffer(
"codebook_offsets",
torch.arange(cfg.num_audio_codebook).view(1, -1, 1) * cfg.audio_vocab_size,
)
def forward(
self,
input_ids: torch.LongTensor, # [B, C, S]
audio_mask: torch.BoolTensor, # [B, S]
) -> torch.Tensor:
"""
Returns: audio_logits [B, C, S, V]
"""
B, C, S = input_ids.shape
# 1. Text embeddings from first codebook layer
text_embeds = self.text_embeddings(input_ids[:, 0, :]) # [B, S, H]
# 2. Audio embeddings: shift IDs by codebook offsets, sum across codebooks
shifted_ids = (input_ids * audio_mask.unsqueeze(1).long()) + self.codebook_offsets
audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1) # [B, S, H]
# 3. Merge: use audio where mask is True, text otherwise
inputs_embeds = torch.where(
audio_mask.unsqueeze(-1),
audio_embeds,
text_embeds,
)
# 4. BIDIRECTIONAL attention mask (all True β€” NOT causal!)
attention_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=input_ids.device)
# 5. LLM forward with bidirectional attention
llm_out = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
)
hidden_states = llm_out[0] # [B, S, H]
# 5. Audio prediction heads
logits_flat = self.audio_heads(hidden_states) # [B, S, C*V]
audio_logits = logits_flat.view(B, S, self.num_codebook, self.audio_vocab_size)
audio_logits = audio_logits.permute(0, 2, 1, 3) # [B, C, S, V]
return audio_logits
def main():
print("Loading OmniVoice...", flush=True)
t0 = time.time()
from omnivoice import OmniVoice
model = OmniVoice.from_pretrained(
"k2-fsa/OmniVoice",
device_map="cpu",
dtype=torch.float32, # export needs float32
)
print(f" Loaded in {time.time()-t0:.1f}s", flush=True)
# Extract backbone
print("Extracting backbone...", flush=True)
backbone = OmniVoiceBackbone(model)
backbone.eval()
# Count params
params = sum(p.numel() for p in backbone.parameters())
print(f" Backbone params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True)
# Create example inputs
seq_len = 512 # Start with smaller for testing
input_ids = torch.randint(0, 1024, (BATCH_SIZE, 8, seq_len), dtype=torch.long)
audio_mask = torch.zeros(BATCH_SIZE, seq_len, dtype=torch.bool)
audio_mask[:, 200:] = True # Last portion is audio
# Test forward pass
print("Testing forward pass...", flush=True)
t0 = time.time()
with torch.no_grad():
logits = backbone(input_ids, audio_mask)
print(f" Forward: {time.time()-t0:.1f}s", flush=True)
print(f" Output shape: {logits.shape}", flush=True)
print(f" Expected: [1, 8, {seq_len}, 1025]", flush=True)
# Try torch.export
print("\nAttempting torch.export...", flush=True)
try:
from torch.export import export
t0 = time.time()
ep = export(backbone, (input_ids, audio_mask))
print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True)
print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True)
# Save the exported program
torch.export.save(ep, "omnivoice_backbone_exported.pt2")
print(" Saved to omnivoice_backbone_exported.pt2", flush=True)
except Exception as e:
print(f" torch.export failed: {e}", flush=True)
print(" Will need fixes β€” see error above", flush=True)
# Also save the PyTorch module for later conversion
torch.save(backbone.state_dict(), "omnivoice_backbone_state.pt")
print(f"\nSaved backbone state dict to omnivoice_backbone_state.pt", flush=True)
if __name__ == "__main__":
main()