File size: 5,353 Bytes
0b88b36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | #!/usr/bin/env python3
"""
OmniVoice Backbone Export β Step 1: Model Surgery + torch.export
Extracts the diffusion backbone as a standalone nn.Module:
Input: input_ids [B, C, S], audio_mask [B, S]
Output: audio_logits [B, C, S, V]
Where:
B = batch size (1 for single, 2 for CFG)
C = 8 (num_audio_codebook)
S = fixed sequence length
V = 1025 (audio_vocab_size)
"""
import torch
import torch.nn as nn
import time
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MAX_SEQ_LEN = 2048 # Fixed max sequence length for export
BATCH_SIZE = 1 # Single inference (no CFG batching for simplicity)
class OmniVoiceBackbone(nn.Module):
"""Standalone backbone for export β no Python control flow, fixed shapes."""
def __init__(self, omnivoice_model):
super().__init__()
cfg = omnivoice_model.config
# Core components
self.text_embeddings = omnivoice_model.get_input_embeddings()
self.audio_embeddings = omnivoice_model.audio_embeddings
self.llm = omnivoice_model.llm
self.audio_heads = omnivoice_model.audio_heads
# Constants
self.num_codebook = cfg.num_audio_codebook
self.audio_vocab_size = cfg.audio_vocab_size
# Register codebook offsets as buffer
self.register_buffer(
"codebook_offsets",
torch.arange(cfg.num_audio_codebook).view(1, -1, 1) * cfg.audio_vocab_size,
)
def forward(
self,
input_ids: torch.LongTensor, # [B, C, S]
audio_mask: torch.BoolTensor, # [B, S]
) -> torch.Tensor:
"""
Returns: audio_logits [B, C, S, V]
"""
B, C, S = input_ids.shape
# 1. Text embeddings from first codebook layer
text_embeds = self.text_embeddings(input_ids[:, 0, :]) # [B, S, H]
# 2. Audio embeddings: shift IDs by codebook offsets, sum across codebooks
shifted_ids = (input_ids * audio_mask.unsqueeze(1).long()) + self.codebook_offsets
audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1) # [B, S, H]
# 3. Merge: use audio where mask is True, text otherwise
inputs_embeds = torch.where(
audio_mask.unsqueeze(-1),
audio_embeds,
text_embeds,
)
# 4. BIDIRECTIONAL attention mask (all True β NOT causal!)
attention_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=input_ids.device)
# 5. LLM forward with bidirectional attention
llm_out = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
)
hidden_states = llm_out[0] # [B, S, H]
# 5. Audio prediction heads
logits_flat = self.audio_heads(hidden_states) # [B, S, C*V]
audio_logits = logits_flat.view(B, S, self.num_codebook, self.audio_vocab_size)
audio_logits = audio_logits.permute(0, 2, 1, 3) # [B, C, S, V]
return audio_logits
def main():
print("Loading OmniVoice...", flush=True)
t0 = time.time()
from omnivoice import OmniVoice
model = OmniVoice.from_pretrained(
"k2-fsa/OmniVoice",
device_map="cpu",
dtype=torch.float32, # export needs float32
)
print(f" Loaded in {time.time()-t0:.1f}s", flush=True)
# Extract backbone
print("Extracting backbone...", flush=True)
backbone = OmniVoiceBackbone(model)
backbone.eval()
# Count params
params = sum(p.numel() for p in backbone.parameters())
print(f" Backbone params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True)
# Create example inputs
seq_len = 512 # Start with smaller for testing
input_ids = torch.randint(0, 1024, (BATCH_SIZE, 8, seq_len), dtype=torch.long)
audio_mask = torch.zeros(BATCH_SIZE, seq_len, dtype=torch.bool)
audio_mask[:, 200:] = True # Last portion is audio
# Test forward pass
print("Testing forward pass...", flush=True)
t0 = time.time()
with torch.no_grad():
logits = backbone(input_ids, audio_mask)
print(f" Forward: {time.time()-t0:.1f}s", flush=True)
print(f" Output shape: {logits.shape}", flush=True)
print(f" Expected: [1, 8, {seq_len}, 1025]", flush=True)
# Try torch.export
print("\nAttempting torch.export...", flush=True)
try:
from torch.export import export
t0 = time.time()
ep = export(backbone, (input_ids, audio_mask))
print(f" torch.export succeeded in {time.time()-t0:.1f}s!", flush=True)
print(f" Graph: {len(ep.graph.nodes)} nodes", flush=True)
# Save the exported program
torch.export.save(ep, "omnivoice_backbone_exported.pt2")
print(" Saved to omnivoice_backbone_exported.pt2", flush=True)
except Exception as e:
print(f" torch.export failed: {e}", flush=True)
print(" Will need fixes β see error above", flush=True)
# Also save the PyTorch module for later conversion
torch.save(backbone.state_dict(), "omnivoice_backbone_state.pt")
print(f"\nSaved backbone state dict to omnivoice_backbone_state.pt", flush=True)
if __name__ == "__main__":
main()
|