|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
import random |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
from einops import rearrange |
|
|
|
|
|
print("Current working directory:", os.getcwd()) |
|
|
|
|
|
from infer_utils import ( |
|
|
decode_audio, |
|
|
get_lrc_token, |
|
|
get_negative_style_prompt, |
|
|
get_reference_latent, |
|
|
get_style_prompt, |
|
|
prepare_model, |
|
|
) |
|
|
|
|
|
|
|
|
def inference( |
|
|
cfm_model, |
|
|
vae_model, |
|
|
cond, |
|
|
text, |
|
|
duration, |
|
|
style_prompt, |
|
|
negative_style_prompt, |
|
|
start_time, |
|
|
pred_frames, |
|
|
batch_infer_num, |
|
|
chunked=False, |
|
|
): |
|
|
with torch.inference_mode(): |
|
|
latents, _ = cfm_model.sample( |
|
|
cond=cond, |
|
|
text=text, |
|
|
duration=duration, |
|
|
style_prompt=style_prompt, |
|
|
negative_style_prompt=negative_style_prompt, |
|
|
steps=32, |
|
|
cfg_strength=4.0, |
|
|
start_time=start_time, |
|
|
latent_pred_segments=pred_frames, |
|
|
batch_infer_num=batch_infer_num |
|
|
) |
|
|
|
|
|
outputs = [] |
|
|
for latent in latents: |
|
|
latent = latent.to(torch.float32) |
|
|
latent = latent.transpose(1, 2) |
|
|
|
|
|
output = decode_audio(latent, vae_model, chunked=chunked) |
|
|
|
|
|
|
|
|
output = rearrange(output, "b d n -> d (b n)") |
|
|
|
|
|
output = ( |
|
|
output.to(torch.float32) |
|
|
.div(torch.max(torch.abs(output))) |
|
|
.clamp(-1, 1) |
|
|
.mul(32767) |
|
|
.to(torch.int16) |
|
|
.cpu() |
|
|
) |
|
|
outputs.append(output) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--lrc-path", |
|
|
type=str, |
|
|
help="lyrics of target song", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ref-prompt", |
|
|
type=str, |
|
|
help="reference prompt as style prompt for target song", |
|
|
required=False, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ref-audio-path", |
|
|
type=str, |
|
|
help="reference audio as style prompt for target song", |
|
|
required=False, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chunked", |
|
|
action="store_true", |
|
|
help="whether to use chunked decoding", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--audio-length", |
|
|
type=int, |
|
|
default=95, |
|
|
choices=[95, 285], |
|
|
help="length of generated song", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--repo-id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="infer/example/output", |
|
|
help="output directory fo generated song", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--edit", |
|
|
action="store_true", |
|
|
help="whether to open edit mode", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ref-song", |
|
|
type=str, |
|
|
required=False, |
|
|
help="reference prompt as latent prompt for editing", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--edit-segments", |
|
|
type=str, |
|
|
required=False, |
|
|
help="Time segments to edit (in seconds). Format: `[[start1,end1],...]`. " |
|
|
"Use `-1` for audio start/end (e.g., `[[-1,25], [50.0,-1]]`)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-infer-num", |
|
|
type=int, |
|
|
default=1, |
|
|
required=False, |
|
|
help="number of songs per batch", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
assert ( |
|
|
args.ref_prompt or args.ref_audio_path |
|
|
), "either ref_prompt or ref_audio_path should be provided" |
|
|
assert not ( |
|
|
args.ref_prompt and args.ref_audio_path |
|
|
), "only one of them should be provided" |
|
|
if args.edit: |
|
|
assert ( |
|
|
args.ref_song and args.edit_segments |
|
|
), "reference song and edit segments should be provided for editing" |
|
|
|
|
|
device = "cpu" |
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
elif torch.mps.is_available(): |
|
|
device = "mps" |
|
|
|
|
|
audio_length = args.audio_length |
|
|
if audio_length == 95: |
|
|
max_frames = 2048 |
|
|
elif audio_length == 285: |
|
|
max_frames = 6144 |
|
|
|
|
|
cfm, tokenizer, muq, vae = prepare_model(max_frames, device, repo_id=args.repo_id) |
|
|
|
|
|
if args.lrc_path: |
|
|
with open(args.lrc_path, "r", encoding='utf-8') as f: |
|
|
lrc = f.read() |
|
|
else: |
|
|
lrc = "" |
|
|
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device) |
|
|
|
|
|
if args.ref_audio_path: |
|
|
style_prompt = get_style_prompt(muq, args.ref_audio_path) |
|
|
else: |
|
|
style_prompt = get_style_prompt(muq, prompt=args.ref_prompt) |
|
|
|
|
|
negative_style_prompt = get_negative_style_prompt(device) |
|
|
|
|
|
latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae) |
|
|
|
|
|
s_t = time.time() |
|
|
generated_songs = inference( |
|
|
cfm_model=cfm, |
|
|
vae_model=vae, |
|
|
cond=latent_prompt, |
|
|
text=lrc_prompt, |
|
|
duration=max_frames, |
|
|
style_prompt=style_prompt, |
|
|
negative_style_prompt=negative_style_prompt, |
|
|
start_time=start_time, |
|
|
pred_frames=pred_frames, |
|
|
chunked=args.chunked, |
|
|
batch_infer_num=args.batch_infer_num |
|
|
) |
|
|
e_t = time.time() - s_t |
|
|
print(f"inference cost {e_t:.2f} seconds") |
|
|
|
|
|
generated_song = random.sample(generated_songs, 1)[0] |
|
|
|
|
|
output_dir = args.output_dir |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
output_path = os.path.join(output_dir, "output.wav") |
|
|
torchaudio.save(output_path, generated_song, sample_rate=44100) |
|
|
|