#!/usr/bin/env python3 """ Phase 2: Export Speaker Encoder to ExecuTorch .pte =================================================== Extracts the ECAPA-TDNN speaker encoder, wraps it for fixed-size input, exports via torch.export, and lowers to XNNPACK .pte. Input: mel spectrogram [1, T, 128] where T is fixed (e.g. 469 for 5s audio) Output: x-vector [1, 2048] """ import sys import os import copy import time import numpy as np import torch import torch.nn as nn # ── paths ──────────────────────────────────────────────────────────── MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") if VENV_SITE not in sys.path: sys.path.insert(0, VENV_SITE) if QWEN_TTS_SRC not in sys.path: sys.path.insert(0, QWEN_TTS_SRC) os.makedirs(OUTPUT_DIR, exist_ok=True) # ── Configuration ──────────────────────────────────────────────────── # 5 seconds of audio at 24kHz: mel with hop_size=256 gives ~469 frames # We fix this for export. At runtime, pad/truncate mel to this size. FIXED_MEL_FRAMES = 469 MEL_DIM = 128 if __name__ == "__main__": _run_export = True else: _run_export = False if _run_export: print("=" * 70) print("PHASE 2: Export Speaker Encoder → .pte") print("=" * 70) # ── 1. Load Model ─────────────────────────────────────────────────── print("\n[1/5] Loading model...") from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig from qwen_tts.core.models.modeling_qwen3_tts import ( Qwen3TTSForConditionalGeneration, mel_spectrogram, ) config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) model = Qwen3TTSForConditionalGeneration.from_pretrained( MODEL_PATH, config=config, dtype=torch.float32, attn_implementation="sdpa", device_map="cpu", ) model.eval() print(" Model loaded.") # ── 2. Create Export-Ready Wrapper ─────────────────────────────────── print("\n[2/5] Creating export-ready speaker encoder wrapper...") class SpeakerEncoderForExport(nn.Module): """ Wrapper around the ECAPA-TDNN speaker encoder for ExecuTorch export. Takes a pre-computed mel spectrogram of fixed size and returns x-vector. The original speaker_encoder.forward expects [B, T, 128] (mel) and transposes internally to [B, 128, T] for Conv1d processing. We replicate the same architecture but replace padding="same" Conv1d layers with explicit padding to avoid dynamic pad calculation issues. """ def __init__(self, original_encoder): super().__init__() # Deep copy to avoid modifying the original self.encoder = copy.deepcopy(original_encoder) # Replace all Conv1d with padding="same" to use explicit integer padding self._fix_conv_padding(self.encoder) def _fix_conv_padding(self, module): """ Replace padding='same' Conv1d layers with explicit integer padding. For kernel_size=k and dilation=d, 'same' padding = d * (k - 1) // 2 when stride=1. We switch to 'zeros' padding mode and use F.pad for reflect. """ for name, child in module.named_children(): if isinstance(child, nn.Conv1d) and child.padding == 'same': # Calculate explicit padding for stride=1 k = child.kernel_size[0] d = child.dilation[0] s = child.stride[0] assert s == 1, f"padding='same' with stride != 1 not handled: {name}" pad_total = d * (k - 1) pad_left = pad_total // 2 pad_right = pad_total - pad_left # Create a wrapper that does explicit reflect padding + conv with no padding new_conv = _ExplicitPadConv1d(child, pad_left, pad_right, child.padding_mode) setattr(module, name, new_conv) else: self._fix_conv_padding(child) def forward(self, mel_input: torch.Tensor) -> torch.Tensor: """ Args: mel_input: [1, FIXED_MEL_FRAMES, 128] — pre-computed mel spectrogram Returns: x_vector: [1, 2048] — speaker embedding """ return self.encoder(mel_input) class _ExplicitPadConv1d(nn.Module): """Conv1d with explicit padding instead of padding='same'.""" def __init__(self, original_conv: nn.Conv1d, pad_left: int, pad_right: int, pad_mode: str): super().__init__() # Create a new Conv1d with padding=0 self.conv = nn.Conv1d( in_channels=original_conv.in_channels, out_channels=original_conv.out_channels, kernel_size=original_conv.kernel_size[0], stride=original_conv.stride[0], padding=0, dilation=original_conv.dilation[0], groups=original_conv.groups, bias=original_conv.bias is not None, ) # Copy weights self.conv.weight = original_conv.weight if original_conv.bias is not None: self.conv.bias = original_conv.bias self.pad_left = pad_left self.pad_right = pad_right self.pad_mode = pad_mode def forward(self, x: torch.Tensor) -> torch.Tensor: if self.pad_left > 0 or self.pad_right > 0: x = torch.nn.functional.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode) return self.conv(x) # Create the wrapper export_encoder = SpeakerEncoderForExport(model.speaker_encoder) export_encoder.eval() # ── 3. Validate Wrapper vs Original ───────────────────────────────── print("\n[3/5] Validating wrapper produces same output as original...") # Create test mel input test_mel = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM) with torch.no_grad(): orig_out = model.speaker_encoder(test_mel) wrap_out = export_encoder(test_mel) cos_sim = torch.nn.functional.cosine_similarity(orig_out, wrap_out, dim=-1).item() max_diff = (orig_out - wrap_out).abs().max().item() print(f" Original output shape: {list(orig_out.shape)}") print(f" Wrapper output shape: {list(wrap_out.shape)}") print(f" Cosine similarity: {cos_sim:.6f}") print(f" Max abs difference: {max_diff:.2e}") assert cos_sim > 0.999, f"Wrapper diverged from original! cos_sim={cos_sim}" print(" PASS — wrapper matches original") # ── 4. torch.export ───────────────────────────────────────────────── print("\n[4/5] Running torch.export...") t0 = time.time() example_input = (torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM),) try: exported = torch.export.export( export_encoder, example_input, strict=False, # Allow some Python dynamism to be traced ) print(f" torch.export succeeded in {time.time() - t0:.1f}s") print(f" Graph has {len(exported.graph.nodes)} nodes") except Exception as e: print(f" torch.export FAILED: {e}") print(" Trying with torch.export.export(..., strict=False) already set.") print(" Attempting torch.jit.trace as fallback...") try: traced = torch.jit.trace(export_encoder, example_input) traced.save(os.path.join(OUTPUT_DIR, "speaker_encoder_traced.pt")) print(" torch.jit.trace succeeded (saved as .pt, not .pte)") except Exception as e2: print(f" torch.jit.trace also failed: {e2}") sys.exit(1) # ── 5. Lower to ExecuTorch .pte ───────────────────────────────────── print("\n[5/5] Lowering to ExecuTorch .pte (XNNPACK)...") t0 = time.time() try: from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner edge = to_edge_transform_and_lower( exported, compile_config=EdgeCompileConfig(_check_ir_validity=False), partitioner=[XnnpackPartitioner()], ) et_program = edge.to_executorch() pte_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pte") with open(pte_path, "wb") as f: f.write(et_program.buffer) pte_size_mb = os.path.getsize(pte_path) / 1e6 print(f" .pte saved: {pte_path}") print(f" .pte size: {pte_size_mb:.1f} MB") print(f" Lowered in {time.time() - t0:.1f}s") except Exception as e: print(f" ExecuTorch lowering failed: {e}") print(" Saving exported program as .pt2 instead...") pt2_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pt2") torch.export.save(exported, pt2_path) print(f" Saved: {pt2_path}") # ── Validate .pte output (if available) ────────────────────────────── if os.path.exists(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")): print("\n Validating .pte execution...") try: from executorch.runtime import Runtime, Program, Method runtime = Runtime.get() program = runtime.load_program( open(os.path.join(OUTPUT_DIR, "speaker_encoder.pte"), "rb").read() ) method = program.load_method("forward") test_input = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM) pte_out = method.execute([test_input]) with torch.no_grad(): ref_out = export_encoder(test_input) if isinstance(pte_out, (list, tuple)): pte_out = pte_out[0] cos_sim_pte = torch.nn.functional.cosine_similarity( ref_out.flatten().unsqueeze(0), pte_out.flatten().unsqueeze(0) ).item() print(f" .pte vs PyTorch cosine sim: {cos_sim_pte:.6f}") except Exception as e: print(f" .pte validation failed: {e}") print(" (This may be OK — runtime validation can be done on target device)") print("\n" + "=" * 70) print("Phase 2 complete!") print(f" Input: mel spectrogram [1, {FIXED_MEL_FRAMES}, {MEL_DIM}]") print(f" Output: x-vector [1, 2048]") print("=" * 70)