File size: 10,626 Bytes
24368a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | #!/usr/bin/env python3
"""
Phase 2: Export Speaker Encoder to ExecuTorch .pte
===================================================
Extracts the ECAPA-TDNN speaker encoder, wraps it for fixed-size input,
exports via torch.export, and lowers to XNNPACK .pte.
Input: mel spectrogram [1, T, 128] where T is fixed (e.g. 469 for 5s audio)
Output: x-vector [1, 2048]
"""
import sys
import os
import copy
import time
import numpy as np
import torch
import torch.nn as nn
# ββ paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5 seconds of audio at 24kHz: mel with hop_size=256 gives ~469 frames
# We fix this for export. At runtime, pad/truncate mel to this size.
FIXED_MEL_FRAMES = 469
MEL_DIM = 128
if __name__ == "__main__":
_run_export = True
else:
_run_export = False
if _run_export:
print("=" * 70)
print("PHASE 2: Export Speaker Encoder β .pte")
print("=" * 70)
# ββ 1. Load Model βββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import (
Qwen3TTSForConditionalGeneration,
mel_spectrogram,
)
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
# ββ 2. Create Export-Ready Wrapper βββββββββββββββββββββββββββββββββββ
print("\n[2/5] Creating export-ready speaker encoder wrapper...")
class SpeakerEncoderForExport(nn.Module):
"""
Wrapper around the ECAPA-TDNN speaker encoder for ExecuTorch export.
Takes a pre-computed mel spectrogram of fixed size and returns x-vector.
The original speaker_encoder.forward expects [B, T, 128] (mel)
and transposes internally to [B, 128, T] for Conv1d processing.
We replicate the same architecture but replace padding="same" Conv1d
layers with explicit padding to avoid dynamic pad calculation issues.
"""
def __init__(self, original_encoder):
super().__init__()
# Deep copy to avoid modifying the original
self.encoder = copy.deepcopy(original_encoder)
# Replace all Conv1d with padding="same" to use explicit integer padding
self._fix_conv_padding(self.encoder)
def _fix_conv_padding(self, module):
"""
Replace padding='same' Conv1d layers with explicit integer padding.
For kernel_size=k and dilation=d, 'same' padding = d * (k - 1) // 2
when stride=1. We switch to 'zeros' padding mode and use F.pad for reflect.
"""
for name, child in module.named_children():
if isinstance(child, nn.Conv1d) and child.padding == 'same':
# Calculate explicit padding for stride=1
k = child.kernel_size[0]
d = child.dilation[0]
s = child.stride[0]
assert s == 1, f"padding='same' with stride != 1 not handled: {name}"
pad_total = d * (k - 1)
pad_left = pad_total // 2
pad_right = pad_total - pad_left
# Create a wrapper that does explicit reflect padding + conv with no padding
new_conv = _ExplicitPadConv1d(child, pad_left, pad_right, child.padding_mode)
setattr(module, name, new_conv)
else:
self._fix_conv_padding(child)
def forward(self, mel_input: torch.Tensor) -> torch.Tensor:
"""
Args:
mel_input: [1, FIXED_MEL_FRAMES, 128] β pre-computed mel spectrogram
Returns:
x_vector: [1, 2048] β speaker embedding
"""
return self.encoder(mel_input)
class _ExplicitPadConv1d(nn.Module):
"""Conv1d with explicit padding instead of padding='same'."""
def __init__(self, original_conv: nn.Conv1d, pad_left: int, pad_right: int, pad_mode: str):
super().__init__()
# Create a new Conv1d with padding=0
self.conv = nn.Conv1d(
in_channels=original_conv.in_channels,
out_channels=original_conv.out_channels,
kernel_size=original_conv.kernel_size[0],
stride=original_conv.stride[0],
padding=0,
dilation=original_conv.dilation[0],
groups=original_conv.groups,
bias=original_conv.bias is not None,
)
# Copy weights
self.conv.weight = original_conv.weight
if original_conv.bias is not None:
self.conv.bias = original_conv.bias
self.pad_left = pad_left
self.pad_right = pad_right
self.pad_mode = pad_mode
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.pad_left > 0 or self.pad_right > 0:
x = torch.nn.functional.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode)
return self.conv(x)
# Create the wrapper
export_encoder = SpeakerEncoderForExport(model.speaker_encoder)
export_encoder.eval()
# ββ 3. Validate Wrapper vs Original βββββββββββββββββββββββββββββββββ
print("\n[3/5] Validating wrapper produces same output as original...")
# Create test mel input
test_mel = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM)
with torch.no_grad():
orig_out = model.speaker_encoder(test_mel)
wrap_out = export_encoder(test_mel)
cos_sim = torch.nn.functional.cosine_similarity(orig_out, wrap_out, dim=-1).item()
max_diff = (orig_out - wrap_out).abs().max().item()
print(f" Original output shape: {list(orig_out.shape)}")
print(f" Wrapper output shape: {list(wrap_out.shape)}")
print(f" Cosine similarity: {cos_sim:.6f}")
print(f" Max abs difference: {max_diff:.2e}")
assert cos_sim > 0.999, f"Wrapper diverged from original! cos_sim={cos_sim}"
print(" PASS β wrapper matches original")
# ββ 4. torch.export βββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n[4/5] Running torch.export...")
t0 = time.time()
example_input = (torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM),)
try:
exported = torch.export.export(
export_encoder,
example_input,
strict=False, # Allow some Python dynamism to be traced
)
print(f" torch.export succeeded in {time.time() - t0:.1f}s")
print(f" Graph has {len(exported.graph.nodes)} nodes")
except Exception as e:
print(f" torch.export FAILED: {e}")
print(" Trying with torch.export.export(..., strict=False) already set.")
print(" Attempting torch.jit.trace as fallback...")
try:
traced = torch.jit.trace(export_encoder, example_input)
traced.save(os.path.join(OUTPUT_DIR, "speaker_encoder_traced.pt"))
print(" torch.jit.trace succeeded (saved as .pt, not .pte)")
except Exception as e2:
print(f" torch.jit.trace also failed: {e2}")
sys.exit(1)
# ββ 5. Lower to ExecuTorch .pte βββββββββββββββββββββββββββββββββββββ
print("\n[5/5] Lowering to ExecuTorch .pte (XNNPACK)...")
t0 = time.time()
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size_mb = os.path.getsize(pte_path) / 1e6
print(f" .pte saved: {pte_path}")
print(f" .pte size: {pte_size_mb:.1f} MB")
print(f" Lowered in {time.time() - t0:.1f}s")
except Exception as e:
print(f" ExecuTorch lowering failed: {e}")
print(" Saving exported program as .pt2 instead...")
pt2_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pt2")
torch.export.save(exported, pt2_path)
print(f" Saved: {pt2_path}")
# ββ Validate .pte output (if available) ββββββββββββββββββββββββββββββ
if os.path.exists(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")):
print("\n Validating .pte execution...")
try:
from executorch.runtime import Runtime, Program, Method
runtime = Runtime.get()
program = runtime.load_program(
open(os.path.join(OUTPUT_DIR, "speaker_encoder.pte"), "rb").read()
)
method = program.load_method("forward")
test_input = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM)
pte_out = method.execute([test_input])
with torch.no_grad():
ref_out = export_encoder(test_input)
if isinstance(pte_out, (list, tuple)):
pte_out = pte_out[0]
cos_sim_pte = torch.nn.functional.cosine_similarity(
ref_out.flatten().unsqueeze(0),
pte_out.flatten().unsqueeze(0)
).item()
print(f" .pte vs PyTorch cosine sim: {cos_sim_pte:.6f}")
except Exception as e:
print(f" .pte validation failed: {e}")
print(" (This may be OK β runtime validation can be done on target device)")
print("\n" + "=" * 70)
print("Phase 2 complete!")
print(f" Input: mel spectrogram [1, {FIXED_MEL_FRAMES}, {MEL_DIM}]")
print(f" Output: x-vector [1, 2048]")
print("=" * 70)
|