Qwen3-TTS-1.7B-Base-ExecuTorch / scripts /export_speaker_encoder.py
acul3's picture
Upload scripts/export_speaker_encoder.py with huggingface_hub
24368a4 verified
#!/usr/bin/env python3
"""
Phase 2: Export Speaker Encoder to ExecuTorch .pte
===================================================
Extracts the ECAPA-TDNN speaker encoder, wraps it for fixed-size input,
exports via torch.export, and lowers to XNNPACK .pte.
Input: mel spectrogram [1, T, 128] where T is fixed (e.g. 469 for 5s audio)
Output: x-vector [1, 2048]
"""
import sys
import os
import copy
import time
import numpy as np
import torch
import torch.nn as nn
# ── paths ────────────────────────────────────────────────────────────
MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base")
VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages")
QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS")
OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported")
if VENV_SITE not in sys.path:
sys.path.insert(0, VENV_SITE)
if QWEN_TTS_SRC not in sys.path:
sys.path.insert(0, QWEN_TTS_SRC)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ── Configuration ────────────────────────────────────────────────────
# 5 seconds of audio at 24kHz: mel with hop_size=256 gives ~469 frames
# We fix this for export. At runtime, pad/truncate mel to this size.
FIXED_MEL_FRAMES = 469
MEL_DIM = 128
if __name__ == "__main__":
_run_export = True
else:
_run_export = False
if _run_export:
print("=" * 70)
print("PHASE 2: Export Speaker Encoder β†’ .pte")
print("=" * 70)
# ── 1. Load Model ───────────────────────────────────────────────────
print("\n[1/5] Loading model...")
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import (
Qwen3TTSForConditionalGeneration,
mel_spectrogram,
)
config = Qwen3TTSConfig.from_pretrained(MODEL_PATH)
model = Qwen3TTSForConditionalGeneration.from_pretrained(
MODEL_PATH, config=config, dtype=torch.float32,
attn_implementation="sdpa", device_map="cpu",
)
model.eval()
print(" Model loaded.")
# ── 2. Create Export-Ready Wrapper ───────────────────────────────────
print("\n[2/5] Creating export-ready speaker encoder wrapper...")
class SpeakerEncoderForExport(nn.Module):
"""
Wrapper around the ECAPA-TDNN speaker encoder for ExecuTorch export.
Takes a pre-computed mel spectrogram of fixed size and returns x-vector.
The original speaker_encoder.forward expects [B, T, 128] (mel)
and transposes internally to [B, 128, T] for Conv1d processing.
We replicate the same architecture but replace padding="same" Conv1d
layers with explicit padding to avoid dynamic pad calculation issues.
"""
def __init__(self, original_encoder):
super().__init__()
# Deep copy to avoid modifying the original
self.encoder = copy.deepcopy(original_encoder)
# Replace all Conv1d with padding="same" to use explicit integer padding
self._fix_conv_padding(self.encoder)
def _fix_conv_padding(self, module):
"""
Replace padding='same' Conv1d layers with explicit integer padding.
For kernel_size=k and dilation=d, 'same' padding = d * (k - 1) // 2
when stride=1. We switch to 'zeros' padding mode and use F.pad for reflect.
"""
for name, child in module.named_children():
if isinstance(child, nn.Conv1d) and child.padding == 'same':
# Calculate explicit padding for stride=1
k = child.kernel_size[0]
d = child.dilation[0]
s = child.stride[0]
assert s == 1, f"padding='same' with stride != 1 not handled: {name}"
pad_total = d * (k - 1)
pad_left = pad_total // 2
pad_right = pad_total - pad_left
# Create a wrapper that does explicit reflect padding + conv with no padding
new_conv = _ExplicitPadConv1d(child, pad_left, pad_right, child.padding_mode)
setattr(module, name, new_conv)
else:
self._fix_conv_padding(child)
def forward(self, mel_input: torch.Tensor) -> torch.Tensor:
"""
Args:
mel_input: [1, FIXED_MEL_FRAMES, 128] β€” pre-computed mel spectrogram
Returns:
x_vector: [1, 2048] β€” speaker embedding
"""
return self.encoder(mel_input)
class _ExplicitPadConv1d(nn.Module):
"""Conv1d with explicit padding instead of padding='same'."""
def __init__(self, original_conv: nn.Conv1d, pad_left: int, pad_right: int, pad_mode: str):
super().__init__()
# Create a new Conv1d with padding=0
self.conv = nn.Conv1d(
in_channels=original_conv.in_channels,
out_channels=original_conv.out_channels,
kernel_size=original_conv.kernel_size[0],
stride=original_conv.stride[0],
padding=0,
dilation=original_conv.dilation[0],
groups=original_conv.groups,
bias=original_conv.bias is not None,
)
# Copy weights
self.conv.weight = original_conv.weight
if original_conv.bias is not None:
self.conv.bias = original_conv.bias
self.pad_left = pad_left
self.pad_right = pad_right
self.pad_mode = pad_mode
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.pad_left > 0 or self.pad_right > 0:
x = torch.nn.functional.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode)
return self.conv(x)
# Create the wrapper
export_encoder = SpeakerEncoderForExport(model.speaker_encoder)
export_encoder.eval()
# ── 3. Validate Wrapper vs Original ─────────────────────────────────
print("\n[3/5] Validating wrapper produces same output as original...")
# Create test mel input
test_mel = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM)
with torch.no_grad():
orig_out = model.speaker_encoder(test_mel)
wrap_out = export_encoder(test_mel)
cos_sim = torch.nn.functional.cosine_similarity(orig_out, wrap_out, dim=-1).item()
max_diff = (orig_out - wrap_out).abs().max().item()
print(f" Original output shape: {list(orig_out.shape)}")
print(f" Wrapper output shape: {list(wrap_out.shape)}")
print(f" Cosine similarity: {cos_sim:.6f}")
print(f" Max abs difference: {max_diff:.2e}")
assert cos_sim > 0.999, f"Wrapper diverged from original! cos_sim={cos_sim}"
print(" PASS β€” wrapper matches original")
# ── 4. torch.export ─────────────────────────────────────────────────
print("\n[4/5] Running torch.export...")
t0 = time.time()
example_input = (torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM),)
try:
exported = torch.export.export(
export_encoder,
example_input,
strict=False, # Allow some Python dynamism to be traced
)
print(f" torch.export succeeded in {time.time() - t0:.1f}s")
print(f" Graph has {len(exported.graph.nodes)} nodes")
except Exception as e:
print(f" torch.export FAILED: {e}")
print(" Trying with torch.export.export(..., strict=False) already set.")
print(" Attempting torch.jit.trace as fallback...")
try:
traced = torch.jit.trace(export_encoder, example_input)
traced.save(os.path.join(OUTPUT_DIR, "speaker_encoder_traced.pt"))
print(" torch.jit.trace succeeded (saved as .pt, not .pte)")
except Exception as e2:
print(f" torch.jit.trace also failed: {e2}")
sys.exit(1)
# ── 5. Lower to ExecuTorch .pte ─────────────────────────────────────
print("\n[5/5] Lowering to ExecuTorch .pte (XNNPACK)...")
t0 = time.time()
try:
from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()],
)
et_program = edge.to_executorch()
pte_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
pte_size_mb = os.path.getsize(pte_path) / 1e6
print(f" .pte saved: {pte_path}")
print(f" .pte size: {pte_size_mb:.1f} MB")
print(f" Lowered in {time.time() - t0:.1f}s")
except Exception as e:
print(f" ExecuTorch lowering failed: {e}")
print(" Saving exported program as .pt2 instead...")
pt2_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pt2")
torch.export.save(exported, pt2_path)
print(f" Saved: {pt2_path}")
# ── Validate .pte output (if available) ──────────────────────────────
if os.path.exists(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")):
print("\n Validating .pte execution...")
try:
from executorch.runtime import Runtime, Program, Method
runtime = Runtime.get()
program = runtime.load_program(
open(os.path.join(OUTPUT_DIR, "speaker_encoder.pte"), "rb").read()
)
method = program.load_method("forward")
test_input = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM)
pte_out = method.execute([test_input])
with torch.no_grad():
ref_out = export_encoder(test_input)
if isinstance(pte_out, (list, tuple)):
pte_out = pte_out[0]
cos_sim_pte = torch.nn.functional.cosine_similarity(
ref_out.flatten().unsqueeze(0),
pte_out.flatten().unsqueeze(0)
).item()
print(f" .pte vs PyTorch cosine sim: {cos_sim_pte:.6f}")
except Exception as e:
print(f" .pte validation failed: {e}")
print(" (This may be OK β€” runtime validation can be done on target device)")
print("\n" + "=" * 70)
print("Phase 2 complete!")
print(f" Input: mel spectrogram [1, {FIXED_MEL_FRAMES}, {MEL_DIM}]")
print(f" Output: x-vector [1, 2048]")
print("=" * 70)