Spaces:
Paused
Paused
File size: 4,554 Bytes
2566adf 339c325 2566adf 339c325 2566adf 339c325 2566adf 339c325 2566adf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import os
import torch
import json
import argparse
from tqdm import tqdm
import numpy as np
import soundfile as sf
from collections import OrderedDict
from omegaconf import DictConfig
from soulxsinger.utils.file_utils import load_config
from soulxsinger.models.soulxsinger_svc import SoulXSingerSVC
from soulxsinger.utils.audio_utils import load_wav
def build_model(
model_path: str,
config: DictConfig,
device: str = "cuda",
):
"""
Build the model from the pre-trained model path and model configuration.
Args:
model_path (str): Path to the checkpoint file.
config (DictConfig): Model configuration.
device (str, optional): Device to use. Defaults to "cuda".
Returns:
Tuple[torch.nn.Module, torch.nn.Module]: The initialized model and vocoder.
"""
if not os.path.isfile(model_path):
raise FileNotFoundError(
f"Model checkpoint not found: {model_path}. "
"Please download the pretrained model and place it at the path, or set --model_path."
)
model = SoulXSingerSVC(config).to(device)
print("Model initialized.")
print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")
checkpoint = torch.load(model_path, weights_only=False, map_location=device)
if "state_dict" not in checkpoint:
raise KeyError(
f"Checkpoint at {model_path} has no 'state_dict' key. "
"Expected a checkpoint saved with model.state_dict()."
)
model.load_state_dict(checkpoint["state_dict"], strict=True)
model.eval()
model.to(device)
print("Model checkpoint loaded.")
return model
def process(args, config, model: torch.nn.Module):
"""Run the full inference pipeline given a data_processor and model.
"""
os.makedirs(args.save_dir, exist_ok=True)
pt_wav = load_wav(args.prompt_wav_path, config.audio.sample_rate).to(args.device)
gt_wav = load_wav(args.target_wav_path, config.audio.sample_rate).to(args.device)
pt_f0 = torch.from_numpy(np.load(args.prompt_f0_path)).unsqueeze(0).to(args.device)
gt_f0 = torch.from_numpy(np.load(args.target_f0_path)).unsqueeze(0).to(args.device)
n_step = args.n_steps if hasattr(args, "n_steps") else config.infer.n_steps
cfg = args.cfg if hasattr(args, "cfg") else config.infer.cfg
generated_audio, generated_shift = model.infer(
pt_wav=pt_wav,
gt_wav=gt_wav,
pt_f0=pt_f0,
gt_f0=gt_f0,
auto_shift=args.auto_shift,
pitch_shift=args.pitch_shift,
n_steps=n_step,
cfg=cfg,
use_fp16=args.use_fp16,
)
generated_audio = generated_audio.squeeze().float().cpu().numpy()
if args.pitch_shift != generated_shift:
args.pitch_shift = generated_shift
# print(f"Applied pitch shift of {generated_shift} semitones to match GT F0 contour.")
sf.write(os.path.join(args.save_dir, "generated.wav"), generated_audio, config.audio.sample_rate)
print(f"Generated audio saved to {os.path.join(args.save_dir, 'generated.wav')}")
def main(args, config):
model = build_model(
model_path=args.model_path,
config=config,
device=args.device,
)
process(args, config, model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--model_path", type=str, default='pretrained_models/soulx-singer/model.pt')
parser.add_argument("--config", type=str, default='soulxsinger/config/soulxsinger.yaml')
parser.add_argument("--prompt_wav_path", type=str, default='example/audio/zh_prompt.wav')
parser.add_argument("--target_wav_path", type=str, default='example/audio/zh_target.wav')
parser.add_argument("--prompt_f0_path", type=str, default='example/audio/zh_prompt_f0.npy')
parser.add_argument("--target_f0_path", type=str, default='example/audio/zh_target_f0.npy')
parser.add_argument("--save_dir", type=str, default='outputs')
parser.add_argument("--auto_shift", action="store_true")
parser.add_argument("--pitch_shift", type=int, default=0)
parser.add_argument("--n_steps", type=int, default=32)
parser.add_argument("--cfg", type=float, default=3.0)
parser.add_argument(
"--fp16",
action="store_true",
default=False,
help="Use FP16 inference (faster on GPU)",
)
args = parser.parse_args()
config = load_config(args.config)
args.use_fp16 = args.fp16
main(args, config)
|