Spaces:
Running
Running
File size: 5,728 Bytes
1b8b9eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import argparse
import torch
import yaml
import soundfile as sf
import time
from modules.commons import str2bool
# Set up device and torch configurations
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
dtype = torch.float16
# Global variables to store model instances
vc_wrapper_v2 = None
def load_v2_models(args):
"""Load V2 models using the wrapper from app.py"""
from hydra.utils import instantiate
from omegaconf import DictConfig
cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
vc_wrapper = instantiate(cfg)
vc_wrapper.load_checkpoints(ar_checkpoint_path=args.ar_checkpoint_path,
cfm_checkpoint_path=args.cfm_checkpoint_path)
vc_wrapper.to(device)
vc_wrapper.eval()
vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
if args.compile:
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
# Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True
vc_wrapper.compile_ar()
# vc_wrapper.compile_cfm()
return vc_wrapper
def convert_voice_v2(source_audio_path, target_audio_path, args):
"""Convert voice using V2 model"""
global vc_wrapper_v2
if vc_wrapper_v2 is None:
vc_wrapper_v2 = load_v2_models(args)
# Use the generator function but collect all outputs
generator = vc_wrapper_v2.convert_voice_with_streaming(
source_audio_path=source_audio_path,
target_audio_path=target_audio_path,
diffusion_steps=args.diffusion_steps,
length_adjust=args.length_adjust,
intelligebility_cfg_rate=args.intelligibility_cfg_rate,
similarity_cfg_rate=args.similarity_cfg_rate,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
convert_style=args.convert_style,
anonymization_only=args.anonymization_only,
device=device,
dtype=dtype,
stream_output=True
)
# Collect all outputs from the generator
for output in generator:
_, full_audio = output
return full_audio
def main(args):
# Create output directory if it doesn't exist
os.makedirs(args.output, exist_ok=True)
start_time = time.time()
converted_audio = convert_voice_v2(args.source, args.target, args)
end_time = time.time()
if converted_audio is None:
print("Error: Failed to convert voice")
return
# Save the converted audio
source_name = os.path.basename(args.source).split(".")[0]
target_name = os.path.basename(args.target).split(".")[0]
# Create a descriptive filename
filename = f"vc_v2_{source_name}_{target_name}_{args.length_adjust}_{args.diffusion_steps}_{args.similarity_cfg_rate}.wav"
output_path = os.path.join(args.output, filename)
save_sr, converted_audio = converted_audio
sf.write(output_path, converted_audio, save_sr)
print(f"Voice conversion completed in {end_time - start_time:.2f} seconds")
print(f"Output saved to: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Voice Conversion Inference Script")
parser.add_argument("--source", type=str, required=True,
help="Path to source audio file")
parser.add_argument("--target", type=str, required=True,
help="Path to target/reference audio file")
parser.add_argument("--output", type=str, default="./output",
help="Output directory for converted audio")
parser.add_argument("--diffusion-steps", type=int, default=30,
help="Number of diffusion steps")
parser.add_argument("--length-adjust", type=float, default=1.0,
help="Length adjustment factor (<1.0 for speed-up, >1.0 for slow-down)")
parser.add_argument("--compile", type=bool, default=False,
help="Whether to compile the model for faster inference")
# V2 specific arguments
parser.add_argument("--intelligibility-cfg-rate", type=float, default=0.7,
help="Intelligibility CFG rate for V2 model")
parser.add_argument("--similarity-cfg-rate", type=float, default=0.7,
help="Similarity CFG rate for V2 model")
parser.add_argument("--top-p", type=float, default=0.9,
help="Top-p sampling parameter for V2 model")
parser.add_argument("--temperature", type=float, default=1.0,
help="Temperature sampling parameter for V2 model")
parser.add_argument("--repetition-penalty", type=float, default=1.0,
help="Repetition penalty for V2 model")
parser.add_argument("--convert-style", type=str2bool, default=False,
help="Convert style/emotion/accent for V2 model")
parser.add_argument("--anonymization-only", type=str2bool, default=False,
help="Anonymization only mode for V2 model")
# V2 custom checkpoints
parser.add_argument("--ar-checkpoint-path", type=str, default=None,
help="Path to custom checkpoint file")
parser.add_argument("--cfm-checkpoint-path", type=str, default=None,
help="Path to custom checkpoint file")
args = parser.parse_args()
main(args) |