File size: 4,554 Bytes
2566adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339c325
 
 
 
 
 
 
 
 
 
 
 
2566adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339c325
 
 
 
 
 
2566adf
339c325
2566adf
339c325
2566adf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import json
import argparse
from tqdm import tqdm
import numpy as np
import soundfile as sf
from collections import OrderedDict
from omegaconf import DictConfig

from soulxsinger.utils.file_utils import load_config
from soulxsinger.models.soulxsinger_svc import SoulXSingerSVC
from soulxsinger.utils.audio_utils import load_wav


def build_model(
    model_path: str,
    config: DictConfig,
    device: str = "cuda",
):
    """
    Build the model from the pre-trained model path and model configuration.

    Args:
        model_path (str): Path to the checkpoint file.
        config (DictConfig): Model configuration.
        device (str, optional): Device to use. Defaults to "cuda".

    Returns:
        Tuple[torch.nn.Module, torch.nn.Module]: The initialized model and vocoder.
    """

    if not os.path.isfile(model_path):
        raise FileNotFoundError(
            f"Model checkpoint not found: {model_path}. "
            "Please download the pretrained model and place it at the path, or set --model_path."
        )
    model = SoulXSingerSVC(config).to(device)
    print("Model initialized.")
    print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")
    
    checkpoint = torch.load(model_path, weights_only=False, map_location=device)
    if "state_dict" not in checkpoint:
        raise KeyError(
            f"Checkpoint at {model_path} has no 'state_dict' key. "
            "Expected a checkpoint saved with model.state_dict()."
        )
    model.load_state_dict(checkpoint["state_dict"], strict=True)
    
    model.eval()
    model.to(device)
    print("Model checkpoint loaded.")

    return model


def process(args, config, model: torch.nn.Module):
    """Run the full inference pipeline given a data_processor and model.
    """

    os.makedirs(args.save_dir, exist_ok=True)
    pt_wav = load_wav(args.prompt_wav_path, config.audio.sample_rate).to(args.device)
    gt_wav = load_wav(args.target_wav_path, config.audio.sample_rate).to(args.device)
    pt_f0 = torch.from_numpy(np.load(args.prompt_f0_path)).unsqueeze(0).to(args.device)
    gt_f0 = torch.from_numpy(np.load(args.target_f0_path)).unsqueeze(0).to(args.device)

    n_step = args.n_steps if hasattr(args, "n_steps") else config.infer.n_steps
    cfg = args.cfg if hasattr(args, "cfg") else config.infer.cfg

    generated_audio, generated_shift = model.infer(
        pt_wav=pt_wav,
        gt_wav=gt_wav,
        pt_f0=pt_f0,
        gt_f0=gt_f0,
        auto_shift=args.auto_shift, 
        pitch_shift=args.pitch_shift, 
        n_steps=n_step, 
        cfg=cfg,
        use_fp16=args.use_fp16,
    )
    generated_audio = generated_audio.squeeze().float().cpu().numpy()
    if args.pitch_shift != generated_shift:
        args.pitch_shift = generated_shift
        # print(f"Applied pitch shift of {generated_shift} semitones to match GT F0 contour.")

    sf.write(os.path.join(args.save_dir, "generated.wav"), generated_audio, config.audio.sample_rate)
    print(f"Generated audio saved to {os.path.join(args.save_dir, 'generated.wav')}")


def main(args, config):
    model = build_model(
        model_path=args.model_path,
        config=config,
        device=args.device,
    )
    process(args, config, model)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--model_path", type=str, default='pretrained_models/soulx-singer/model.pt')
    parser.add_argument("--config", type=str, default='soulxsinger/config/soulxsinger.yaml')
    parser.add_argument("--prompt_wav_path", type=str, default='example/audio/zh_prompt.wav')
    parser.add_argument("--target_wav_path", type=str, default='example/audio/zh_target.wav')
    parser.add_argument("--prompt_f0_path", type=str, default='example/audio/zh_prompt_f0.npy')
    parser.add_argument("--target_f0_path", type=str, default='example/audio/zh_target_f0.npy')
    parser.add_argument("--save_dir", type=str, default='outputs')
    parser.add_argument("--auto_shift", action="store_true")
    parser.add_argument("--pitch_shift", type=int, default=0)
    parser.add_argument("--n_steps", type=int, default=32)
    parser.add_argument("--cfg", type=float, default=3.0)
    parser.add_argument(
        "--fp16",
        action="store_true",
        default=False,
        help="Use FP16 inference (faster on GPU)",
    )
    args = parser.parse_args()

    config = load_config(args.config)
    args.use_fp16 = args.fp16
    main(args, config)