| |
| """ |
| Convert Kanade decoder and HiFT vocoder to CoreML. |
| |
| These are non-autoregressive models (single forward pass), so conversion |
| is simpler than the LLM β no KV cache or StateType needed. |
| |
| Two models are produced: |
| - KanadeDecoder.mlpackage: audio token indices + speaker embedding β mel spectrogram |
| - HiFTVocoder.mlpackage: mel spectrogram β PCM waveform |
| |
| Usage: |
| python scripts/convert_kanade.py [--output-dir PATH] [--num-tokens 100] |
| """ |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import coremltools as ct |
| from kanade_tokenizer import KanadeModel, load_vocoder |
| import kanade_tokenizer.module.transformer as kanade_transformer |
|
|
|
|
| |
|
|
| def _apply_rotary_emb_real(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| """Real-valued RoPE replacement for Kanade's complex-number version. |
| Converts complex freqs_cis to cos/sin and applies split-half rotation. |
| """ |
| |
| cos = freqs_cis.real |
| sin = freqs_cis.imag |
| |
| |
| cos = torch.cat([cos, cos], dim=-1) |
| sin = torch.cat([sin, sin], dim=-1) |
| |
| cos = cos.unsqueeze(0).unsqueeze(2) |
| sin = sin.unsqueeze(0).unsqueeze(2) |
| |
| half = x.shape[-1] // 2 |
| x1 = x[..., :half] |
| x2 = x[..., half:] |
| rotated = torch.cat((-x2, x1), dim=-1) |
| return (x * cos + rotated * sin).type_as(x) |
|
|
|
|
| def _apply_rotary_emb_precomputed(x: torch.Tensor, freqs_cos_sin: torch.Tensor) -> torch.Tensor: |
| """Real-valued RoPE using precomputed cos/sin stored as (seq_len, head_dim). |
| |
| Matches Kanade's INTERLEAVED complex multiplication: |
| view_as_complex(x.reshape(..., -1, 2)) * (cos + i*sin) |
| which pairs adjacent elements: (x0,x1), (x2,x3), ... |
| |
| Equivalent real-valued form: |
| out[2k] = x[2k]*cos[k] - x[2k+1]*sin[k] |
| out[2k+1] = x[2k]*sin[k] + x[2k+1]*cos[k] |
| |
| head_dim is always 64, so we have 32 pairs. |
| """ |
| |
| cos = freqs_cos_sin[..., :32] |
| sin = freqs_cos_sin[..., 32:] |
| |
| cos = cos.unsqueeze(0).unsqueeze(2) |
| sin = sin.unsqueeze(0).unsqueeze(2) |
|
|
| |
| x_pairs = x.reshape(*x.shape[:-1], 32, 2) |
| x_even = x_pairs[..., 0] |
| x_odd = x_pairs[..., 1] |
|
|
| |
| out_even = x_even * cos - x_odd * sin |
| out_odd = x_even * sin + x_odd * cos |
|
|
| |
| out = torch.stack([out_even, out_odd], dim=-1) |
| out = out.reshape(*x.shape) |
| return out.type_as(x) |
|
|
|
|
| def _patched_attention_forward_v2(self, x, freqs_cis, mask, return_kv=False): |
| """Attention forward with real-valued RoPE and explicit matmul. |
| Supports local (windowed) attention via additive -inf mask. |
| """ |
| bsz, seqlen, _ = x.shape |
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
|
|
| xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) |
| xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) |
| xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) |
|
|
| if freqs_cis is not None: |
| xq = _apply_rotary_emb_precomputed(xq, freqs_cis[:seqlen]) |
| xk = _apply_rotary_emb_precomputed(xk, freqs_cis[:seqlen]) |
|
|
| xq = xq.transpose(1, 2) |
| xk = xk.transpose(1, 2) |
| xv = xv.transpose(1, 2) |
|
|
| attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * self.scale |
|
|
| |
| if self.causal or self.use_local_attention or mask is not None: |
| attn_mask = torch.zeros((seqlen, seqlen), device=x.device, dtype=x.dtype) |
| if self.causal: |
| causal = torch.triu( |
| torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype), |
| diagonal=1, |
| ) |
| attn_mask = attn_mask + causal |
| if self.use_local_attention: |
| |
| local_mask = torch.triu( |
| torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype), |
| diagonal=self.window_per_side + 1, |
| ) + torch.tril( |
| torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype), |
| diagonal=-(self.window_per_side + 1), |
| ) |
| attn_mask = attn_mask + local_mask |
| attn_weights = attn_weights + attn_mask |
| if mask is not None: |
| attn_weights = attn_weights + mask |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype) |
| output = torch.matmul(attn_weights, xv) |
|
|
| |
| output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, 768) |
| output = self.wo(output) |
|
|
| if return_kv: |
| return output, (xk, xv) |
| return output |
|
|
|
|
| def _convert_freqs_cis_to_real(transformer_module): |
| """Replace complex freqs_cis buffer with real-valued cos/sin concatenation.""" |
| if hasattr(transformer_module, 'freqs_cis') and transformer_module.freqs_cis is not None: |
| fc = transformer_module.freqs_cis |
| cos = fc.real.float() |
| sin = fc.imag.float() |
| real_freqs = torch.cat([cos, sin], dim=-1) |
| |
| del transformer_module.freqs_cis |
| transformer_module.register_buffer('freqs_cis', real_freqs) |
|
|
|
|
| def patch_kanade_for_coreml(kanade: KanadeModel): |
| """Apply monkey-patches to make Kanade traceable by coremltools.""" |
| kanade_transformer.Attention.forward = _patched_attention_forward_v2 |
| |
| for name, module in kanade.named_modules(): |
| if isinstance(module, kanade_transformer.Transformer): |
| _convert_freqs_cis_to_real(module) |
|
|
|
|
| class KanadeDecoderWrapper(nn.Module): |
| """Wraps Kanade's decode pipeline for tracing. |
| |
| Pipeline: token indices β quantizer decode β mel_prenet β upsample β |
| mel_decoder (conditioned on speaker) β mel_postnet β mel |
| """ |
|
|
| def __init__(self, kanade: KanadeModel, num_tokens: int): |
| super().__init__() |
| self.local_quantizer = kanade.local_quantizer |
| self.mel_prenet = kanade.mel_prenet |
| self.mel_conv_upsample = kanade.mel_conv_upsample |
| self.mel_decoder = kanade.mel_decoder |
| self.mel_postnet = kanade.mel_postnet |
| self.num_tokens = num_tokens |
| |
| self.mel_length = kanade._calculate_target_mel_length( |
| kanade._calculate_original_audio_length(num_tokens) |
| ) |
|
|
| def forward( |
| self, |
| token_indices: torch.Tensor, |
| speaker_embedding: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| token_indices: (num_tokens,) int32 β Kanade codebook indices (0-12799) |
| speaker_embedding: (1, 128) float32 β speaker embedding |
| |
| Returns: |
| mel: (1, 80, mel_length) float32 |
| """ |
| |
| content_emb = self.local_quantizer.decode(token_indices) |
| content_emb = content_emb.unsqueeze(0) |
|
|
| |
| local_latent = self.mel_prenet(content_emb) |
|
|
| |
| if self.mel_conv_upsample is not None: |
| local_latent = self.mel_conv_upsample( |
| local_latent.transpose(1, 2) |
| ).transpose(1, 2) |
| local_latent = F.interpolate( |
| local_latent.transpose(1, 2), size=self.mel_length, mode="linear" |
| ).transpose(1, 2) |
|
|
| |
| mel = self.mel_decoder(local_latent, condition=speaker_embedding.unsqueeze(1)) |
| mel = mel.transpose(1, 2) |
|
|
| |
| mel = self.mel_postnet(mel) |
| return mel |
|
|
|
|
| class FullVocoderWrapper(nn.Module): |
| """Complete mel β waveform pipeline: F0 prediction + source gen + HiFT decode + iSTFT. |
| |
| Noise is replaced with zeros for deterministic tracing. |
| """ |
|
|
| def __init__(self, vocoder, num_stft_frames: int): |
| super().__init__() |
| self.vocoder = vocoder |
| self.num_stft_frames = num_stft_frames |
| n_fft = vocoder.istft_n_fft |
| hop_len = vocoder.istft_hop_len |
|
|
| |
| n = torch.arange(n_fft, dtype=torch.float32) |
| k = torch.arange(n_fft, dtype=torch.float32) |
| angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft |
| self.register_buffer("idft_cos", torch.cos(angles) / n_fft) |
| self.register_buffer("idft_sin", torch.sin(angles) / n_fft) |
| self.register_buffer("window", vocoder.stft_window.clone()) |
|
|
| |
| self.sampling_rate = vocoder.m_source.l_sin_gen.sampling_rate |
| self.harmonic_num = vocoder.m_source.l_sin_gen.harmonic_num |
| self.sine_amp = vocoder.m_source.l_sin_gen.sine_amp |
| self.upsample_scale = vocoder.m_source.l_sin_gen.upsample_scale |
|
|
| |
| self.register_buffer( |
| "harmonic_muls", |
| torch.arange(1, self.harmonic_num + 2, dtype=torch.float32), |
| ) |
|
|
| |
| self.source_linear = vocoder.m_source.l_linear |
| self.source_tanh = vocoder.m_source.l_tanh |
|
|
| self.n_fft = n_fft |
| self.hop_len = hop_len |
| self.n_fft_half = n_fft // 2 + 1 |
|
|
| def _generate_source(self, f0: torch.Tensor) -> torch.Tensor: |
| """f0: (1, mel_length) β source_stft: (1, 18, stft_frames)""" |
| |
| f0_up = F.interpolate( |
| f0.unsqueeze(1), scale_factor=float(self.upsample_scale), mode="nearest" |
| ).squeeze(1) |
|
|
| |
| |
| fn = f0_up.unsqueeze(-1) * self.harmonic_muls.unsqueeze(0).unsqueeze(0) |
|
|
| |
| rad = (fn / self.sampling_rate) |
| phase = torch.cumsum(rad, dim=1) * 2.0 * torch.pi |
|
|
| |
| sines = torch.sin(phase) * self.sine_amp |
|
|
| |
| uv = (f0_up > 0).float().unsqueeze(-1) |
|
|
| |
| sines = sines * uv |
|
|
| |
| source = self.source_tanh(self.source_linear(sines)) |
| source = source.squeeze(-1) |
|
|
| |
| |
| |
| padded = F.pad(source, (self.n_fft // 2, self.n_fft // 2), mode="reflect") |
| |
|
|
| |
| |
| |
| eye_kernel = torch.eye(self.n_fft, dtype=source.dtype, device=source.device).unsqueeze(1) |
| |
| frames = F.conv1d(padded.unsqueeze(1), eye_kernel, stride=self.hop_len) |
| |
| frames = frames * self.window.unsqueeze(0).unsqueeze(-1) |
| |
| frames = frames.transpose(1, 2) |
|
|
| |
| dft_cos = self.idft_cos[:self.n_fft_half, :] * self.n_fft |
| dft_sin = self.idft_sin[:self.n_fft_half, :] * self.n_fft |
| s_real = torch.matmul(frames, dft_cos.T) |
| s_imag = -torch.matmul(frames, dft_sin.T) |
| source_stft = torch.cat([s_real.transpose(1, 2), s_imag.transpose(1, 2)], dim=1) |
| return source_stft |
|
|
| def _istft_overlap_add(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (1, 18, num_frames) conv_post output β waveform (1, samples)""" |
| magnitude = torch.exp(x[:, :self.n_fft_half, :]) |
| phase = torch.sin(x[:, self.n_fft_half:, :]) |
|
|
| real_half = magnitude * torch.cos(phase) |
| imag_half = magnitude * torch.sin(phase) |
|
|
| real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1]) |
| imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1]) |
| real_full = torch.cat([real_half, real_mirror], dim=1) |
| imag_full = torch.cat([imag_half, imag_mirror], dim=1) |
|
|
| real_t = real_full.transpose(1, 2) |
| imag_t = imag_full.transpose(1, 2) |
| segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T) |
|
|
| NF = self.num_stft_frames |
| segments = segments * self.window.unsqueeze(0).unsqueeze(0) |
| seg = segments.squeeze(0) |
| seg_chunks = seg.reshape(NF, 4, 4) |
|
|
| b0 = seg_chunks[:, 0, :].reshape(-1) |
| b1 = seg_chunks[:, 1, :].reshape(-1) |
| b2 = seg_chunks[:, 2, :].reshape(-1) |
| b3 = seg_chunks[:, 3, :].reshape(-1) |
|
|
| F4 = NF * 4 |
| padded_samples = NF * 4 + 12 |
| output = torch.zeros(padded_samples) |
| output[0:F4] = output[0:F4] + b0 |
| output[4:F4 + 4] = output[4:F4 + 4] + b1 |
| output[8:F4 + 8] = output[8:F4 + 8] + b2 |
| output[12:F4 + 12] = output[12:F4 + 12] + b3 |
|
|
| win_sq = self.window * self.window |
| win_chunks = win_sq.reshape(4, 4) |
| w0 = win_chunks[0].repeat(NF) |
| w1 = win_chunks[1].repeat(NF) |
| w2 = win_chunks[2].repeat(NF) |
| w3 = win_chunks[3].repeat(NF) |
|
|
| wnorm = torch.zeros(padded_samples) |
| wnorm[0:F4] = wnorm[0:F4] + w0 |
| wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1 |
| wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2 |
| wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3 |
|
|
| output = output / (wnorm + 1e-8) |
| pad = 8 |
| trimmed_len = (NF - 1) * 4 |
| output = output[pad:pad + trimmed_len] |
| output = torch.clamp(output, -0.99, 0.99) |
| return output.unsqueeze(0) |
|
|
| def forward(self, mel: torch.Tensor) -> torch.Tensor: |
| """mel: (1, 80, T) β waveform: (1, samples)""" |
| |
| f0 = self.vocoder.f0_predictor(mel) |
|
|
| |
| source_stft = self._generate_source(f0) |
|
|
| |
| x = self.vocoder.conv_pre(mel) |
| for i in range(self.vocoder.num_upsamples): |
| x = F.leaky_relu(x, self.vocoder.lrelu_slope) |
| x = self.vocoder.ups[i](x) |
| if i == self.vocoder.num_upsamples - 1: |
| x = self.vocoder.reflection_pad(x) |
| si = self.vocoder.source_downs[i](source_stft) |
| si = self.vocoder.source_resblocks[i](si) |
| x = x + si |
| xs = None |
| for j in range(self.vocoder.num_kernels): |
| if xs is None: |
| xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) |
| else: |
| xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) |
| x = xs / self.vocoder.num_kernels |
|
|
| x = F.leaky_relu(x) |
| x = self.vocoder.conv_post(x) |
|
|
| return self._istft_overlap_add(x) |
|
|
|
|
| class F0PredictorWrapper(nn.Module): |
| """Wraps HiFT's f0 predictor: mel β f0.""" |
|
|
| def __init__(self, vocoder): |
| super().__init__() |
| self.f0_predictor = vocoder.f0_predictor |
|
|
| def forward(self, mel: torch.Tensor) -> torch.Tensor: |
| """mel: (1, 80, T) β f0: (1, 1, T)""" |
| return self.f0_predictor(mel) |
|
|
|
|
| class HiFTDecodeWrapper(nn.Module): |
| """Wraps HiFT's decode stage: mel + source_stft β waveform. |
| |
| Includes a manual iSTFT implementation using matmul with a precomputed |
| DFT basis matrix, so the entire pipeline runs inside CoreML. |
| """ |
|
|
| def __init__(self, vocoder, num_stft_frames: int): |
| super().__init__() |
| self.vocoder = vocoder |
| self.num_stft_frames = num_stft_frames |
| n_fft = vocoder.istft_n_fft |
| hop_len = vocoder.istft_hop_len |
|
|
| |
| |
| n = torch.arange(n_fft, dtype=torch.float32) |
| k = torch.arange(n_fft, dtype=torch.float32) |
| angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft |
| |
| self.register_buffer("idft_cos", torch.cos(angles) / n_fft) |
| self.register_buffer("idft_sin", torch.sin(angles) / n_fft) |
|
|
| |
| self.register_buffer("window", vocoder.stft_window.clone()) |
| self.n_fft = n_fft |
| self.hop_len = hop_len |
| self.n_fft_half = n_fft // 2 + 1 |
|
|
| def forward(self, mel: torch.Tensor, source_stft: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| mel: (1, 80, T) float32 |
| source_stft: (1, 18, T') float32 β real+imag STFT of source signal |
| |
| Returns: |
| waveform: (1, samples) float32 |
| """ |
| x = self.vocoder.conv_pre(mel) |
| for i in range(self.vocoder.num_upsamples): |
| x = F.leaky_relu(x, self.vocoder.lrelu_slope) |
| x = self.vocoder.ups[i](x) |
| if i == self.vocoder.num_upsamples - 1: |
| x = self.vocoder.reflection_pad(x) |
|
|
| si = self.vocoder.source_downs[i](source_stft) |
| si = self.vocoder.source_resblocks[i](si) |
| x = x + si |
|
|
| xs = None |
| for j in range(self.vocoder.num_kernels): |
| if xs is None: |
| xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) |
| else: |
| xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x) |
| x = xs / self.vocoder.num_kernels |
|
|
| x = F.leaky_relu(x) |
| x = self.vocoder.conv_post(x) |
|
|
| |
| magnitude = torch.exp(x[:, :self.n_fft_half, :]) |
| phase = torch.sin(x[:, self.n_fft_half:, :]) |
|
|
| |
| real_half = magnitude * torch.cos(phase) |
| imag_half = magnitude * torch.sin(phase) |
|
|
| |
| |
| |
| real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1]) |
| imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1]) |
| real_full = torch.cat([real_half, real_mirror], dim=1) |
| imag_full = torch.cat([imag_half, imag_mirror], dim=1) |
|
|
| |
| |
| real_t = real_full.transpose(1, 2) |
| imag_t = imag_full.transpose(1, 2) |
| |
| |
| |
| |
| |
| segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T) |
| |
|
|
| |
| |
| NF = self.num_stft_frames |
| segments = segments * self.window.unsqueeze(0).unsqueeze(0) |
| seg = segments.squeeze(0) |
|
|
| |
| |
| seg_chunks = seg.reshape(NF, 4, 4) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| padded_samples = NF * 4 + 12 |
| |
|
|
| |
| |
| |
| |
| b0 = seg_chunks[:, 0, :].reshape(-1) |
| b1 = seg_chunks[:, 1, :].reshape(-1) |
| b2 = seg_chunks[:, 2, :].reshape(-1) |
| b3 = seg_chunks[:, 3, :].reshape(-1) |
|
|
| F4 = NF * 4 |
| output = torch.zeros(padded_samples) |
| output[0:F4] = output[0:F4] + b0 |
| output[4:F4 + 4] = output[4:F4 + 4] + b1 |
| output[8:F4 + 8] = output[8:F4 + 8] + b2 |
| output[12:F4 + 12] = output[12:F4 + 12] + b3 |
|
|
| |
| win_sq = self.window * self.window |
| win_chunks = win_sq.reshape(4, 4) |
| w0 = win_chunks[0].repeat(NF) |
| w1 = win_chunks[1].repeat(NF) |
| w2 = win_chunks[2].repeat(NF) |
| w3 = win_chunks[3].repeat(NF) |
|
|
| wnorm = torch.zeros(padded_samples) |
| wnorm[0:F4] = wnorm[0:F4] + w0 |
| wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1 |
| wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2 |
| wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3 |
|
|
| output = output / (wnorm + 1e-8) |
|
|
| |
| pad = 8 |
| trimmed_len = (NF - 1) * 4 |
| output = output[pad:pad + trimmed_len] |
| output = torch.clamp(output, -0.99, 0.99) |
| return output.unsqueeze(0) |
|
|
|
|
| def convert_kanade_decoder(kanade: KanadeModel, num_tokens: int, output_dir: Path): |
| """Convert Kanade decoder to CoreML.""" |
| wrapper = KanadeDecoderWrapper(kanade, num_tokens).eval().float() |
| print(f"Tracing Kanade decoder (num_tokens={num_tokens}, mel_length={wrapper.mel_length})...") |
|
|
| token_indices = torch.arange(num_tokens, dtype=torch.int32) |
| speaker_embedding = torch.randn(1, 128, dtype=torch.float32) |
|
|
| with torch.no_grad(): |
| |
| mel = wrapper(token_indices, speaker_embedding) |
| print(f" Output mel shape: {mel.shape}") |
|
|
| traced = torch.jit.trace(wrapper, (token_indices, speaker_embedding)) |
|
|
| print("Converting Kanade decoder to CoreML...") |
| mlmodel = ct.convert( |
| traced, |
| inputs=[ |
| ct.TensorType(name="token_indices", shape=(num_tokens,), dtype=np.int32), |
| ct.TensorType(name="speaker_embedding", shape=(1, 128), dtype=np.float32), |
| ], |
| outputs=[ct.TensorType(name="mel", dtype=np.float32)], |
| compute_precision=ct.precision.FLOAT32, |
| minimum_deployment_target=ct.target.iOS17, |
| ) |
|
|
| out_path = output_dir / "KanadeDecoder.mlpackage" |
| mlmodel.save(str(out_path)) |
| print(f"Saved Kanade decoder to {out_path}") |
|
|
|
|
| def convert_f0_predictor(vocoder, mel_length: int, output_dir: Path): |
| """Convert HiFT f0 predictor to CoreML.""" |
| wrapper = F0PredictorWrapper(vocoder).eval().float() |
| print(f"Tracing F0 predictor (mel_length={mel_length})...") |
|
|
| mel = torch.randn(1, 80, mel_length, dtype=torch.float32) |
|
|
| with torch.no_grad(): |
| f0 = wrapper(mel) |
| print(f" Output f0 shape: {f0.shape}") |
| traced = torch.jit.trace(wrapper, (mel,)) |
|
|
| print("Converting F0 predictor to CoreML...") |
| mlmodel = ct.convert( |
| traced, |
| inputs=[ |
| ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32), |
| ], |
| outputs=[ct.TensorType(name="f0", dtype=np.float32)], |
| compute_precision=ct.precision.FLOAT32, |
| minimum_deployment_target=ct.target.iOS17, |
| ) |
|
|
| out_path = output_dir / "F0Predictor.mlpackage" |
| mlmodel.save(str(out_path)) |
| print(f"Saved F0 predictor to {out_path}") |
|
|
|
|
| def convert_hift_decode(vocoder, mel_length: int, output_dir: Path): |
| """Convert HiFT decode stage to CoreML. |
| |
| Source signal STFT must be computed externally (Swift side). |
| """ |
| |
| mel = torch.randn(1, 80, mel_length, dtype=torch.float32) |
| with torch.no_grad(): |
| f0 = vocoder.f0_predictor(mel) |
| s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2) |
| s, _, _ = vocoder.m_source(s) |
| s = s.transpose(1, 2) |
| s_stft_real, s_stft_imag = vocoder._stft(s.squeeze(1)) |
| source_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) |
| num_stft_frames = source_stft.shape[2] |
| print(f" Source STFT shape: {source_stft.shape} ({num_stft_frames} frames)") |
|
|
| wrapper = HiFTDecodeWrapper(vocoder, num_stft_frames).eval().float() |
|
|
| print(f"Tracing HiFT decode (mel_length={mel_length})...") |
| with torch.no_grad(): |
| waveform = wrapper(mel, source_stft) |
| print(f" Output waveform shape: {waveform.shape}") |
| traced = torch.jit.trace(wrapper, (mel, source_stft)) |
|
|
| print("Converting HiFT decode to CoreML...") |
| source_stft_channels = source_stft.shape[1] |
| source_stft_time = source_stft.shape[2] |
| mlmodel = ct.convert( |
| traced, |
| inputs=[ |
| ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32), |
| ct.TensorType( |
| name="source_stft", |
| shape=(1, source_stft_channels, source_stft_time), |
| dtype=np.float32, |
| ), |
| ], |
| outputs=[ct.TensorType(name="waveform", dtype=np.float32)], |
| compute_precision=ct.precision.FLOAT32, |
| minimum_deployment_target=ct.target.iOS17, |
| ) |
|
|
| out_path = output_dir / "HiFTDecode.mlpackage" |
| mlmodel.save(str(out_path)) |
| print(f"Saved HiFT decode to {out_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert Kanade + HiFT to CoreML") |
| parser.add_argument( |
| "--output-dir", type=str, |
| default=str(Path(__file__).parent.parent), |
| help="Output directory", |
| ) |
| parser.add_argument( |
| "--num-tokens", type=int, default=100, |
| help="Fixed number of audio tokens (determines mel length)", |
| ) |
| args = parser.parse_args() |
| convert_audio(Path(args.output_dir), args.num_tokens) |
|
|
|
|
| def convert_audio(output_dir: Path, num_tokens: int = 100) -> tuple[Path, Path]: |
| """Convert Kanade decoder + HiFT vocoder to CoreML. |
| |
| Returns (KanadeDecoder.mlpackage, Vocoder.mlpackage) paths.""" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print("Loading Kanade model...") |
| kanade = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval().float() |
| patch_kanade_for_coreml(kanade) |
| vocoder = load_vocoder(kanade.config.vocoder_name).eval().float() |
|
|
| mel_length = kanade._calculate_target_mel_length( |
| kanade._calculate_original_audio_length(num_tokens) |
| ) |
|
|
| print(f"\n=== Converting Kanade decoder ===") |
| convert_kanade_decoder(kanade, num_tokens, output_dir) |
|
|
| print(f"\n=== Converting full vocoder (mel β waveform) ===") |
| convert_full_vocoder(vocoder, mel_length, output_dir) |
|
|
| print("\nAudio conversion complete!") |
| print(f" KanadeDecoder: {num_tokens} tokens β mel (80, {mel_length})") |
| print(f" Vocoder: mel (80, {mel_length}) β waveform") |
| return output_dir / "KanadeDecoder.mlpackage", output_dir / "Vocoder.mlpackage" |
|
|
|
|
| def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path): |
| """Convert complete melβwaveform vocoder to CoreML.""" |
| |
| mel = torch.randn(1, 80, mel_length, dtype=torch.float32) |
| with torch.no_grad(): |
| f0 = vocoder.f0_predictor(mel) |
| s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2) |
| s, _, _ = vocoder.m_source(s) |
| s = s.transpose(1, 2) |
| sr, si = vocoder._stft(s.squeeze(1)) |
| num_stft_frames = sr.shape[2] |
| print(f" STFT frames: {num_stft_frames}") |
|
|
| wrapper = FullVocoderWrapper(vocoder, num_stft_frames).eval().float() |
|
|
| print(f"Tracing full vocoder (mel_length={mel_length})...") |
| |
| orig_randn = torch.randn_like |
| torch.randn_like = lambda x, **kw: torch.zeros_like(x) |
| with torch.no_grad(): |
| wav = wrapper(mel) |
| print(f" Output waveform: {wav.shape}") |
| traced = torch.jit.trace(wrapper, (mel,)) |
| torch.randn_like = orig_randn |
|
|
| print("Converting full vocoder to CoreML...") |
| mlmodel = ct.convert( |
| traced, |
| inputs=[ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32)], |
| outputs=[ct.TensorType(name="waveform", dtype=np.float32)], |
| compute_precision=ct.precision.FLOAT32, |
| minimum_deployment_target=ct.target.iOS17, |
| ) |
|
|
| out_path = output_dir / "Vocoder.mlpackage" |
| mlmodel.save(str(out_path)) |
| print(f"Saved vocoder to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|