plapre-pico-coreml / scripts /convert_audio.py
Daniel Rothmann
Tidy up repo and conversion scripts
e79fa0a
#!/usr/bin/env python3
"""
Convert Kanade decoder and HiFT vocoder to CoreML.
These are non-autoregressive models (single forward pass), so conversion
is simpler than the LLM β€” no KV cache or StateType needed.
Two models are produced:
- KanadeDecoder.mlpackage: audio token indices + speaker embedding β†’ mel spectrogram
- HiFTVocoder.mlpackage: mel spectrogram β†’ PCM waveform
Usage:
python scripts/convert_kanade.py [--output-dir PATH] [--num-tokens 100]
"""
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import coremltools as ct
from kanade_tokenizer import KanadeModel, load_vocoder
import kanade_tokenizer.module.transformer as kanade_transformer
# ── Monkey-patch Kanade's complex RoPE with real-valued version ───────────
def _apply_rotary_emb_real(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Real-valued RoPE replacement for Kanade's complex-number version.
Converts complex freqs_cis to cos/sin and applies split-half rotation.
"""
# freqs_cis is complex: (seq_len, head_dim/2)
cos = freqs_cis.real # (seq_len, head_dim/2)
sin = freqs_cis.imag
# Broadcast to match x shape: (bsz, seq_len, n_heads, head_dim)
# x has head_dim, cos/sin have head_dim/2 β€” need to double them
cos = torch.cat([cos, cos], dim=-1) # (seq_len, head_dim)
sin = torch.cat([sin, sin], dim=-1)
# Reshape for broadcast: (1, seq_len, 1, head_dim)
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
# Split-half rotation
half = x.shape[-1] // 2
x1 = x[..., :half]
x2 = x[..., half:]
rotated = torch.cat((-x2, x1), dim=-1)
return (x * cos + rotated * sin).type_as(x)
def _apply_rotary_emb_precomputed(x: torch.Tensor, freqs_cos_sin: torch.Tensor) -> torch.Tensor:
"""Real-valued RoPE using precomputed cos/sin stored as (seq_len, head_dim).
Matches Kanade's INTERLEAVED complex multiplication:
view_as_complex(x.reshape(..., -1, 2)) * (cos + i*sin)
which pairs adjacent elements: (x0,x1), (x2,x3), ...
Equivalent real-valued form:
out[2k] = x[2k]*cos[k] - x[2k+1]*sin[k]
out[2k+1] = x[2k]*sin[k] + x[2k+1]*cos[k]
head_dim is always 64, so we have 32 pairs.
"""
# freqs_cos_sin: (seq_len, 64) where [:32] = cos, [32:] = sin
cos = freqs_cos_sin[..., :32] # (seq_len, 32)
sin = freqs_cos_sin[..., 32:] # (seq_len, 32)
# Broadcast: (1, seq_len, 1, 32)
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
# Interleaved pairs: x has shape (..., 64), pair as (..., 32, 2)
x_pairs = x.reshape(*x.shape[:-1], 32, 2) # (..., 32, 2)
x_even = x_pairs[..., 0] # (..., 32)
x_odd = x_pairs[..., 1] # (..., 32)
# Complex multiply: (x_even + i*x_odd) * (cos + i*sin)
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
# Interleave back: stack on last dim then flatten
out = torch.stack([out_even, out_odd], dim=-1) # (..., 32, 2)
out = out.reshape(*x.shape) # (..., 64)
return out.type_as(x)
def _patched_attention_forward_v2(self, x, freqs_cis, mask, return_kv=False):
"""Attention forward with real-valued RoPE and explicit matmul.
Supports local (windowed) attention via additive -inf mask.
"""
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
if freqs_cis is not None:
xq = _apply_rotary_emb_precomputed(xq, freqs_cis[:seqlen])
xk = _apply_rotary_emb_precomputed(xk, freqs_cis[:seqlen])
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * self.scale
# Build attention mask (causal + local window)
if self.causal or self.use_local_attention or mask is not None:
attn_mask = torch.zeros((seqlen, seqlen), device=x.device, dtype=x.dtype)
if self.causal:
causal = torch.triu(
torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype),
diagonal=1,
)
attn_mask = attn_mask + causal
if self.use_local_attention:
# Block positions outside the window [-window_per_side, +window_per_side]
local_mask = torch.triu(
torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype),
diagonal=self.window_per_side + 1,
) + torch.tril(
torch.full((seqlen, seqlen), float("-inf"), device=x.device, dtype=x.dtype),
diagonal=-(self.window_per_side + 1),
)
attn_mask = attn_mask + local_mask
attn_weights = attn_weights + attn_mask
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
output = torch.matmul(attn_weights, xv)
# 12 heads * 64 head_dim = 768
output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, 768)
output = self.wo(output)
if return_kv:
return output, (xk, xv)
return output
def _convert_freqs_cis_to_real(transformer_module):
"""Replace complex freqs_cis buffer with real-valued cos/sin concatenation."""
if hasattr(transformer_module, 'freqs_cis') and transformer_module.freqs_cis is not None:
fc = transformer_module.freqs_cis # (max_len, head_dim/2) complex
cos = fc.real.float() # (max_len, head_dim/2)
sin = fc.imag.float()
real_freqs = torch.cat([cos, sin], dim=-1) # (max_len, head_dim)
# Replace the buffer
del transformer_module.freqs_cis
transformer_module.register_buffer('freqs_cis', real_freqs)
def patch_kanade_for_coreml(kanade: KanadeModel):
"""Apply monkey-patches to make Kanade traceable by coremltools."""
kanade_transformer.Attention.forward = _patched_attention_forward_v2
# Convert complex freqs_cis to real in all transformers
for name, module in kanade.named_modules():
if isinstance(module, kanade_transformer.Transformer):
_convert_freqs_cis_to_real(module)
class KanadeDecoderWrapper(nn.Module):
"""Wraps Kanade's decode pipeline for tracing.
Pipeline: token indices β†’ quantizer decode β†’ mel_prenet β†’ upsample β†’
mel_decoder (conditioned on speaker) β†’ mel_postnet β†’ mel
"""
def __init__(self, kanade: KanadeModel, num_tokens: int):
super().__init__()
self.local_quantizer = kanade.local_quantizer
self.mel_prenet = kanade.mel_prenet
self.mel_conv_upsample = kanade.mel_conv_upsample
self.mel_decoder = kanade.mel_decoder
self.mel_postnet = kanade.mel_postnet
self.num_tokens = num_tokens
# Precompute mel_length for this token count
self.mel_length = kanade._calculate_target_mel_length(
kanade._calculate_original_audio_length(num_tokens)
)
def forward(
self,
token_indices: torch.Tensor,
speaker_embedding: torch.Tensor,
) -> torch.Tensor:
"""
Args:
token_indices: (num_tokens,) int32 β€” Kanade codebook indices (0-12799)
speaker_embedding: (1, 128) float32 β€” speaker embedding
Returns:
mel: (1, 80, mel_length) float32
"""
# Quantizer decode: indices β†’ content embedding
content_emb = self.local_quantizer.decode(token_indices) # (num_tokens, 768)
content_emb = content_emb.unsqueeze(0) # (1, num_tokens, 768)
# Mel prenet (transformer)
local_latent = self.mel_prenet(content_emb)
# Upsample to mel length
if self.mel_conv_upsample is not None:
local_latent = self.mel_conv_upsample(
local_latent.transpose(1, 2)
).transpose(1, 2)
local_latent = F.interpolate(
local_latent.transpose(1, 2), size=self.mel_length, mode="linear"
).transpose(1, 2)
# Mel decoder (conditioned on speaker)
mel = self.mel_decoder(local_latent, condition=speaker_embedding.unsqueeze(1))
mel = mel.transpose(1, 2) # (1, 80, mel_length)
# Postnet
mel = self.mel_postnet(mel)
return mel
class FullVocoderWrapper(nn.Module):
"""Complete mel β†’ waveform pipeline: F0 prediction + source gen + HiFT decode + iSTFT.
Noise is replaced with zeros for deterministic tracing.
"""
def __init__(self, vocoder, num_stft_frames: int):
super().__init__()
self.vocoder = vocoder
self.num_stft_frames = num_stft_frames
n_fft = vocoder.istft_n_fft # 16
hop_len = vocoder.istft_hop_len # 4
# iDFT basis
n = torch.arange(n_fft, dtype=torch.float32)
k = torch.arange(n_fft, dtype=torch.float32)
angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft
self.register_buffer("idft_cos", torch.cos(angles) / n_fft)
self.register_buffer("idft_sin", torch.sin(angles) / n_fft)
self.register_buffer("window", vocoder.stft_window.clone())
# Source generation constants
self.sampling_rate = vocoder.m_source.l_sin_gen.sampling_rate
self.harmonic_num = vocoder.m_source.l_sin_gen.harmonic_num # 8
self.sine_amp = vocoder.m_source.l_sin_gen.sine_amp # 0.1
self.upsample_scale = vocoder.m_source.l_sin_gen.upsample_scale # 480
# Harmonic multipliers: [1, 2, ..., 9]
self.register_buffer(
"harmonic_muls",
torch.arange(1, self.harmonic_num + 2, dtype=torch.float32),
)
# l_linear and l_tanh from m_source
self.source_linear = vocoder.m_source.l_linear
self.source_tanh = vocoder.m_source.l_tanh
self.n_fft = n_fft
self.hop_len = hop_len
self.n_fft_half = n_fft // 2 + 1
def _generate_source(self, f0: torch.Tensor) -> torch.Tensor:
"""f0: (1, mel_length) β†’ source_stft: (1, 18, stft_frames)"""
# Upsample f0: (1, mel_length) β†’ (1, 1, mel_length) β†’ nearest β†’ (1, 1, audio_length)
f0_up = F.interpolate(
f0.unsqueeze(1), scale_factor=float(self.upsample_scale), mode="nearest"
).squeeze(1) # (1, audio_length)
# Generate harmonics: f0 * [1..9]
# f0_up: (1, L) β†’ (1, L, 1) * (9,) β†’ (1, L, 9)
fn = f0_up.unsqueeze(-1) * self.harmonic_muls.unsqueeze(0).unsqueeze(0)
# Phase accumulation: cumsum(f/sr) * 2pi
rad = (fn / self.sampling_rate) # instantaneous frequency in cycles per sample
phase = torch.cumsum(rad, dim=1) * 2.0 * torch.pi # (1, L, 9)
# Sine waves
sines = torch.sin(phase) * self.sine_amp # (1, L, 9)
# UV mask (voiced/unvoiced)
uv = (f0_up > 0).float().unsqueeze(-1) # (1, L, 1)
# Apply UV (no noise β€” zeros instead of randn for tracing)
sines = sines * uv # (1, L, 9)
# l_linear + tanh: (1, L, 9) β†’ linear β†’ (1, L, 1) β†’ tanh
source = self.source_tanh(self.source_linear(sines)) # (1, L, 1)
source = source.squeeze(-1) # (1, L)
# Manual STFT (torch.stft/unfold not CoreML-compatible)
# n_fft=16, hop=4. With center padding, we get num_stft_frames frames.
# Pad source: reflect pad n_fft//2 on each side
padded = F.pad(source, (self.n_fft // 2, self.n_fft // 2), mode="reflect")
# padded: (1, L + n_fft) where L = audio_length
# Extract overlapping frames using conv1d with identity kernel
# This replaces unfold: conv1d with (n_fft, 1, n_fft) identity kernel, stride=hop
# Equivalent to: frames[i] = padded[i*hop : i*hop + n_fft]
eye_kernel = torch.eye(self.n_fft, dtype=source.dtype, device=source.device).unsqueeze(1)
# padded: (1, L+16) β†’ (1, 1, L+16) for conv1d
frames = F.conv1d(padded.unsqueeze(1), eye_kernel, stride=self.hop_len)
# frames: (1, 16, num_frames)
frames = frames * self.window.unsqueeze(0).unsqueeze(-1) # window each frame
# Transpose to (1, num_frames, 16) for matmul
frames = frames.transpose(1, 2)
# DFT via matmul
dft_cos = self.idft_cos[:self.n_fft_half, :] * self.n_fft # undo 1/N normalization
dft_sin = self.idft_sin[:self.n_fft_half, :] * self.n_fft
s_real = torch.matmul(frames, dft_cos.T) # (1, NF, 9)
s_imag = -torch.matmul(frames, dft_sin.T) # (1, NF, 9)
source_stft = torch.cat([s_real.transpose(1, 2), s_imag.transpose(1, 2)], dim=1)
return source_stft
def _istft_overlap_add(self, x: torch.Tensor) -> torch.Tensor:
"""x: (1, 18, num_frames) conv_post output β†’ waveform (1, samples)"""
magnitude = torch.exp(x[:, :self.n_fft_half, :])
phase = torch.sin(x[:, self.n_fft_half:, :])
real_half = magnitude * torch.cos(phase)
imag_half = magnitude * torch.sin(phase)
real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1])
imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1])
real_full = torch.cat([real_half, real_mirror], dim=1)
imag_full = torch.cat([imag_half, imag_mirror], dim=1)
real_t = real_full.transpose(1, 2)
imag_t = imag_full.transpose(1, 2)
segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T)
NF = self.num_stft_frames
segments = segments * self.window.unsqueeze(0).unsqueeze(0)
seg = segments.squeeze(0)
seg_chunks = seg.reshape(NF, 4, 4)
b0 = seg_chunks[:, 0, :].reshape(-1)
b1 = seg_chunks[:, 1, :].reshape(-1)
b2 = seg_chunks[:, 2, :].reshape(-1)
b3 = seg_chunks[:, 3, :].reshape(-1)
F4 = NF * 4
padded_samples = NF * 4 + 12
output = torch.zeros(padded_samples)
output[0:F4] = output[0:F4] + b0
output[4:F4 + 4] = output[4:F4 + 4] + b1
output[8:F4 + 8] = output[8:F4 + 8] + b2
output[12:F4 + 12] = output[12:F4 + 12] + b3
win_sq = self.window * self.window
win_chunks = win_sq.reshape(4, 4)
w0 = win_chunks[0].repeat(NF)
w1 = win_chunks[1].repeat(NF)
w2 = win_chunks[2].repeat(NF)
w3 = win_chunks[3].repeat(NF)
wnorm = torch.zeros(padded_samples)
wnorm[0:F4] = wnorm[0:F4] + w0
wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1
wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2
wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3
output = output / (wnorm + 1e-8)
pad = 8
trimmed_len = (NF - 1) * 4
output = output[pad:pad + trimmed_len]
output = torch.clamp(output, -0.99, 0.99)
return output.unsqueeze(0)
def forward(self, mel: torch.Tensor) -> torch.Tensor:
"""mel: (1, 80, T) β†’ waveform: (1, samples)"""
# F0 prediction
f0 = self.vocoder.f0_predictor(mel) # (1, T)
# Source generation
source_stft = self._generate_source(f0)
# HiFT decode
x = self.vocoder.conv_pre(mel)
for i in range(self.vocoder.num_upsamples):
x = F.leaky_relu(x, self.vocoder.lrelu_slope)
x = self.vocoder.ups[i](x)
if i == self.vocoder.num_upsamples - 1:
x = self.vocoder.reflection_pad(x)
si = self.vocoder.source_downs[i](source_stft)
si = self.vocoder.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.vocoder.num_kernels):
if xs is None:
xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
else:
xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
x = xs / self.vocoder.num_kernels
x = F.leaky_relu(x)
x = self.vocoder.conv_post(x)
return self._istft_overlap_add(x)
class F0PredictorWrapper(nn.Module):
"""Wraps HiFT's f0 predictor: mel β†’ f0."""
def __init__(self, vocoder):
super().__init__()
self.f0_predictor = vocoder.f0_predictor
def forward(self, mel: torch.Tensor) -> torch.Tensor:
"""mel: (1, 80, T) β†’ f0: (1, 1, T)"""
return self.f0_predictor(mel)
class HiFTDecodeWrapper(nn.Module):
"""Wraps HiFT's decode stage: mel + source_stft β†’ waveform.
Includes a manual iSTFT implementation using matmul with a precomputed
DFT basis matrix, so the entire pipeline runs inside CoreML.
"""
def __init__(self, vocoder, num_stft_frames: int):
super().__init__()
self.vocoder = vocoder
self.num_stft_frames = num_stft_frames # hardcoded for tracing
n_fft = vocoder.istft_n_fft # 16
hop_len = vocoder.istft_hop_len # 4
# Precompute DFT basis for iSTFT: (n_fft, n_fft) real-valued IDFT matrix
# X[k] = sum_n x[n] * exp(j*2pi*n*k/N) β†’ x[n] = (1/N) * sum_k X[k] * exp(j*2pi*n*k/N)
n = torch.arange(n_fft, dtype=torch.float32)
k = torch.arange(n_fft, dtype=torch.float32)
angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft # (n_fft, n_fft)
# cos/sin basis for real/imag parts
self.register_buffer("idft_cos", torch.cos(angles) / n_fft) # (n_fft, n_fft)
self.register_buffer("idft_sin", torch.sin(angles) / n_fft) # (n_fft, n_fft)
# Window for overlap-add
self.register_buffer("window", vocoder.stft_window.clone())
self.n_fft = n_fft
self.hop_len = hop_len
self.n_fft_half = n_fft // 2 + 1 # 9
def forward(self, mel: torch.Tensor, source_stft: torch.Tensor) -> torch.Tensor:
"""
Args:
mel: (1, 80, T) float32
source_stft: (1, 18, T') float32 β€” real+imag STFT of source signal
Returns:
waveform: (1, samples) float32
"""
x = self.vocoder.conv_pre(mel)
for i in range(self.vocoder.num_upsamples):
x = F.leaky_relu(x, self.vocoder.lrelu_slope)
x = self.vocoder.ups[i](x)
if i == self.vocoder.num_upsamples - 1:
x = self.vocoder.reflection_pad(x)
si = self.vocoder.source_downs[i](source_stft)
si = self.vocoder.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.vocoder.num_kernels):
if xs is None:
xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
else:
xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
x = xs / self.vocoder.num_kernels
x = F.leaky_relu(x)
x = self.vocoder.conv_post(x) # (1, 18, num_frames)
# Split into magnitude and phase
magnitude = torch.exp(x[:, :self.n_fft_half, :]) # (1, 9, num_frames)
phase = torch.sin(x[:, self.n_fft_half:, :]) # (1, 9, num_frames)
# Convert to real/imag
real_half = magnitude * torch.cos(phase) # (1, 9, num_frames)
imag_half = magnitude * torch.sin(phase)
# Mirror to full spectrum (Hermitian symmetry)
# real: [r0, r1, ..., r8, r7, r6, ..., r1]
# imag: [i0, i1, ..., i8, -i7, -i6, ..., -i1]
real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1])
imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1])
real_full = torch.cat([real_half, real_mirror], dim=1) # (1, 16, num_frames)
imag_full = torch.cat([imag_half, imag_mirror], dim=1) # (1, 16, num_frames)
# iDFT via matmul: output[n] = sum_k (real[k]*cos[n,k] - imag[k]*sin[n,k])
# (1, 16, num_frames) β†’ transpose to (1, num_frames, 16) β†’ matmul with (16, 16)
real_t = real_full.transpose(1, 2) # (1, num_frames, 16)
imag_t = imag_full.transpose(1, 2)
# segments[n] = sum_k real[k]*cos[n,k] - imag[k]*sin[n,k]
# = real_t @ idft_cos.T - imag_t @ idft_sin.T
# But idft_cos is (n_fft, n_fft) where idft_cos[n,k] = cos(2pi*n*k/N)/N
# We want segments[frame, n] = sum_k (real[frame,k] * idft_cos[n,k] - imag[frame,k] * idft_sin[n,k])
# = (real_t @ idft_cos^T - imag_t @ idft_sin^T)[frame, n]
segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T)
# segments: (1, num_frames, 16)
# Overlap-add with window
# n_fft=16, hop=4, so overlap ratio = 4 (each sample covered by 4 frames)
NF = self.num_stft_frames # hardcoded constant for tracing
segments = segments * self.window.unsqueeze(0).unsqueeze(0) # (1, NF, 16)
seg = segments.squeeze(0) # (NF, 16)
# Reshape each 16-sample segment into 4 chunks of 4 (hop_len) samples
# seg: (F, 16) β†’ (F, 4, 4)
seg_chunks = seg.reshape(NF, 4, 4) # (F, 4_blocks, 4_samples)
# Block b of frame f lands at output position (f + b) * hop_len
# Rearrange so block b from all frames is contiguous:
# chunk_b[f] = seg_chunks[f, b, :] lands at output[(f+b)*4 : (f+b)*4 + 4]
# = output index f*4 + b*4 ... but shifted by b frames
# Equivalently: for block b, we have F values that go to positions b, b+1, ..., b+F-1
# in units of hop_len
# For each sub-block offset (0..3), create a flat array and add shifted
# Using static slicing only β€” no dynamic indexing
padded_samples = NF * 4 + 12 # (NF-1)*4 + 16
# Actually: (num_frames - 1) * 4 + 16 = num_frames * 4 + 12
# Each sub-block b contributes F chunks of 4 samples, placed at positions
# starting from b*4 with stride 4 between frames.
# block_b = seg_chunks[:, b, :].reshape(-1) β†’ F*4 contiguous values
# These go to output[b*4 : b*4 + F*4]
b0 = seg_chunks[:, 0, :].reshape(-1) # (F*4,) β†’ output[0 : F*4]
b1 = seg_chunks[:, 1, :].reshape(-1) # (F*4,) β†’ output[4 : F*4 + 4]
b2 = seg_chunks[:, 2, :].reshape(-1) # (F*4,) β†’ output[8 : F*4 + 8]
b3 = seg_chunks[:, 3, :].reshape(-1) # (F*4,) β†’ output[12 : F*4 + 12]
F4 = NF * 4
output = torch.zeros(padded_samples)
output[0:F4] = output[0:F4] + b0
output[4:F4 + 4] = output[4:F4 + 4] + b1
output[8:F4 + 8] = output[8:F4 + 8] + b2
output[12:F4 + 12] = output[12:F4 + 12] + b3
# Window normalization β€” same structure
win_sq = self.window * self.window # (16,)
win_chunks = win_sq.reshape(4, 4) # (4_blocks, 4_samples)
w0 = win_chunks[0].repeat(NF)
w1 = win_chunks[1].repeat(NF)
w2 = win_chunks[2].repeat(NF)
w3 = win_chunks[3].repeat(NF)
wnorm = torch.zeros(padded_samples)
wnorm[0:F4] = wnorm[0:F4] + w0
wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1
wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2
wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3
output = output / (wnorm + 1e-8)
# Trim center padding: n_fft//2 = 8 from start
pad = 8
trimmed_len = (NF - 1) * 4 # expected output length
output = output[pad:pad + trimmed_len]
output = torch.clamp(output, -0.99, 0.99)
return output.unsqueeze(0) # (1, samples)
def convert_kanade_decoder(kanade: KanadeModel, num_tokens: int, output_dir: Path):
"""Convert Kanade decoder to CoreML."""
wrapper = KanadeDecoderWrapper(kanade, num_tokens).eval().float()
print(f"Tracing Kanade decoder (num_tokens={num_tokens}, mel_length={wrapper.mel_length})...")
token_indices = torch.arange(num_tokens, dtype=torch.int32)
speaker_embedding = torch.randn(1, 128, dtype=torch.float32)
with torch.no_grad():
# Test forward
mel = wrapper(token_indices, speaker_embedding)
print(f" Output mel shape: {mel.shape}")
traced = torch.jit.trace(wrapper, (token_indices, speaker_embedding))
print("Converting Kanade decoder to CoreML...")
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(name="token_indices", shape=(num_tokens,), dtype=np.int32),
ct.TensorType(name="speaker_embedding", shape=(1, 128), dtype=np.float32),
],
outputs=[ct.TensorType(name="mel", dtype=np.float32)],
compute_precision=ct.precision.FLOAT32,
minimum_deployment_target=ct.target.iOS17,
)
out_path = output_dir / "KanadeDecoder.mlpackage"
mlmodel.save(str(out_path))
print(f"Saved Kanade decoder to {out_path}")
def convert_f0_predictor(vocoder, mel_length: int, output_dir: Path):
"""Convert HiFT f0 predictor to CoreML."""
wrapper = F0PredictorWrapper(vocoder).eval().float()
print(f"Tracing F0 predictor (mel_length={mel_length})...")
mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
with torch.no_grad():
f0 = wrapper(mel)
print(f" Output f0 shape: {f0.shape}")
traced = torch.jit.trace(wrapper, (mel,))
print("Converting F0 predictor to CoreML...")
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32),
],
outputs=[ct.TensorType(name="f0", dtype=np.float32)],
compute_precision=ct.precision.FLOAT32,
minimum_deployment_target=ct.target.iOS17,
)
out_path = output_dir / "F0Predictor.mlpackage"
mlmodel.save(str(out_path))
print(f"Saved F0 predictor to {out_path}")
def convert_hift_decode(vocoder, mel_length: int, output_dir: Path):
"""Convert HiFT decode stage to CoreML.
Source signal STFT must be computed externally (Swift side).
"""
# Compute source_stft shape: run f0 predictor + source module to get it
mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
with torch.no_grad():
f0 = vocoder.f0_predictor(mel)
s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2)
s, _, _ = vocoder.m_source(s)
s = s.transpose(1, 2)
s_stft_real, s_stft_imag = vocoder._stft(s.squeeze(1))
source_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
num_stft_frames = source_stft.shape[2]
print(f" Source STFT shape: {source_stft.shape} ({num_stft_frames} frames)")
wrapper = HiFTDecodeWrapper(vocoder, num_stft_frames).eval().float()
print(f"Tracing HiFT decode (mel_length={mel_length})...")
with torch.no_grad():
waveform = wrapper(mel, source_stft)
print(f" Output waveform shape: {waveform.shape}")
traced = torch.jit.trace(wrapper, (mel, source_stft))
print("Converting HiFT decode to CoreML...")
source_stft_channels = source_stft.shape[1]
source_stft_time = source_stft.shape[2]
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32),
ct.TensorType(
name="source_stft",
shape=(1, source_stft_channels, source_stft_time),
dtype=np.float32,
),
],
outputs=[ct.TensorType(name="waveform", dtype=np.float32)],
compute_precision=ct.precision.FLOAT32,
minimum_deployment_target=ct.target.iOS17,
)
out_path = output_dir / "HiFTDecode.mlpackage"
mlmodel.save(str(out_path))
print(f"Saved HiFT decode to {out_path}")
def main():
parser = argparse.ArgumentParser(description="Convert Kanade + HiFT to CoreML")
parser.add_argument(
"--output-dir", type=str,
default=str(Path(__file__).parent.parent),
help="Output directory",
)
parser.add_argument(
"--num-tokens", type=int, default=100,
help="Fixed number of audio tokens (determines mel length)",
)
args = parser.parse_args()
convert_audio(Path(args.output_dir), args.num_tokens)
def convert_audio(output_dir: Path, num_tokens: int = 100) -> tuple[Path, Path]:
"""Convert Kanade decoder + HiFT vocoder to CoreML.
Returns (KanadeDecoder.mlpackage, Vocoder.mlpackage) paths."""
output_dir.mkdir(parents=True, exist_ok=True)
print("Loading Kanade model...")
kanade = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval().float()
patch_kanade_for_coreml(kanade)
vocoder = load_vocoder(kanade.config.vocoder_name).eval().float()
mel_length = kanade._calculate_target_mel_length(
kanade._calculate_original_audio_length(num_tokens)
)
print(f"\n=== Converting Kanade decoder ===")
convert_kanade_decoder(kanade, num_tokens, output_dir)
print(f"\n=== Converting full vocoder (mel β†’ waveform) ===")
convert_full_vocoder(vocoder, mel_length, output_dir)
print("\nAudio conversion complete!")
print(f" KanadeDecoder: {num_tokens} tokens β†’ mel (80, {mel_length})")
print(f" Vocoder: mel (80, {mel_length}) β†’ waveform")
return output_dir / "KanadeDecoder.mlpackage", output_dir / "Vocoder.mlpackage"
def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path):
"""Convert complete mel→waveform vocoder to CoreML."""
# Get num_stft_frames by running a dummy forward
mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
with torch.no_grad():
f0 = vocoder.f0_predictor(mel)
s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2)
s, _, _ = vocoder.m_source(s)
s = s.transpose(1, 2)
sr, si = vocoder._stft(s.squeeze(1))
num_stft_frames = sr.shape[2]
print(f" STFT frames: {num_stft_frames}")
wrapper = FullVocoderWrapper(vocoder, num_stft_frames).eval().float()
print(f"Tracing full vocoder (mel_length={mel_length})...")
# Replace randn_like with zeros for tracing
orig_randn = torch.randn_like
torch.randn_like = lambda x, **kw: torch.zeros_like(x)
with torch.no_grad():
wav = wrapper(mel)
print(f" Output waveform: {wav.shape}")
traced = torch.jit.trace(wrapper, (mel,))
torch.randn_like = orig_randn
print("Converting full vocoder to CoreML...")
mlmodel = ct.convert(
traced,
inputs=[ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32)],
outputs=[ct.TensorType(name="waveform", dtype=np.float32)],
compute_precision=ct.precision.FLOAT32,
minimum_deployment_target=ct.target.iOS17,
)
out_path = output_dir / "Vocoder.mlpackage"
mlmodel.save(str(out_path))
print(f"Saved vocoder to {out_path}")
if __name__ == "__main__":
main()