import os import torch import json import argparse from tqdm import tqdm import numpy as np import soundfile as sf from collections import OrderedDict from omegaconf import DictConfig from soulxsinger.utils.file_utils import load_config from soulxsinger.models.soulxsinger import SoulXSinger from soulxsinger.utils.data_processor import DataProcessor def build_model( model_path: str, config: DictConfig, device: str = "cuda", ): """ Build the model from the pre-trained model path and model configuration. Args: model_path (str): Path to the checkpoint file. config (DictConfig): Model configuration. device (str, optional): Device to use. Defaults to "cuda". Returns: Tuple[torch.nn.Module, torch.nn.Module]: The initialized model and vocoder. """ if not os.path.isfile(model_path): raise FileNotFoundError( f"Model checkpoint not found: {model_path}. " "Please download the pretrained model and place it at the path, or set --model_path." ) model = SoulXSinger(config).to(device) print("Model initialized.") print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M") checkpoint = torch.load(model_path, weights_only=False, map_location=device) if "state_dict" not in checkpoint: raise KeyError( f"Checkpoint at {model_path} has no 'state_dict' key. " "Expected a checkpoint saved with model.state_dict()." ) model.load_state_dict(checkpoint["state_dict"], strict=True) model.eval() model.to(device) print("Model checkpoint loaded.") return model def process(args, config, model: torch.nn.Module): """Run the full inference pipeline given a data_processor and model. """ if args.control not in ("melody", "score"): raise ValueError(f"control must be 'melody' or 'score', got: {args.control}") print(f"prompt_metadata_path: {args.prompt_metadata_path}") print(f"target_metadata_path: {args.target_metadata_path}") os.makedirs(args.save_dir, exist_ok=True) data_processor = DataProcessor( hop_size=config.audio.hop_size, sample_rate=config.audio.sample_rate, phoneset_path=args.phoneset_path, device=args.device, ) with open(args.prompt_metadata_path, "r", encoding="utf-8") as f: prompt_meta_list = json.load(f) if not prompt_meta_list: raise ValueError("Prompt metadata is empty. Please run preprocess on prompt audio first.") prompt_meta = prompt_meta_list[0] # load the first segment as the prompt with open(args.target_metadata_path, "r", encoding="utf-8") as f: target_meta_list = json.load(f) infer_prompt_data = data_processor.process(prompt_meta, args.prompt_wav_path) assert len(target_meta_list) > 0, "No target segments found in the target metadata." generated_len = int(target_meta_list[-1]["time"][1] / 1000 * config.audio.sample_rate) generated_merged = np.zeros(generated_len, dtype=np.float32) for idx, target_meta in enumerate( tqdm(target_meta_list, total=len(target_meta_list), desc="Inferring segments"), ): start_sample_idx = int(target_meta["time"][0] / 1000 * config.audio.sample_rate) end_sample_idx = int(target_meta["time"][1] / 1000 * config.audio.sample_rate) infer_target_data = data_processor.process(target_meta, None) infer_data = { "prompt": infer_prompt_data, "target": infer_target_data, } with torch.no_grad(): generated_audio = model.infer( infer_data, auto_shift=args.auto_shift, pitch_shift=args.pitch_shift, n_steps=config.infer.n_steps, cfg=config.infer.cfg, control=args.control, ) generated_audio = generated_audio.squeeze().cpu().numpy() generated_merged[start_sample_idx : start_sample_idx + generated_audio.shape[0]] = generated_audio merged_path = os.path.join(args.save_dir, "generated.wav") sf.write(merged_path, generated_merged, 24000) print(f"Generated audio saved to {merged_path}") def main(args, config): model = build_model( model_path=args.model_path, config=config, device=args.device, ) process(args, config, model) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--model_path", type=str, default='pretrained_models/soulx-singer/model.pt') parser.add_argument("--config", type=str, default='soulxsinger/config/soulxsinger.yaml') parser.add_argument("--prompt_wav_path", type=str, default='example/audio/zh_prompt.wav') parser.add_argument("--prompt_metadata_path", type=str, default='example/metadata/zh_prompt.json') parser.add_argument("--target_metadata_path", type=str, default='example/metadata/zh_target.json') parser.add_argument("--phoneset_path", type=str, default='soulxsinger/utils/phoneme/phone_set.json') parser.add_argument("--save_dir", type=str, default='outputs') parser.add_argument("--auto_shift", action="store_true") parser.add_argument("--pitch_shift", type=int, default=0) parser.add_argument( "--control", type=str, default="melody", choices=["melody", "score"], help="Control mode: melody or score only", ) args = parser.parse_args() config = load_config(args.config) main(args, config)