| import torch |
| from models.cond_encoder import CondEncoder |
| from models.codec_wrapper import CodecWrapper |
| from models.cfm import DiT |
| from samplers.ode import ODESampler |
|
|
| import argparse |
|
|
| @torch.no_grad() |
| def infer_pipeline(wave_path=None, epoch=None, target_spk=None, hubert_index_path=None, hubert_blend=0.75, use_ema=True): |
| if torch.cuda.is_available(): |
| device = torch.device('cuda') |
| elif torch.backends.mps.is_available(): |
| device = torch.device('mps') |
| else: |
| device = torch.device('cpu') |
| print(f"Using device {device}") |
| |
| |
| codec_wrapper = CodecWrapper(latent_dim=1024).to(device) |
| cond_enc = CondEncoder(ppg_dim=1280, hubert_dim=256, f0_dim=1, spk_dim=256, cond_out_dim=1024).to(device) |
| dit = DiT(in_channels=1024, cond_dim=1024, hidden_dim=512, depth=8).to(device) |
| |
| import os |
|
|
| |
| if epoch is not None: |
| dit_path = f"chkpt_cfm/dit_epoch_{epoch}.pt" |
| cond_path = f"chkpt_cfm/cond_encoder_epoch_{epoch}.pt" |
| proj_path = f"chkpt_cfm/projector_epoch_{epoch}.pt" |
| ema_dit_path = f"chkpt_cfm/ema_dit_epoch_{epoch}.pt" |
| ema_cond_path = f"chkpt_cfm/ema_cond_encoder_epoch_{epoch}.pt" |
| ema_proj_path = f"chkpt_cfm/ema_projector_epoch_{epoch}.pt" |
| else: |
| dit_path = "chkpt_cfm/dit_final.pt" |
| cond_path = "chkpt_cfm/cond_encoder_final.pt" |
| proj_path = "chkpt_cfm/projector_final.pt" |
| ema_dit_path = "chkpt_cfm/ema_dit_final.pt" |
| ema_cond_path = "chkpt_cfm/ema_cond_encoder_final.pt" |
| ema_proj_path = "chkpt_cfm/ema_projector_final.pt" |
|
|
| |
| if use_ema and os.path.exists(ema_dit_path): |
| load_dit = ema_dit_path |
| load_cond = ema_cond_path |
| load_proj = ema_proj_path |
| print(f"Loading EMA checkpoints from {ema_dit_path}...") |
| elif os.path.exists(dit_path): |
| load_dit = dit_path |
| load_cond = cond_path |
| load_proj = proj_path |
| if use_ema: |
| print(f"WARNING: EMA checkpoints not found, falling back to regular weights.") |
| print(f"Loading trained checkpoints from {dit_path}...") |
| else: |
| load_dit = None |
| print(f"WARNING: Checkpoints not found at {dit_path}. Using untrained weights! Output will be noisy.") |
|
|
| if load_dit is not None: |
| |
| def clean_sd(sd): |
| return {k.replace('_orig_mod.', ''): v for k, v in sd.items()} |
|
|
| dit.load_state_dict(clean_sd(torch.load(load_dit, map_location=device, weights_only=True))) |
| cond_enc.load_state_dict(clean_sd(torch.load(load_cond, map_location=device, weights_only=True))) |
| codec_wrapper.projector.load_state_dict(clean_sd(torch.load(load_proj, map_location=device, weights_only=True))) |
| |
| |
| if os.path.exists("chkpt_cfm/latent_norm.pt"): |
| norm_data = torch.load("chkpt_cfm/latent_norm.pt", map_location=device, weights_only=True) |
| z_mean = norm_data['mean'].to(device) |
| z_std = norm_data['std'].to(device) |
| else: |
| z_mean = torch.zeros(1024).to(device) |
| z_std = torch.ones(1024).to(device) |
| |
| codec_wrapper.eval() |
| cond_enc.eval() |
| dit.eval() |
| |
| |
| hubert_index = None |
| target_hubert_vectors = None |
| if hubert_index_path and os.path.exists(hubert_index_path): |
| import faiss |
| import numpy as np |
| print(f"Loading target FAISS index from {hubert_index_path} with blend ratio {hubert_blend}...") |
| hubert_index = faiss.read_index(hubert_index_path) |
| vectors_path = hubert_index_path.replace(".index", "_vectors.npy") |
| if os.path.exists(vectors_path): |
| target_hubert_vectors = np.load(vectors_path) |
| print(f"Loaded {target_hubert_vectors.shape[0]} target vectors for real-time inference retrieval.") |
| else: |
| print(f"WARNING: Source vectors {vectors_path} missing. FAISS disabled.") |
| hubert_index = None |
| |
| total_T_latent = 200 |
| if wave_path: |
| import soundfile as sf |
| try: |
| wav_data, sr = sf.read(wave_path) |
| import math |
| total_T_latent = math.ceil(len(wav_data) / sr * 44100 / 512) |
| print(f"Loaded {wave_path}, calculating total T_latent={total_T_latent}") |
| except Exception as e: |
| print(f"Could not load wave: {e}") |
| total_T_latent = 200 |
|
|
| |
| max_frames = 400 |
| overlap_frames = 50 |
| step_frames = max_frames - overlap_frames |
| |
| final_audio = None |
| |
| print("Starting chunked inference pipeline (Heun, Overlap-Add)...") |
| import numpy as np |
| |
| if wave_path: |
| file_id = os.path.basename(wave_path).replace('.wav', '') |
| else: |
| file_id = "mock" |
| |
| for start_idx in range(0, total_T_latent, step_frames): |
| T_latent = min(max_frames, total_T_latent - start_idx) |
| print(f"--- Processing chunk from frame {start_idx} to {start_idx + T_latent} (Length: {T_latent}) ---") |
| |
| |
| time_start = start_idx * 512 / 44100.0 |
| time_end = (start_idx + T_latent) * 512 / 44100.0 |
| |
| try: |
| ppg_full = np.load(f"data_svc_infer/whisper/speaker0/{file_id}.ppg.npy") |
| hubert_full = np.load(f"data_svc_infer/hubert/speaker0/{file_id}.vec.npy") |
| f0_full = np.load(f"data_svc_infer/pitch/speaker0/{file_id}.pit.npy") |
| |
| if target_spk is not None: |
| spk_full = np.load(target_spk) |
| else: |
| spk_full = np.load(f"data_svc_infer/speaker/speaker0/{file_id}.spk.npy") |
| |
| ppg_start, ppg_end = int(time_start * 50), int(time_end * 50) |
| hubert_start, hubert_end = int(time_start * 50), int(time_end * 50) |
| f0_start, f0_end = int(time_start * 100), int(time_end * 100) |
| |
| ppg = torch.tensor(ppg_full[max(0, ppg_start) : max(1, ppg_end)]).float().unsqueeze(0).to(device) |
| hubert = torch.tensor(hubert_full[max(0, hubert_start) : max(1, hubert_end)]).float().unsqueeze(0).to(device) |
| |
| f0_raw = torch.tensor(f0_full[max(0, f0_start) : max(1, f0_end)]).float() |
| f0 = torch.where(f0_raw > 0, torch.log(f0_raw.clamp(min=1.0)), torch.zeros_like(f0_raw)).unsqueeze(-1).unsqueeze(0).to(device) |
| |
| spk = torch.tensor(spk_full).float().unsqueeze(0).to(device) |
| |
| |
| if ppg.shape[1] == 0: ppg = torch.randn(1, max(1, T_latent // 2), 1280).to(device) |
| if hubert.shape[1] == 0: hubert = torch.randn(1, max(1, T_latent // 2), 256).to(device) |
| if f0.shape[1] == 0: f0 = torch.randn(1, T_latent, 1).to(device) |
| |
| |
| if hubert_index is not None and target_hubert_vectors is not None: |
| source_hubert_np = hubert.squeeze(0).cpu().numpy().astype(np.float32) |
| |
| _, I = hubert_index.search(source_hubert_np, 4) |
| |
| nn_hubert = target_hubert_vectors[I].mean(axis=1) |
| |
| blended_hubert = hubert_blend * nn_hubert + (1.0 - hubert_blend) * source_hubert_np |
| hubert = torch.tensor(blended_hubert).float().unsqueeze(0).to(device) |
| |
| except FileNotFoundError: |
| ppg = torch.randn(1, max(1, T_latent // 2), 1280).to(device) |
| hubert = torch.randn(1, max(1, T_latent // 2), 256).to(device) |
| f0 = torch.randn(1, T_latent, 1).to(device) |
| spk = torch.randn(1, 256).to(device) |
| |
| c = cond_enc(ppg, hubert, f0, spk, target_seq_len=T_latent) |
| |
| |
| sampler = ODESampler(dit, steps=12, solver='heun') |
| z_noise = torch.randn(1, T_latent, 1024).to(device) |
| u_hat = sampler.sample(z_noise, c) |
| |
| |
| u_hat_transposed = u_hat.transpose(1, 2) |
| z_hat_norm = codec_wrapper.forward_project(u_hat_transposed) |
| |
| |
| z_hat_norm_transposed = z_hat_norm.transpose(1, 2) |
| z_hat_denorm = (z_hat_norm_transposed * z_std) + z_mean |
| z_hat = z_hat_denorm.transpose(1, 2) |
| |
| |
| wav_chunk = codec_wrapper.decode(z_hat).cpu().squeeze().numpy() |
| |
| |
| if final_audio is None: |
| final_audio = wav_chunk |
| else: |
| |
| overlap_samples = overlap_frames * 512 |
| if len(wav_chunk) >= overlap_samples and len(final_audio) >= overlap_samples: |
| fade_in = np.linspace(0, 1, overlap_samples) |
| fade_out = 1 - fade_in |
| |
| |
| final_audio[-overlap_samples:] = final_audio[-overlap_samples:] * fade_out + wav_chunk[:overlap_samples] * fade_in |
| |
| final_audio = np.concatenate([final_audio, wav_chunk[overlap_samples:]]) |
| else: |
| final_audio = np.concatenate([final_audio, wav_chunk]) |
|
|
| if T_latent < max_frames: |
| break |
| |
| wav_out = final_audio |
| print(f"Inference complete! Final output waveform shape: {wav_out.shape}") |
| |
| import soundfile as sf |
| out_file = "output_sample.wav" |
| sf.write(out_file, wav_out, 44100) |
| print(f"Saved output to {out_file}") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--wave", type=str, default=None, help="Path to input wave file") |
| parser.add_argument("--epoch", type=int, default=None, help="Epoch number to load checkpoints from (e.g., 30)") |
| parser.add_argument("--target_spk", type=str, default=None, help="Path to target speaker .npy array for voice conversion") |
| parser.add_argument("--hubert_index", type=str, default=None, help="Path to FAISS .index file for HuBERT retrieval") |
| parser.add_argument("--hubert_blend", type=float, default=0.75, help="Blend ratio for HuBERT retrieval (0.0=source, 1.0=target)") |
| parser.add_argument("--ema", action="store_true", default=True, help="Use EMA-averaged weights (default: True, use --no-ema to disable)") |
| parser.add_argument("--no-ema", dest="ema", action="store_false", help="Use raw training weights instead of EMA") |
| args = parser.parse_args() |
| |
| infer_pipeline(args.wave, args.epoch, args.target_spk, args.hubert_index, args.hubert_blend, use_ema=args.ema) |
|
|
|
|