#!/usr/bin/env python3 """ Convert Kanade decoder and HiFT vocoder to CoreML. These are non-autoregressive models (single forward pass), so conversion is simpler than the LLM — no KV cache or StateType needed. Two models are produced: - KanadeDecoder.mlpackage: audio token indices + speaker embedding → mel spectrogram - HiFTVocoder.mlpackage: mel spectrogram → PCM waveform Usage: python scripts/convert_kanade.py [--output-dir PATH] [--num-tokens 100] """ import argparse from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import coremltools as ct from kanade_tokenizer import KanadeModel, load_vocoder import kanade_tokenizer.module.transformer as kanade_transformer # ── Monkey-patch Kanade's complex RoPE with real-valued version ─────────── def _apply_rotary_emb_real(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """Real-valued RoPE replacement for Kanade's complex-number version. Converts complex freqs_cis to cos/sin and applies split-half rotation. """ # freqs_cis is complex: (seq_len, head_dim/2) cos = freqs_cis.real # (seq_len, head_dim/2) sin = freqs_cis.imag # Broadcast to match x shape: (bsz, seq_len, n_heads, head_dim) # x has head_dim, cos/sin have head_dim/2 — need to double them cos = torch.cat([cos, cos], dim=-1) # (seq_len, head_dim) sin = torch.cat([sin, sin], dim=-1) # Reshape for broadcast: (1, seq_len, 1, head_dim) cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) # Split-half rotation half = x.shape[-1] // 2 x1 = x[..., :half] x2 = x[..., half:] rotated = torch.cat((-x2, x1), dim=-1) return (x * cos + rotated * sin).type_as(x) def _apply_rotary_emb_precomputed(x: torch.Tensor, freqs_cos_sin: torch.Tensor) -> torch.Tensor: """Real-valued RoPE using precomputed cos/sin stored as (seq_len, head_dim). Matches Kanade's INTERLEAVED complex multiplication: view_as_complex(x.reshape(..., -1, 2)) * (cos + i*sin) which pairs adjacent elements: (x0,x1), (x2,x3), ... Equivalent real-valued form: out[2k] = x[2k]*cos[k] - x[2k+1]*sin[k] out[2k+1] = x[2k]*sin[k] + x[2k+1]*cos[k] head_dim is always 64, so we have 32 pairs. """ # freqs_cos_sin: (seq_len, 64) where [:32] = cos, [32:] = sin cos = freqs_cos_sin[..., :32] # (seq_len, 32) sin = freqs_cos_sin[..., 32:] # (seq_len, 32) # Broadcast: (1, seq_len, 1, 32) cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) # Interleaved pairs: x has shape (..., 64), pair as (..., 32, 2) x_pairs = x.reshape(*x.shape[:-1], 32, 2) # (..., 32, 2) x_even = x_pairs[..., 0] # (..., 32) x_odd = x_pairs[..., 1] # (..., 32) # Complex multiply: (x_even + i*x_odd) * (cos + i*sin) out_even = x_even * cos - x_odd * sin out_odd = x_even * sin + x_odd * cos # Interleave back: stack on last dim then flatten out = torch.stack([out_even, out_odd], dim=-1) # (..., 32, 2) out = out.reshape(*x.shape) # (..., 64) return out.type_as(x) def _patched_attention_forward_v2(self, x, freqs_cis, mask, return_kv=False): """Attention forward with real-valued RoPE and explicit matmul. Supports local (windowed) attention via additive -inf mask. """ bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) if freqs_cis is not None: xq = _apply_rotary_emb_precomputed(xq, freqs_cis[:seqlen]) xk = _apply_rotary_emb_precomputed(xk, freqs_cis[:seqlen]) xq = xq.transpose(1, 2) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * self.scale # Build attention mask (causal + local window) if self.causal or self.use_local_attention or mask is not None: attn_mask = torch.zeros((seqlen, seqlen), device=x.device, dtype=x.dtype) if self.causal: causal = torch.triu( torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype), diagonal=1, ) attn_mask = attn_mask + causal if self.use_local_attention: # Block positions outside the window [-window_per_side, +window_per_side] local_mask = torch.triu( torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype), diagonal=self.window_per_side + 1, ) + torch.tril( torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype), diagonal=-(self.window_per_side + 1), ) attn_mask = attn_mask + local_mask attn_weights = attn_weights + attn_mask if mask is not None: attn_weights = attn_weights + mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype) output = torch.matmul(attn_weights, xv) # 12 heads * 64 head_dim = 768 output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, 768) output = self.wo(output) if return_kv: return output, (xk, xv) return output def _convert_freqs_cis_to_real(transformer_module): """Replace complex freqs_cis buffer with real-valued cos/sin concatenation.""" if hasattr(transformer_module, 'freqs_cis') and transformer_module.freqs_cis is not None: fc = transformer_module.freqs_cis # (max_len, head_dim/2) complex cos = fc.real.float() # (max_len, head_dim/2) sin = fc.imag.float() real_freqs = torch.cat([cos, sin], dim=-1) # (max_len, head_dim) # Replace the buffer del transformer_module.freqs_cis transformer_module.register_buffer('freqs_cis', real_freqs) def patch_kanade_for_coreml(kanade: KanadeModel): """Apply monkey-patches to make Kanade traceable by coremltools.""" kanade_transformer.Attention.forward = _patched_attention_forward_v2 # Convert complex freqs_cis to real in all transformers for name, module in kanade.named_modules(): if isinstance(module, kanade_transformer.Transformer): _convert_freqs_cis_to_real(module) class KanadeDecoderWrapper(nn.Module): """Wraps Kanade's decode pipeline for tracing. Pipeline: token indices → quantizer decode → mel_prenet → upsample → mel_decoder (conditioned on speaker) → mel_postnet → mel """ def __init__(self, kanade: KanadeModel, num_tokens: int): super().__init__() self.local_quantizer = kanade.local_quantizer self.mel_prenet = kanade.mel_prenet self.mel_conv_upsample = kanade.mel_conv_upsample self.mel_decoder = kanade.mel_decoder self.mel_postnet = kanade.mel_postnet self.num_tokens = num_tokens # Precompute mel_length for this token count self.mel_length = kanade._calculate_target_mel_length( kanade._calculate_original_audio_length(num_tokens) ) def forward( self, token_indices: torch.Tensor, speaker_embedding: torch.Tensor, ) -> torch.Tensor: """ Args: token_indices: (num_tokens,) int32 — Kanade codebook indices (0-12799) speaker_embedding: (1, 128) float32 — speaker embedding Returns: mel: (1, 80, mel_length) float32 """ # Quantizer decode: indices → content embedding content_emb = self.local_quantizer.decode(token_indices) # (num_tokens, 768) content_emb = content_emb.unsqueeze(0) # (1, num_tokens, 768) # Mel prenet (transformer) local_latent = self.mel_prenet(content_emb) # Upsample to mel length if self.mel_conv_upsample is not None: local_latent = self.mel_conv_upsample( local_latent.transpose(1, 2) ).transpose(1, 2) local_latent = F.interpolate( local_latent.transpose(1, 2), size=self.mel_length, mode="linear" ).transpose(1, 2) # Mel decoder (conditioned on speaker) mel = self.mel_decoder(local_latent, condition=speaker_embedding.unsqueeze(1)) mel = mel.transpose(1, 2) # (1, 80, mel_length) # Postnet mel = self.mel_postnet(mel) return mel class FullVocoderWrapper(nn.Module): """Complete mel → waveform pipeline: F0 prediction + source gen + HiFT decode + iSTFT. Noise is replaced with zeros for deterministic tracing. """ def __init__(self, vocoder, num_stft_frames: int): super().__init__() self.vocoder = vocoder self.num_stft_frames = num_stft_frames n_fft = vocoder.istft_n_fft # 16 hop_len = vocoder.istft_hop_len # 4 # iDFT basis n = torch.arange(n_fft, dtype=torch.float32) k = torch.arange(n_fft, dtype=torch.float32) angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft self.register_buffer("idft_cos", torch.cos(angles) / n_fft) self.register_buffer("idft_sin", torch.sin(angles) / n_fft) self.register_buffer("window", vocoder.stft_window.clone()) # Source generation constants self.sampling_rate = vocoder.m_source.l_sin_gen.sampling_rate self.harmonic_num = vocoder.m_source.l_sin_gen.harmonic_num # 8 self.sine_amp = vocoder.m_source.l_sin_gen.sine_amp # 0.1 self.upsample_scale = vocoder.m_source.l_sin_gen.upsample_scale # 480 # Harmonic multipliers: [1, 2, ..., 9] self.register_buffer( "harmonic_muls", torch.arange(1, self.harmonic_num + 2, dtype=torch.float32), ) # l_linear and l_tanh from m_source self.source_linear = vocoder.m_source.l_linear self.source_tanh = vocoder.m_source.l_tanh self.n_fft = n_fft self.hop_len = hop_len self.n_fft_half = n_fft // 2 + 1 def _generate_source(self, f0: torch.Tensor) -> torch.Tensor: """f0: (1, mel_length) → source_stft: (1, 18, stft_frames)""" # Upsample f0: (1, mel_length) → (1, 1, mel_length) → nearest → (1, 1, audio_length) f0_up = F.interpolate( f0.unsqueeze(1), scale_factor=float(self.upsample_scale), mode="nearest" ).squeeze(1) # (1, audio_length) # Generate harmonics: f0 * [1..9] # f0_up: (1, L) → (1, L, 1) * (9,) → (1, L, 9) fn = f0_up.unsqueeze(-1) * self.harmonic_muls.unsqueeze(0).unsqueeze(0) # Phase accumulation: cumsum(f/sr) * 2pi rad = (fn / self.sampling_rate) # instantaneous frequency in cycles per sample phase = torch.cumsum(rad, dim=1) * 2.0 * torch.pi # (1, L, 9) # Sine waves sines = torch.sin(phase) * self.sine_amp # (1, L, 9) # UV mask (voiced/unvoiced) uv = (f0_up > 0).float().unsqueeze(-1) # (1, L, 1) # Apply UV (no noise — zeros instead of randn for tracing) sines = sines * uv # (1, L, 9) # l_linear + tanh: (1, L, 9) → linear → (1, L, 1) → tanh source = self.source_tanh(self.source_linear(sines)) # (1, L, 1) source = source.squeeze(-1) # (1, L) # Manual STFT (torch.stft/unfold not CoreML-compatible) # n_fft=16, hop=4. With center padding, we get num_stft_frames frames. # Pad source: reflect pad n_fft//2 on each side padded = F.pad(source, (self.n_fft // 2, self.n_fft // 2), mode="reflect") # padded: (1, L + n_fft) where L = audio_length # Extract overlapping frames using conv1d with identity kernel # This replaces unfold: conv1d with (n_fft, 1, n_fft) identity kernel, stride=hop # Equivalent to: frames[i] = padded[i*hop : i*hop + n_fft] eye_kernel = torch.eye(self.n_fft, dtype=source.dtype, device=source.device).unsqueeze(1) # padded: (1, L+16) → (1, 1, L+16) for conv1d frames = F.conv1d(padded.unsqueeze(1), eye_kernel, stride=self.hop_len) # frames: (1, 16, num_frames) frames = frames * self.window.unsqueeze(0).unsqueeze(-1) # window each frame # Transpose to (1, num_frames, 16) for matmul frames = frames.transpose(1, 2) # DFT via matmul dft_cos = self.idft_cos[:self.n_fft_half, :] * self.n_fft # undo 1/N normalization dft_sin = self.idft_sin[:self.n_fft_half, :] * self.n_fft s_real = torch.matmul(frames, dft_cos.T) # (1, NF, 9) s_imag = -torch.matmul(frames, dft_sin.T) # (1, NF, 9) source_stft = torch.cat([s_real.transpose(1, 2), s_imag.transpose(1, 2)], dim=1) return source_stft def _istft_overlap_add(self, x: torch.Tensor) -> torch.Tensor: """x: (1, 18, num_frames) conv_post output → waveform (1, samples)""" magnitude = torch.exp(x[:, :self.n_fft_half, :]) phase = torch.sin(x[:, self.n_fft_half:, :]) real_half = magnitude * torch.cos(phase) imag_half = magnitude * torch.sin(phase) real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1]) imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1]) real_full = torch.cat([real_half, real_mirror], dim=1) imag_full = torch.cat([imag_half, imag_mirror], dim=1) real_t = real_full.transpose(1, 2) imag_t = imag_full.transpose(1, 2) segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T) NF = self.num_stft_frames segments = segments * self.window.unsqueeze(0).unsqueeze(0) seg = segments.squeeze(0) seg_chunks = seg.reshape(NF, 4, 4) b0 = seg_chunks[:, 0, :].reshape(-1) b1 = seg_chunks[:, 1, :].reshape(-1) b2 = seg_chunks[:, 2, :].reshape(-1) b3 = seg_chunks[:, 3, :].reshape(-1) F4 = NF * 4 padded_samples = NF * 4 + 12 output = torch.zeros(padded_samples) output[0:F4] = output[0:F4] + b0 output[4:F4 + 4] = output[4:F4 + 4] + b1 output[8:F4 + 8] = output[8:F4 + 8] + b2 output[12:F4 + 12] = output[12:F4 + 12] + b3 win_sq = self.window * self.window win_chunks = win_sq.reshape(4, 4) w0 = win_chunks[0].repeat(NF) w1 = win_chunks[1].repeat(NF) w2 = win_chunks[2].repeat(NF) w3 = win_chunks[3].repeat(NF) wnorm = torch.zeros(padded_samples) wnorm[0:F4] = wnorm[0:F4] + w0 wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1 wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2 wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3 output = output / (wnorm + 1e-8) pad = 8 trimmed_len = (NF - 1) * 4 output = output[pad:pad + trimmed_len] output = torch.clamp(output, -0.99, 0.99) return output.unsqueeze(0) def forward(self, mel: torch.Tensor) -> torch.Tensor: """mel: (1, 80, T) → waveform: (1, samples)""" # F0 prediction f0 = self.vocoder.f0_predictor(mel) # (1, T) # Source generation source_stft = self._generate_source(f0) # HiFT decode x = self.vocoder.conv_pre(mel) for i in range(self.vocoder.num_upsamples): x = F.leaky_relu(x, self.vocoder.lrelu_slope) x = self.vocoder.ups[i](x) if i == self.vocoder.num_upsamples - 1: x = self.vocoder.reflection_pad(x) si = self.vocoder.source_downs[i](source_stft) si = self.vocoder.source_resblocks[i](si) x = x + si xs = None for j in range(self.vocoder.num_kernels): if xs is None: xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) else: xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) x = xs / self.vocoder.num_kernels x = F.leaky_relu(x) x = self.vocoder.conv_post(x) return self._istft_overlap_add(x) class F0PredictorWrapper(nn.Module): """Wraps HiFT's f0 predictor: mel → f0.""" def __init__(self, vocoder): super().__init__() self.f0_predictor = vocoder.f0_predictor def forward(self, mel: torch.Tensor) -> torch.Tensor: """mel: (1, 80, T) → f0: (1, 1, T)""" return self.f0_predictor(mel) class HiFTDecodeWrapper(nn.Module): """Wraps HiFT's decode stage: mel + source_stft → waveform. Includes a manual iSTFT implementation using matmul with a precomputed DFT basis matrix, so the entire pipeline runs inside CoreML. """ def __init__(self, vocoder, num_stft_frames: int): super().__init__() self.vocoder = vocoder self.num_stft_frames = num_stft_frames # hardcoded for tracing n_fft = vocoder.istft_n_fft # 16 hop_len = vocoder.istft_hop_len # 4 # Precompute DFT basis for iSTFT: (n_fft, n_fft) real-valued IDFT matrix # X[k] = sum_n x[n] * exp(j*2pi*n*k/N) → x[n] = (1/N) * sum_k X[k] * exp(j*2pi*n*k/N) n = torch.arange(n_fft, dtype=torch.float32) k = torch.arange(n_fft, dtype=torch.float32) angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft # (n_fft, n_fft) # cos/sin basis for real/imag parts self.register_buffer("idft_cos", torch.cos(angles) / n_fft) # (n_fft, n_fft) self.register_buffer("idft_sin", torch.sin(angles) / n_fft) # (n_fft, n_fft) # Window for overlap-add self.register_buffer("window", vocoder.stft_window.clone()) self.n_fft = n_fft self.hop_len = hop_len self.n_fft_half = n_fft // 2 + 1 # 9 def forward(self, mel: torch.Tensor, source_stft: torch.Tensor) -> torch.Tensor: """ Args: mel: (1, 80, T) float32 source_stft: (1, 18, T') float32 — real+imag STFT of source signal Returns: waveform: (1, samples) float32 """ x = self.vocoder.conv_pre(mel) for i in range(self.vocoder.num_upsamples): x = F.leaky_relu(x, self.vocoder.lrelu_slope) x = self.vocoder.ups[i](x) if i == self.vocoder.num_upsamples - 1: x = self.vocoder.reflection_pad(x) si = self.vocoder.source_downs[i](source_stft) si = self.vocoder.source_resblocks[i](si) x = x + si xs = None for j in range(self.vocoder.num_kernels): if xs is None: xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) else: xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) x = xs / self.vocoder.num_kernels x = F.leaky_relu(x) x = self.vocoder.conv_post(x) # (1, 18, num_frames) # Split into magnitude and phase magnitude = torch.exp(x[:, :self.n_fft_half, :]) # (1, 9, num_frames) phase = torch.sin(x[:, self.n_fft_half:, :]) # (1, 9, num_frames) # Convert to real/imag real_half = magnitude * torch.cos(phase) # (1, 9, num_frames) imag_half = magnitude * torch.sin(phase) # Mirror to full spectrum (Hermitian symmetry) # real: [r0, r1, ..., r8, r7, r6, ..., r1] # imag: [i0, i1, ..., i8, -i7, -i6, ..., -i1] real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1]) imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1]) real_full = torch.cat([real_half, real_mirror], dim=1) # (1, 16, num_frames) imag_full = torch.cat([imag_half, imag_mirror], dim=1) # (1, 16, num_frames) # iDFT via matmul: output[n] = sum_k (real[k]*cos[n,k] - imag[k]*sin[n,k]) # (1, 16, num_frames) → transpose to (1, num_frames, 16) → matmul with (16, 16) real_t = real_full.transpose(1, 2) # (1, num_frames, 16) imag_t = imag_full.transpose(1, 2) # segments[n] = sum_k real[k]*cos[n,k] - imag[k]*sin[n,k] # = real_t @ idft_cos.T - imag_t @ idft_sin.T # But idft_cos is (n_fft, n_fft) where idft_cos[n,k] = cos(2pi*n*k/N)/N # We want segments[frame, n] = sum_k (real[frame,k] * idft_cos[n,k] - imag[frame,k] * idft_sin[n,k]) # = (real_t @ idft_cos^T - imag_t @ idft_sin^T)[frame, n] segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T) # segments: (1, num_frames, 16) # Overlap-add with window # n_fft=16, hop=4, so overlap ratio = 4 (each sample covered by 4 frames) NF = self.num_stft_frames # hardcoded constant for tracing segments = segments * self.window.unsqueeze(0).unsqueeze(0) # (1, NF, 16) seg = segments.squeeze(0) # (NF, 16) # Reshape each 16-sample segment into 4 chunks of 4 (hop_len) samples # seg: (F, 16) → (F, 4, 4) seg_chunks = seg.reshape(NF, 4, 4) # (F, 4_blocks, 4_samples) # Block b of frame f lands at output position (f + b) * hop_len # Rearrange so block b from all frames is contiguous: # chunk_b[f] = seg_chunks[f, b, :] lands at output[(f+b)*4 : (f+b)*4 + 4] # = output index f*4 + b*4 ... but shifted by b frames # Equivalently: for block b, we have F values that go to positions b, b+1, ..., b+F-1 # in units of hop_len # For each sub-block offset (0..3), create a flat array and add shifted # Using static slicing only — no dynamic indexing padded_samples = NF * 4 + 12 # (NF-1)*4 + 16 # Actually: (num_frames - 1) * 4 + 16 = num_frames * 4 + 12 # Each sub-block b contributes F chunks of 4 samples, placed at positions # starting from b*4 with stride 4 between frames. # block_b = seg_chunks[:, b, :].reshape(-1) → F*4 contiguous values # These go to output[b*4 : b*4 + F*4] b0 = seg_chunks[:, 0, :].reshape(-1) # (F*4,) → output[0 : F*4] b1 = seg_chunks[:, 1, :].reshape(-1) # (F*4,) → output[4 : F*4 + 4] b2 = seg_chunks[:, 2, :].reshape(-1) # (F*4,) → output[8 : F*4 + 8] b3 = seg_chunks[:, 3, :].reshape(-1) # (F*4,) → output[12 : F*4 + 12] F4 = NF * 4 output = torch.zeros(padded_samples) output[0:F4] = output[0:F4] + b0 output[4:F4 + 4] = output[4:F4 + 4] + b1 output[8:F4 + 8] = output[8:F4 + 8] + b2 output[12:F4 + 12] = output[12:F4 + 12] + b3 # Window normalization — same structure win_sq = self.window * self.window # (16,) win_chunks = win_sq.reshape(4, 4) # (4_blocks, 4_samples) w0 = win_chunks[0].repeat(NF) w1 = win_chunks[1].repeat(NF) w2 = win_chunks[2].repeat(NF) w3 = win_chunks[3].repeat(NF) wnorm = torch.zeros(padded_samples) wnorm[0:F4] = wnorm[0:F4] + w0 wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1 wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2 wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3 output = output / (wnorm + 1e-8) # Trim center padding: n_fft//2 = 8 from start pad = 8 trimmed_len = (NF - 1) * 4 # expected output length output = output[pad:pad + trimmed_len] output = torch.clamp(output, -0.99, 0.99) return output.unsqueeze(0) # (1, samples) def convert_kanade_decoder(kanade: KanadeModel, num_tokens: int, output_dir: Path): """Convert Kanade decoder to CoreML.""" wrapper = KanadeDecoderWrapper(kanade, num_tokens).eval().float() print(f"Tracing Kanade decoder (num_tokens={num_tokens}, mel_length={wrapper.mel_length})...") token_indices = torch.arange(num_tokens, dtype=torch.int32) speaker_embedding = torch.randn(1, 128, dtype=torch.float32) with torch.no_grad(): # Test forward mel = wrapper(token_indices, speaker_embedding) print(f" Output mel shape: {mel.shape}") traced = torch.jit.trace(wrapper, (token_indices, speaker_embedding)) print("Converting Kanade decoder to CoreML...") mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="token_indices", shape=(num_tokens,), dtype=np.int32), ct.TensorType(name="speaker_embedding", shape=(1, 128), dtype=np.float32), ], outputs=[ct.TensorType(name="mel", dtype=np.float32)], compute_precision=ct.precision.FLOAT32, minimum_deployment_target=ct.target.iOS17, ) out_path = output_dir / "KanadeDecoder.mlpackage" mlmodel.save(str(out_path)) print(f"Saved Kanade decoder to {out_path}") def convert_f0_predictor(vocoder, mel_length: int, output_dir: Path): """Convert HiFT f0 predictor to CoreML.""" wrapper = F0PredictorWrapper(vocoder).eval().float() print(f"Tracing F0 predictor (mel_length={mel_length})...") mel = torch.randn(1, 80, mel_length, dtype=torch.float32) with torch.no_grad(): f0 = wrapper(mel) print(f" Output f0 shape: {f0.shape}") traced = torch.jit.trace(wrapper, (mel,)) print("Converting F0 predictor to CoreML...") mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32), ], outputs=[ct.TensorType(name="f0", dtype=np.float32)], compute_precision=ct.precision.FLOAT32, minimum_deployment_target=ct.target.iOS17, ) out_path = output_dir / "F0Predictor.mlpackage" mlmodel.save(str(out_path)) print(f"Saved F0 predictor to {out_path}") def convert_hift_decode(vocoder, mel_length: int, output_dir: Path): """Convert HiFT decode stage to CoreML. Source signal STFT must be computed externally (Swift side). """ # Compute source_stft shape: run f0 predictor + source module to get it mel = torch.randn(1, 80, mel_length, dtype=torch.float32) with torch.no_grad(): f0 = vocoder.f0_predictor(mel) s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2) s, _, _ = vocoder.m_source(s) s = s.transpose(1, 2) s_stft_real, s_stft_imag = vocoder._stft(s.squeeze(1)) source_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) num_stft_frames = source_stft.shape[2] print(f" Source STFT shape: {source_stft.shape} ({num_stft_frames} frames)") wrapper = HiFTDecodeWrapper(vocoder, num_stft_frames).eval().float() print(f"Tracing HiFT decode (mel_length={mel_length})...") with torch.no_grad(): waveform = wrapper(mel, source_stft) print(f" Output waveform shape: {waveform.shape}") traced = torch.jit.trace(wrapper, (mel, source_stft)) print("Converting HiFT decode to CoreML...") source_stft_channels = source_stft.shape[1] source_stft_time = source_stft.shape[2] mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32), ct.TensorType( name="source_stft", shape=(1, source_stft_channels, source_stft_time), dtype=np.float32, ), ], outputs=[ct.TensorType(name="waveform", dtype=np.float32)], compute_precision=ct.precision.FLOAT32, minimum_deployment_target=ct.target.iOS17, ) out_path = output_dir / "HiFTDecode.mlpackage" mlmodel.save(str(out_path)) print(f"Saved HiFT decode to {out_path}") def main(): parser = argparse.ArgumentParser(description="Convert Kanade + HiFT to CoreML") parser.add_argument( "--output-dir", type=str, default=str(Path(__file__).parent.parent), help="Output directory", ) parser.add_argument( "--num-tokens", type=int, default=100, help="Fixed number of audio tokens (determines mel length)", ) args = parser.parse_args() convert_audio(Path(args.output_dir), args.num_tokens) def convert_audio(output_dir: Path, num_tokens: int = 100) -> tuple[Path, Path]: """Convert Kanade decoder + HiFT vocoder to CoreML. Returns (KanadeDecoder.mlpackage, Vocoder.mlpackage) paths.""" output_dir.mkdir(parents=True, exist_ok=True) print("Loading Kanade model...") kanade = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval().float() patch_kanade_for_coreml(kanade) vocoder = load_vocoder(kanade.config.vocoder_name).eval().float() mel_length = kanade._calculate_target_mel_length( kanade._calculate_original_audio_length(num_tokens) ) print(f"\n=== Converting Kanade decoder ===") convert_kanade_decoder(kanade, num_tokens, output_dir) print(f"\n=== Converting full vocoder (mel → waveform) ===") convert_full_vocoder(vocoder, mel_length, output_dir) print("\nAudio conversion complete!") print(f" KanadeDecoder: {num_tokens} tokens → mel (80, {mel_length})") print(f" Vocoder: mel (80, {mel_length}) → waveform") return output_dir / "KanadeDecoder.mlpackage", output_dir / "Vocoder.mlpackage" def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path): """Convert complete mel→waveform vocoder to CoreML.""" # Get num_stft_frames by running a dummy forward mel = torch.randn(1, 80, mel_length, dtype=torch.float32) with torch.no_grad(): f0 = vocoder.f0_predictor(mel) s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2) s, _, _ = vocoder.m_source(s) s = s.transpose(1, 2) sr, si = vocoder._stft(s.squeeze(1)) num_stft_frames = sr.shape[2] print(f" STFT frames: {num_stft_frames}") wrapper = FullVocoderWrapper(vocoder, num_stft_frames).eval().float() print(f"Tracing full vocoder (mel_length={mel_length})...") # Replace randn_like with zeros for tracing orig_randn = torch.randn_like torch.randn_like = lambda x, **kw: torch.zeros_like(x) with torch.no_grad(): wav = wrapper(mel) print(f" Output waveform: {wav.shape}") traced = torch.jit.trace(wrapper, (mel,)) torch.randn_like = orig_randn print("Converting full vocoder to CoreML...") mlmodel = ct.convert( traced, inputs=[ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32)], outputs=[ct.TensorType(name="waveform", dtype=np.float32)], compute_precision=ct.precision.FLOAT32, minimum_deployment_target=ct.target.iOS17, ) out_path = output_dir / "Vocoder.mlpackage" mlmodel.save(str(out_path)) print(f"Saved vocoder to {out_path}") if __name__ == "__main__": main()