Spaces:
Runtime error
Runtime error
File size: 8,650 Bytes
c8c0ef5 258fd02 c8c0ef5 258fd02 c8c0ef5 f9e2d84 c8c0ef5 98a0e3b 258fd02 f9e2d84 258fd02 c8c0ef5 258fd02 98a0e3b c8c0ef5 98a0e3b d658154 c8c0ef5 f9e2d84 c8c0ef5 98a0e3b c8c0ef5 d658154 258fd02 98a0e3b c8c0ef5 98a0e3b c8c0ef5 98a0e3b c8c0ef5 98a0e3b c8c0ef5 d658154 258fd02 c8c0ef5 d658154 c8c0ef5 d658154 98a0e3b c8c0ef5 98a0e3b c8c0ef5 98a0e3b d658154 c8c0ef5 d658154 98a0e3b c8c0ef5 258fd02 c8c0ef5 258fd02 d658154 c8c0ef5 258fd02 c8c0ef5 3779445 258fd02 c8c0ef5 98a0e3b c8c0ef5 98a0e3b d658154 c8c0ef5 258fd02 d658154 258fd02 98a0e3b c8c0ef5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import glob
import time
import torch
from codeclm.models.codeclm_gen import CodecLM_gen
from codeclm.models import builders
import sys
import os
import torchaudio
import numpy as np
import json
from vllm import LLM, SamplingParams
import re
import argparse
import librosa
auto_prompt_type = ['Pop', 'Latin', 'Rock', 'Electronic', 'Metal', 'Country', 'R&B/Soul', 'Ballad', 'Jazz', 'World', 'Hip-Hop', 'Funk', 'Soundtrack','Auto']
def check_language_by_text(text):
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
english_pattern = re.compile(r'[a-zA-Z]')
chinese_count = len(re.findall(chinese_pattern, text))
english_count = len(re.findall(english_pattern, text))
chinese_ratio = chinese_count / len(text)
english_ratio = english_count / len(text)
if chinese_ratio >= 0.2:
return "zh"
elif english_ratio >= 0.5:
return "en"
else:
return "en"
def load_audio(f):
a, fs= librosa.load(f, sr=48000)
a = torch.tensor(a).unsqueeze(0)
if (fs != 48000):
a = torchaudio.functional.resample(a, fs, 48000)
if a.shape[-1] >= 48000*10:
a = a[..., :48000*10]
return a[:, 0:48000*10]
def parse_args():
parser = argparse.ArgumentParser(description='Song Generation Script')
# 必需参数
parser.add_argument('--input_jsonl', type=str, required=True,
help='Path to input JSONL file containing generation tasks')
parser.add_argument('--save_dir', type=str, required=True,
help='Directory to save generated audio files and results')
parser.add_argument('--config_path', type=str, required=True,
help='Path to the config file')
return parser.parse_args()
def main():
torch.set_num_threads(1)
torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错
from omegaconf import OmegaConf
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
args = parse_args()
input_jsonl = args.input_jsonl
save_dir = args.save_dir
cfg_path = args.config_path
cfg = OmegaConf.load(cfg_path)
cfg.mode = 'inference'
max_duration = cfg.max_dur
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
if audio_tokenizer is not None:
for param in audio_tokenizer.parameters():
param.requires_grad = False
print("Audio tokenizer successfully loaded!")
audio_tokenizer = audio_tokenizer.eval().cuda()
model_condition = CodecLM_gen(cfg=cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = max_duration)
model_condition.condition_provider.conditioners.load_state_dict(torch.load(cfg.lm_checkpoint+"/conditioners_weights.pth"))
print('Conditioner successfully loaded!')
llm = LLM(
model=cfg.lm_checkpoint,
trust_remote_code=True,
tensor_parallel_size=cfg.vllm.device_num,
enforce_eager=False,
dtype="bfloat16",
gpu_memory_utilization=cfg.vllm.gpu_memory_utilization,
tokenizer=None,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
enable_chunked_prefill=True,
)
print("LLM 初始化成功")
auto_prompt = torch.load('tools/new_prompt.pt')
guidance_scale = cfg.vllm.guidance_scale
temp = cfg.vllm.temp
top_k = cfg.vllm.top_k
sum_time = 0
sum_wav_len = 0
os.makedirs(save_dir, exist_ok=True)
os.makedirs(save_dir + "/audios", exist_ok=True)
os.makedirs(save_dir + "/jsonl", exist_ok=True)
with open(input_jsonl, "r") as fp:
lines = fp.readlines()
new_items = []
for line in lines:
item = json.loads(line)
lyric = item["gt_lyric"]
descriptions = item["descriptions"].lower() if "descriptions" in item else '.'
descriptions = '[Musicality-very-high]' + ', ' + descriptions
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
if os.path.exists(target_wav_name):
continue
if "prompt_audio_path" in item:
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
with torch.no_grad():
pmt_wav = load_audio(item['prompt_audio_path'])
item['raw_pmt_wav'] = pmt_wav
if pmt_wav.dim() == 2:
pmt_wav = pmt_wav[None]
if pmt_wav.dim() != 3:
raise ValueError("Melody wavs should have a shape [B, C, T].")
pmt_wav = list(pmt_wav)
if type(pmt_wav) == list:
pmt_wav = torch.stack(pmt_wav, dim=0)
with torch.no_grad():
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
print(pmt_wav.shape)
melody_is_wav = False
elif "auto_prompt_audio_type" in item:
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
lang = check_language_by_text(item['gt_lyric'])
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][lang][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]][lang]))]
pmt_wav = prompt_token[:,[0],:]
melody_is_wav = False
else:
pmt_wav = None
melody_is_wav = True
item["idx"] = f"{item['idx']}"
item["wav_path"] = target_wav_name
embeded_eosp1 = torch.load(cfg.lm_checkpoint+'/embeded_eosp1.pt')
generate_inp = {
'descriptions': [lyric.replace(" ", " ")],
'type_info': [descriptions],
'melody_wavs': pmt_wav,
'melody_is_wav': melody_is_wav,
'embeded_eosp1': embeded_eosp1,
}
fused_input, audio_qt_embs = model_condition.generate_condition(**generate_inp, return_tokens=True)
prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
allowed_token_ids = [x for x in range(cfg.lm.code_size+1) if x not in prompt_token]
sampling_params = SamplingParams(
max_tokens=cfg.audio_tokenizer_frame_rate*cfg.max_dur,
temperature=temp,
stop_token_ids=[cfg.lm.code_size],
top_k=top_k,
frequency_penalty=0.2,
seed=int(time.time() * 1000000) % (2**32) if cfg.vllm.cfg else -1,
allowed_token_ids=allowed_token_ids,
guidance_scale=guidance_scale
)
# 拆成现支持的batch 3 CFG形式
prompts = [{"prompt_embeds": embed} for embed in fused_input]
promptss = []
for _ in range(2):
promptss+=prompts
uncondi = prompts[1]
promptss = promptss[::2] + [uncondi]
start_time = time.time()
outputs = llm.generate(promptss, sampling_params=sampling_params)
mid_time = time.time()
token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
with torch.no_grad():
# wav_nocfg = model_condition.generate_audio(token_ids)
if 'raw_pmt_wav' in item:
wav_cfg = model_condition.generate_audio(token_ids_CFG, item['raw_pmt_wav'])
del item['raw_pmt_wav']
else:
wav_cfg = model_condition.generate_audio(token_ids_CFG)
end_time = time.time()
torchaudio.save(target_wav_name, wav_cfg[0].cpu().float(), cfg.sample_rate)
sum_time += end_time - start_time
sum_wav_len += (token_ids_CFG.shape[-1] / 25)
print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}, rtf {(end_time - start_time) / token_ids_CFG.shape[-1] * 25:.2f}")
new_items.append(item)
print(f"Total time: {sum_time:.4f} seconds, total wav length: {sum_wav_len:.4f} seconds, rtf {sum_time/sum_wav_len:.2f}")
src_jsonl_name = os.path.split(input_jsonl)[-1]
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
for item in new_items:
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
if __name__ == "__main__":
main()
|