File size: 5,353 Bytes
0b88b36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
"""
OmniVoice Backbone Export β€” Step 1: Model Surgery + torch.export

Extracts the diffusion backbone as a standalone nn.Module:
  Input:  input_ids [B, C, S], audio_mask [B, S]
  Output: audio_logits [B, C, S, V]

Where:
  B = batch size (1 for single, 2 for CFG)
  C = 8 (num_audio_codebook)
  S = fixed sequence length
  V = 1025 (audio_vocab_size)
"""

import torch
import torch.nn as nn
import time


# ── Config ──────────────────────────────────────────────────────────────────
MAX_SEQ_LEN = 2048  # Fixed max sequence length for export
BATCH_SIZE = 1       # Single inference (no CFG batching for simplicity)


class OmniVoiceBackbone(nn.Module):
    """Standalone backbone for export β€” no Python control flow, fixed shapes."""

    def __init__(self, omnivoice_model):
        super().__init__()
        cfg = omnivoice_model.config

        # Core components
        self.text_embeddings = omnivoice_model.get_input_embeddings()
        self.audio_embeddings = omnivoice_model.audio_embeddings
        self.llm = omnivoice_model.llm
        self.audio_heads = omnivoice_model.audio_heads

        # Constants
        self.num_codebook = cfg.num_audio_codebook
        self.audio_vocab_size = cfg.audio_vocab_size

        # Register codebook offsets as buffer
        self.register_buffer(
            "codebook_offsets",
            torch.arange(cfg.num_audio_codebook).view(1, -1, 1) * cfg.audio_vocab_size,
        )

    def forward(
        self,
        input_ids: torch.LongTensor,    # [B, C, S]
        audio_mask: torch.BoolTensor,    # [B, S]
    ) -> torch.Tensor:
        """
        Returns: audio_logits [B, C, S, V]
        """
        B, C, S = input_ids.shape

        # 1. Text embeddings from first codebook layer
        text_embeds = self.text_embeddings(input_ids[:, 0, :])  # [B, S, H]

        # 2. Audio embeddings: shift IDs by codebook offsets, sum across codebooks
        shifted_ids = (input_ids * audio_mask.unsqueeze(1).long()) + self.codebook_offsets
        audio_embeds = self.audio_embeddings(shifted_ids).sum(dim=1)  # [B, S, H]

        # 3. Merge: use audio where mask is True, text otherwise
        inputs_embeds = torch.where(
            audio_mask.unsqueeze(-1),
            audio_embeds,
            text_embeds,
        )

        # 4. BIDIRECTIONAL attention mask (all True β€” NOT causal!)
        attention_mask = torch.ones(B, 1, S, S, dtype=torch.bool, device=input_ids.device)

        # 5. LLM forward with bidirectional attention
        llm_out = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
        )
        hidden_states = llm_out[0]  # [B, S, H]

        # 5. Audio prediction heads
        logits_flat = self.audio_heads(hidden_states)  # [B, S, C*V]
        audio_logits = logits_flat.view(B, S, self.num_codebook, self.audio_vocab_size)
        audio_logits = audio_logits.permute(0, 2, 1, 3)  # [B, C, S, V]

        return audio_logits


def main():
    print("Loading OmniVoice...", flush=True)
    t0 = time.time()

    from omnivoice import OmniVoice
    model = OmniVoice.from_pretrained(
        "k2-fsa/OmniVoice",
        device_map="cpu",
        dtype=torch.float32,  # export needs float32
    )
    print(f"  Loaded in {time.time()-t0:.1f}s", flush=True)

    # Extract backbone
    print("Extracting backbone...", flush=True)
    backbone = OmniVoiceBackbone(model)
    backbone.eval()

    # Count params
    params = sum(p.numel() for p in backbone.parameters())
    print(f"  Backbone params: {params/1e6:.1f}M ({params*4/1e9:.2f}GB fp32)", flush=True)

    # Create example inputs
    seq_len = 512  # Start with smaller for testing
    input_ids = torch.randint(0, 1024, (BATCH_SIZE, 8, seq_len), dtype=torch.long)
    audio_mask = torch.zeros(BATCH_SIZE, seq_len, dtype=torch.bool)
    audio_mask[:, 200:] = True  # Last portion is audio

    # Test forward pass
    print("Testing forward pass...", flush=True)
    t0 = time.time()
    with torch.no_grad():
        logits = backbone(input_ids, audio_mask)
    print(f"  Forward: {time.time()-t0:.1f}s", flush=True)
    print(f"  Output shape: {logits.shape}", flush=True)
    print(f"  Expected: [1, 8, {seq_len}, 1025]", flush=True)

    # Try torch.export
    print("\nAttempting torch.export...", flush=True)
    try:
        from torch.export import export
        t0 = time.time()
        ep = export(backbone, (input_ids, audio_mask))
        print(f"  torch.export succeeded in {time.time()-t0:.1f}s!", flush=True)
        print(f"  Graph: {len(ep.graph.nodes)} nodes", flush=True)

        # Save the exported program
        torch.export.save(ep, "omnivoice_backbone_exported.pt2")
        print("  Saved to omnivoice_backbone_exported.pt2", flush=True)

    except Exception as e:
        print(f"  torch.export failed: {e}", flush=True)
        print("  Will need fixes β€” see error above", flush=True)

    # Also save the PyTorch module for later conversion
    torch.save(backbone.state_dict(), "omnivoice_backbone_state.pt")
    print(f"\nSaved backbone state dict to omnivoice_backbone_state.pt", flush=True)


if __name__ == "__main__":
    main()