SongGeneration / levo_inference.py
root
norm description
e7ab0ec
import os
import sys
import time
sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')
import torch
import numpy as np
from omegaconf import OmegaConf
from vllm import LLM, SamplingParams
from codeclm.models import builders
from codeclm.models.codeclm_gen import CodecLM_gen
from generate import check_language_by_text, load_audio
class LeVoInference(torch.nn.Module):
def __init__(self, ckpt_path):
super().__init__()
torch.backends.cudnn.enabled = False
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda: 'default')
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
cfg_path = os.path.join(ckpt_path, 'config.yaml')
self.cfg = OmegaConf.load(cfg_path)
self.cfg.mode = 'inference'
self.max_duration = self.cfg.max_dur
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
if audio_tokenizer is not None:
for param in audio_tokenizer.parameters():
param.requires_grad = False
print("Audio tokenizer successfully loaded!")
audio_tokenizer = audio_tokenizer.eval().cuda()
self.model_condition = CodecLM_gen(cfg=self.cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = self.max_duration)
self.model_condition.condition_provider.conditioners.load_state_dict(torch.load(self.cfg.lm_checkpoint+"/conditioners_weights.pth"))
self.embeded_eosp1 = torch.load(self.cfg.lm_checkpoint+'/embeded_eosp1.pt')
print('Conditioner successfully loaded!')
self.llm = LLM(
model=self.cfg.lm_checkpoint,
trust_remote_code=True,
tensor_parallel_size=self.cfg.vllm.device_num,
enforce_eager=True,
dtype="bfloat16",
gpu_memory_utilization=0.65,
max_num_seqs=8,
tokenizer=None,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
enable_chunked_prefill=True,
)
self.default_params = dict(
cfg_coef = 1.8,
temperature = 0.8,
top_k = 5000,
top_p = 0.0,
record_tokens = True,
record_window = 50,
extend_stride = 5,
duration = self.max_duration,
)
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
params = {**self.default_params, **params}
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
pmt_wav = load_audio(prompt_audio_path)
melody_is_wav = True
elif genre is not None and auto_prompt_path is not None:
auto_prompt = torch.load(auto_prompt_path)
lang = check_language_by_text(lyric)
prompt_token = auto_prompt[genre][lang][np.random.randint(0, len(auto_prompt[genre][lang]))]
pmt_wav = prompt_token[:,[0],:]
melody_is_wav = False
else:
pmt_wav = None
melody_is_wav = True
description = description.lower() if description else '.'
description = '[Musicality-very-high]' + ', ' + description
generate_inp = {
'descriptions': [lyric.replace(" ", " ")],
'type_info': [description],
'melody_wavs': pmt_wav,
'melody_is_wav': melody_is_wav,
'embeded_eosp1': self.embeded_eosp1,
}
fused_input, audio_qt_embs = self.model_condition.generate_condition(**generate_inp, return_tokens=True)
prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
allowed_token_ids = [x for x in range(self.cfg.lm.code_size+1) if x not in prompt_token]
sampling_params = SamplingParams(
max_tokens=self.cfg.audio_tokenizer_frame_rate*self.max_duration,
temperature=params["temperature"],
stop_token_ids=[self.cfg.lm.code_size],
top_k=params["top_k"],
frequency_penalty=0.2,
seed=int(time.time() * 1000000) % (2**32) if self.cfg.vllm.cfg else -1,
allowed_token_ids=allowed_token_ids,
guidance_scale=params["cfg_coef"]
)
# 拆成现支持的batch 3 CFG形式
prompts = [{"prompt_embeds": embed} for embed in fused_input]
condi, uncondi = prompts[0], prompts[1]
promptss = [condi, condi, uncondi]
outputs = self.llm.generate(promptss, sampling_params=sampling_params)
token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
with torch.no_grad():
if melody_is_wav:
wav_cfg = self.model_condition.generate_audio(token_ids_CFG, pmt_wav, chunked=True)
else:
wav_cfg = self.model_condition.generate_audio(token_ids_CFG, chunked=True)
return wav_cfg[0]