| |
| """ |
| Phase 2: Export Speaker Encoder to ExecuTorch .pte |
| =================================================== |
| Extracts the ECAPA-TDNN speaker encoder, wraps it for fixed-size input, |
| exports via torch.export, and lowers to XNNPACK .pte. |
| |
| Input: mel spectrogram [1, T, 128] where T is fixed (e.g. 469 for 5s audio) |
| Output: x-vector [1, 2048] |
| """ |
|
|
| import sys |
| import os |
| import copy |
| import time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| |
| MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") |
| VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") |
| QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") |
| OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") |
|
|
| if VENV_SITE not in sys.path: |
| sys.path.insert(0, VENV_SITE) |
| if QWEN_TTS_SRC not in sys.path: |
| sys.path.insert(0, QWEN_TTS_SRC) |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| |
| |
| FIXED_MEL_FRAMES = 469 |
| MEL_DIM = 128 |
|
|
| if __name__ == "__main__": |
| _run_export = True |
| else: |
| _run_export = False |
|
|
| if _run_export: |
| print("=" * 70) |
| print("PHASE 2: Export Speaker Encoder β .pte") |
| print("=" * 70) |
|
|
| |
|
|
| print("\n[1/5] Loading model...") |
| from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig |
| from qwen_tts.core.models.modeling_qwen3_tts import ( |
| Qwen3TTSForConditionalGeneration, |
| mel_spectrogram, |
| ) |
|
|
| config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) |
| model = Qwen3TTSForConditionalGeneration.from_pretrained( |
| MODEL_PATH, config=config, dtype=torch.float32, |
| attn_implementation="sdpa", device_map="cpu", |
| ) |
| model.eval() |
| print(" Model loaded.") |
|
|
| |
|
|
| print("\n[2/5] Creating export-ready speaker encoder wrapper...") |
|
|
| class SpeakerEncoderForExport(nn.Module): |
| """ |
| Wrapper around the ECAPA-TDNN speaker encoder for ExecuTorch export. |
| |
| Takes a pre-computed mel spectrogram of fixed size and returns x-vector. |
| |
| The original speaker_encoder.forward expects [B, T, 128] (mel) |
| and transposes internally to [B, 128, T] for Conv1d processing. |
| |
| We replicate the same architecture but replace padding="same" Conv1d |
| layers with explicit padding to avoid dynamic pad calculation issues. |
| """ |
|
|
| def __init__(self, original_encoder): |
| super().__init__() |
| |
| self.encoder = copy.deepcopy(original_encoder) |
|
|
| |
| self._fix_conv_padding(self.encoder) |
|
|
| def _fix_conv_padding(self, module): |
| """ |
| Replace padding='same' Conv1d layers with explicit integer padding. |
| For kernel_size=k and dilation=d, 'same' padding = d * (k - 1) // 2 |
| when stride=1. We switch to 'zeros' padding mode and use F.pad for reflect. |
| """ |
| for name, child in module.named_children(): |
| if isinstance(child, nn.Conv1d) and child.padding == 'same': |
| |
| k = child.kernel_size[0] |
| d = child.dilation[0] |
| s = child.stride[0] |
| assert s == 1, f"padding='same' with stride != 1 not handled: {name}" |
|
|
| pad_total = d * (k - 1) |
| pad_left = pad_total // 2 |
| pad_right = pad_total - pad_left |
|
|
| |
| new_conv = _ExplicitPadConv1d(child, pad_left, pad_right, child.padding_mode) |
| setattr(module, name, new_conv) |
| else: |
| self._fix_conv_padding(child) |
|
|
| def forward(self, mel_input: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| mel_input: [1, FIXED_MEL_FRAMES, 128] β pre-computed mel spectrogram |
| Returns: |
| x_vector: [1, 2048] β speaker embedding |
| """ |
| return self.encoder(mel_input) |
|
|
|
|
| class _ExplicitPadConv1d(nn.Module): |
| """Conv1d with explicit padding instead of padding='same'.""" |
|
|
| def __init__(self, original_conv: nn.Conv1d, pad_left: int, pad_right: int, pad_mode: str): |
| super().__init__() |
| |
| self.conv = nn.Conv1d( |
| in_channels=original_conv.in_channels, |
| out_channels=original_conv.out_channels, |
| kernel_size=original_conv.kernel_size[0], |
| stride=original_conv.stride[0], |
| padding=0, |
| dilation=original_conv.dilation[0], |
| groups=original_conv.groups, |
| bias=original_conv.bias is not None, |
| ) |
| |
| self.conv.weight = original_conv.weight |
| if original_conv.bias is not None: |
| self.conv.bias = original_conv.bias |
|
|
| self.pad_left = pad_left |
| self.pad_right = pad_right |
| self.pad_mode = pad_mode |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.pad_left > 0 or self.pad_right > 0: |
| x = torch.nn.functional.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode) |
| return self.conv(x) |
|
|
|
|
| |
| export_encoder = SpeakerEncoderForExport(model.speaker_encoder) |
| export_encoder.eval() |
|
|
| |
|
|
| print("\n[3/5] Validating wrapper produces same output as original...") |
|
|
| |
| test_mel = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM) |
|
|
| with torch.no_grad(): |
| orig_out = model.speaker_encoder(test_mel) |
| wrap_out = export_encoder(test_mel) |
|
|
| cos_sim = torch.nn.functional.cosine_similarity(orig_out, wrap_out, dim=-1).item() |
| max_diff = (orig_out - wrap_out).abs().max().item() |
| print(f" Original output shape: {list(orig_out.shape)}") |
| print(f" Wrapper output shape: {list(wrap_out.shape)}") |
| print(f" Cosine similarity: {cos_sim:.6f}") |
| print(f" Max abs difference: {max_diff:.2e}") |
| assert cos_sim > 0.999, f"Wrapper diverged from original! cos_sim={cos_sim}" |
| print(" PASS β wrapper matches original") |
|
|
| |
|
|
| print("\n[4/5] Running torch.export...") |
| t0 = time.time() |
|
|
| example_input = (torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM),) |
|
|
| try: |
| exported = torch.export.export( |
| export_encoder, |
| example_input, |
| strict=False, |
| ) |
| print(f" torch.export succeeded in {time.time() - t0:.1f}s") |
| print(f" Graph has {len(exported.graph.nodes)} nodes") |
| except Exception as e: |
| print(f" torch.export FAILED: {e}") |
| print(" Trying with torch.export.export(..., strict=False) already set.") |
| print(" Attempting torch.jit.trace as fallback...") |
| try: |
| traced = torch.jit.trace(export_encoder, example_input) |
| traced.save(os.path.join(OUTPUT_DIR, "speaker_encoder_traced.pt")) |
| print(" torch.jit.trace succeeded (saved as .pt, not .pte)") |
| except Exception as e2: |
| print(f" torch.jit.trace also failed: {e2}") |
| sys.exit(1) |
|
|
| |
|
|
| print("\n[5/5] Lowering to ExecuTorch .pte (XNNPACK)...") |
| t0 = time.time() |
|
|
| try: |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| edge = to_edge_transform_and_lower( |
| exported, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
|
|
| et_program = edge.to_executorch() |
|
|
| pte_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pte") |
| with open(pte_path, "wb") as f: |
| f.write(et_program.buffer) |
|
|
| pte_size_mb = os.path.getsize(pte_path) / 1e6 |
| print(f" .pte saved: {pte_path}") |
| print(f" .pte size: {pte_size_mb:.1f} MB") |
| print(f" Lowered in {time.time() - t0:.1f}s") |
|
|
| except Exception as e: |
| print(f" ExecuTorch lowering failed: {e}") |
| print(" Saving exported program as .pt2 instead...") |
| pt2_path = os.path.join(OUTPUT_DIR, "speaker_encoder.pt2") |
| torch.export.save(exported, pt2_path) |
| print(f" Saved: {pt2_path}") |
|
|
| |
|
|
| if os.path.exists(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")): |
| print("\n Validating .pte execution...") |
| try: |
| from executorch.runtime import Runtime, Program, Method |
|
|
| runtime = Runtime.get() |
| program = runtime.load_program( |
| open(os.path.join(OUTPUT_DIR, "speaker_encoder.pte"), "rb").read() |
| ) |
| method = program.load_method("forward") |
|
|
| test_input = torch.randn(1, FIXED_MEL_FRAMES, MEL_DIM) |
| pte_out = method.execute([test_input]) |
|
|
| with torch.no_grad(): |
| ref_out = export_encoder(test_input) |
|
|
| if isinstance(pte_out, (list, tuple)): |
| pte_out = pte_out[0] |
|
|
| cos_sim_pte = torch.nn.functional.cosine_similarity( |
| ref_out.flatten().unsqueeze(0), |
| pte_out.flatten().unsqueeze(0) |
| ).item() |
| print(f" .pte vs PyTorch cosine sim: {cos_sim_pte:.6f}") |
|
|
| except Exception as e: |
| print(f" .pte validation failed: {e}") |
| print(" (This may be OK β runtime validation can be done on target device)") |
|
|
| print("\n" + "=" * 70) |
| print("Phase 2 complete!") |
| print(f" Input: mel spectrogram [1, {FIXED_MEL_FRAMES}, {MEL_DIM}]") |
| print(f" Output: x-vector [1, 2048]") |
| print("=" * 70) |
|
|