cfm_svc / infer.py
Hector Li
Initial commit for Hugging Face
df93d13
import torch
from models.cond_encoder import CondEncoder
from models.codec_wrapper import CodecWrapper
from models.cfm import DiT
from samplers.ode import ODESampler
import argparse
@torch.no_grad()
def infer_pipeline(wave_path=None, epoch=None, target_spk=None, hubert_index_path=None, hubert_blend=0.75, use_ema=True):
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
print(f"Using device {device}")
# 1. Load trained models
codec_wrapper = CodecWrapper(latent_dim=1024).to(device)
cond_enc = CondEncoder(ppg_dim=1280, hubert_dim=256, f0_dim=1, spk_dim=256, cond_out_dim=1024).to(device)
dit = DiT(in_channels=1024, cond_dim=1024, hidden_dim=512, depth=8).to(device)
# Load checkpoints
import os
# Resolve checkpoint paths (prefer EMA weights when --ema is set)
if epoch is not None:
dit_path = f"chkpt_cfm/dit_epoch_{epoch}.pt"
cond_path = f"chkpt_cfm/cond_encoder_epoch_{epoch}.pt"
proj_path = f"chkpt_cfm/projector_epoch_{epoch}.pt"
ema_dit_path = f"chkpt_cfm/ema_dit_epoch_{epoch}.pt"
ema_cond_path = f"chkpt_cfm/ema_cond_encoder_epoch_{epoch}.pt"
ema_proj_path = f"chkpt_cfm/ema_projector_epoch_{epoch}.pt"
else:
dit_path = "chkpt_cfm/dit_final.pt"
cond_path = "chkpt_cfm/cond_encoder_final.pt"
proj_path = "chkpt_cfm/projector_final.pt"
ema_dit_path = "chkpt_cfm/ema_dit_final.pt"
ema_cond_path = "chkpt_cfm/ema_cond_encoder_final.pt"
ema_proj_path = "chkpt_cfm/ema_projector_final.pt"
# Select which weights to load
if use_ema and os.path.exists(ema_dit_path):
load_dit = ema_dit_path
load_cond = ema_cond_path
load_proj = ema_proj_path
print(f"Loading EMA checkpoints from {ema_dit_path}...")
elif os.path.exists(dit_path):
load_dit = dit_path
load_cond = cond_path
load_proj = proj_path
if use_ema:
print(f"WARNING: EMA checkpoints not found, falling back to regular weights.")
print(f"Loading trained checkpoints from {dit_path}...")
else:
load_dit = None
print(f"WARNING: Checkpoints not found at {dit_path}. Using untrained weights! Output will be noisy.")
if load_dit is not None:
# Helper function to strip _orig_mod prefix from compiled PyTorch 2.0 checkpoints
def clean_sd(sd):
return {k.replace('_orig_mod.', ''): v for k, v in sd.items()}
dit.load_state_dict(clean_sd(torch.load(load_dit, map_location=device, weights_only=True)))
cond_enc.load_state_dict(clean_sd(torch.load(load_cond, map_location=device, weights_only=True)))
codec_wrapper.projector.load_state_dict(clean_sd(torch.load(load_proj, map_location=device, weights_only=True)))
# Load Latent Normalization
if os.path.exists("chkpt_cfm/latent_norm.pt"):
norm_data = torch.load("chkpt_cfm/latent_norm.pt", map_location=device, weights_only=True)
z_mean = norm_data['mean'].to(device)
z_std = norm_data['std'].to(device)
else:
z_mean = torch.zeros(1024).to(device)
z_std = torch.ones(1024).to(device)
codec_wrapper.eval()
cond_enc.eval()
dit.eval()
# Load FAISS Index if provided
hubert_index = None
target_hubert_vectors = None
if hubert_index_path and os.path.exists(hubert_index_path):
import faiss
import numpy as np
print(f"Loading target FAISS index from {hubert_index_path} with blend ratio {hubert_blend}...")
hubert_index = faiss.read_index(hubert_index_path)
vectors_path = hubert_index_path.replace(".index", "_vectors.npy")
if os.path.exists(vectors_path):
target_hubert_vectors = np.load(vectors_path)
print(f"Loaded {target_hubert_vectors.shape[0]} target vectors for real-time inference retrieval.")
else:
print(f"WARNING: Source vectors {vectors_path} missing. FAISS disabled.")
hubert_index = None
total_T_latent = 200
if wave_path:
import soundfile as sf
try:
wav_data, sr = sf.read(wave_path)
import math
total_T_latent = math.ceil(len(wav_data) / sr * 44100 / 512)
print(f"Loaded {wave_path}, calculating total T_latent={total_T_latent}")
except Exception as e:
print(f"Could not load wave: {e}")
total_T_latent = 200
# Max frames for ~25 seconds out_chunks
max_frames = 400
overlap_frames = 50 # ~1.1 seconds overlap
step_frames = max_frames - overlap_frames
final_audio = None
print("Starting chunked inference pipeline (Heun, Overlap-Add)...")
import numpy as np
if wave_path:
file_id = os.path.basename(wave_path).replace('.wav', '')
else:
file_id = "mock"
for start_idx in range(0, total_T_latent, step_frames):
T_latent = min(max_frames, total_T_latent - start_idx)
print(f"--- Processing chunk from frame {start_idx} to {start_idx + T_latent} (Length: {T_latent}) ---")
# 2. Extract Conditioning
time_start = start_idx * 512 / 44100.0
time_end = (start_idx + T_latent) * 512 / 44100.0
try:
ppg_full = np.load(f"data_svc_infer/whisper/speaker0/{file_id}.ppg.npy")
hubert_full = np.load(f"data_svc_infer/hubert/speaker0/{file_id}.vec.npy")
f0_full = np.load(f"data_svc_infer/pitch/speaker0/{file_id}.pit.npy")
if target_spk is not None:
spk_full = np.load(target_spk)
else:
spk_full = np.load(f"data_svc_infer/speaker/speaker0/{file_id}.spk.npy")
ppg_start, ppg_end = int(time_start * 50), int(time_end * 50)
hubert_start, hubert_end = int(time_start * 50), int(time_end * 50)
f0_start, f0_end = int(time_start * 100), int(time_end * 100)
ppg = torch.tensor(ppg_full[max(0, ppg_start) : max(1, ppg_end)]).float().unsqueeze(0).to(device)
hubert = torch.tensor(hubert_full[max(0, hubert_start) : max(1, hubert_end)]).float().unsqueeze(0).to(device)
f0_raw = torch.tensor(f0_full[max(0, f0_start) : max(1, f0_end)]).float()
f0 = torch.where(f0_raw > 0, torch.log(f0_raw.clamp(min=1.0)), torch.zeros_like(f0_raw)).unsqueeze(-1).unsqueeze(0).to(device)
spk = torch.tensor(spk_full).float().unsqueeze(0).to(device)
# Failsafe
if ppg.shape[1] == 0: ppg = torch.randn(1, max(1, T_latent // 2), 1280).to(device)
if hubert.shape[1] == 0: hubert = torch.randn(1, max(1, T_latent // 2), 256).to(device)
if f0.shape[1] == 0: f0 = torch.randn(1, T_latent, 1).to(device)
# --- FAISS HUBERT BLEND ---
if hubert_index is not None and target_hubert_vectors is not None:
source_hubert_np = hubert.squeeze(0).cpu().numpy().astype(np.float32) # (T, 256)
# Search FAISS index top k=4
_, I = hubert_index.search(source_hubert_np, 4)
# Average neighbors
nn_hubert = target_hubert_vectors[I].mean(axis=1) # (T, 256)
# Soft blend based on arg
blended_hubert = hubert_blend * nn_hubert + (1.0 - hubert_blend) * source_hubert_np
hubert = torch.tensor(blended_hubert).float().unsqueeze(0).to(device)
except FileNotFoundError:
ppg = torch.randn(1, max(1, T_latent // 2), 1280).to(device)
hubert = torch.randn(1, max(1, T_latent // 2), 256).to(device)
f0 = torch.randn(1, T_latent, 1).to(device)
spk = torch.randn(1, 256).to(device)
c = cond_enc(ppg, hubert, f0, spk, target_seq_len=T_latent)
# 3. Sample
sampler = ODESampler(dit, steps=12, solver='heun')
z_noise = torch.randn(1, T_latent, 1024).to(device)
u_hat = sampler.sample(z_noise, c) # (1, T_latent, 1024)
# 4. Project (expecting normalized input to match training)
u_hat_transposed = u_hat.transpose(1, 2)
z_hat_norm = codec_wrapper.forward_project(u_hat_transposed) # (1, 1024, T_latent)
# 5. Invert Normalization (before decoding to raw latent space)
z_hat_norm_transposed = z_hat_norm.transpose(1, 2)
z_hat_denorm = (z_hat_norm_transposed * z_std) + z_mean
z_hat = z_hat_denorm.transpose(1, 2)
# 6. Decode
wav_chunk = codec_wrapper.decode(z_hat).cpu().squeeze().numpy()
# Overlap Add Crossfade
if final_audio is None:
final_audio = wav_chunk
else:
# Crossfade overlap region
overlap_samples = overlap_frames * 512
if len(wav_chunk) >= overlap_samples and len(final_audio) >= overlap_samples:
fade_in = np.linspace(0, 1, overlap_samples)
fade_out = 1 - fade_in
# Apply crossfade
final_audio[-overlap_samples:] = final_audio[-overlap_samples:] * fade_out + wav_chunk[:overlap_samples] * fade_in
# Append rest
final_audio = np.concatenate([final_audio, wav_chunk[overlap_samples:]])
else:
final_audio = np.concatenate([final_audio, wav_chunk])
if T_latent < max_frames:
break
wav_out = final_audio
print(f"Inference complete! Final output waveform shape: {wav_out.shape}")
import soundfile as sf
out_file = "output_sample.wav"
sf.write(out_file, wav_out, 44100)
print(f"Saved output to {out_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--wave", type=str, default=None, help="Path to input wave file")
parser.add_argument("--epoch", type=int, default=None, help="Epoch number to load checkpoints from (e.g., 30)")
parser.add_argument("--target_spk", type=str, default=None, help="Path to target speaker .npy array for voice conversion")
parser.add_argument("--hubert_index", type=str, default=None, help="Path to FAISS .index file for HuBERT retrieval")
parser.add_argument("--hubert_blend", type=float, default=0.75, help="Blend ratio for HuBERT retrieval (0.0=source, 1.0=target)")
parser.add_argument("--ema", action="store_true", default=True, help="Use EMA-averaged weights (default: True, use --no-ema to disable)")
parser.add_argument("--no-ema", dest="ema", action="store_false", help="Use raw training weights instead of EMA")
args = parser.parse_args()
infer_pipeline(args.wave, args.epoch, args.target_spk, args.hubert_index, args.hubert_blend, use_ema=args.ema)