Spaces:
Runtime error
Runtime error
File size: 5,282 Bytes
2644f3e c8c0ef5 2644f3e 544833b 2644f3e c8c0ef5 2644f3e f9e2d84 c8c0ef5 2644f3e d658154 2644f3e f6b176e 2644f3e d658154 2644f3e b000a9b 2644f3e f9e2d84 c8c0ef5 eb8bfb7 c8c0ef5 c298f3c e7ab0ec c8c0ef5 2644f3e c8c0ef5 2644f3e d2ca82a 2644f3e 544833b c8c0ef5 d658154 48275bf d658154 2644f3e e7ab0ec 57d225d 2644f3e c8c0ef5 2644f3e d658154 c8c0ef5 2644f3e c8c0ef5 e7ab0ec c8c0ef5 2644f3e e55b921 c298f3c e55b921 c298f3c 2644f3e c8c0ef5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import os
import sys
import time
sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')
import torch
import numpy as np
from omegaconf import OmegaConf
from vllm import LLM, SamplingParams
from codeclm.models import builders
from codeclm.models.codeclm_gen import CodecLM_gen
from generate import check_language_by_text, load_audio
class LeVoInference(torch.nn.Module):
def __init__(self, ckpt_path):
super().__init__()
torch.backends.cudnn.enabled = False
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda: 'default')
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
cfg_path = os.path.join(ckpt_path, 'config.yaml')
self.cfg = OmegaConf.load(cfg_path)
self.cfg.mode = 'inference'
self.max_duration = self.cfg.max_dur
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
if audio_tokenizer is not None:
for param in audio_tokenizer.parameters():
param.requires_grad = False
print("Audio tokenizer successfully loaded!")
audio_tokenizer = audio_tokenizer.eval().cuda()
self.model_condition = CodecLM_gen(cfg=self.cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = self.max_duration)
self.model_condition.condition_provider.conditioners.load_state_dict(torch.load(self.cfg.lm_checkpoint+"/conditioners_weights.pth"))
self.embeded_eosp1 = torch.load(self.cfg.lm_checkpoint+'/embeded_eosp1.pt')
print('Conditioner successfully loaded!')
self.llm = LLM(
model=self.cfg.lm_checkpoint,
trust_remote_code=True,
tensor_parallel_size=self.cfg.vllm.device_num,
enforce_eager=True,
dtype="bfloat16",
gpu_memory_utilization=0.65,
max_num_seqs=8,
tokenizer=None,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
enable_chunked_prefill=True,
)
self.default_params = dict(
cfg_coef = 1.8,
temperature = 0.8,
top_k = 5000,
top_p = 0.0,
record_tokens = True,
record_window = 50,
extend_stride = 5,
duration = self.max_duration,
)
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
params = {**self.default_params, **params}
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
pmt_wav = load_audio(prompt_audio_path)
melody_is_wav = True
elif genre is not None and auto_prompt_path is not None:
auto_prompt = torch.load(auto_prompt_path)
lang = check_language_by_text(lyric)
prompt_token = auto_prompt[genre][lang][np.random.randint(0, len(auto_prompt[genre][lang]))]
pmt_wav = prompt_token[:,[0],:]
melody_is_wav = False
else:
pmt_wav = None
melody_is_wav = True
description = description.lower() if description else '.'
description = '[Musicality-very-high]' + ', ' + description
generate_inp = {
'descriptions': [lyric.replace(" ", " ")],
'type_info': [description],
'melody_wavs': pmt_wav,
'melody_is_wav': melody_is_wav,
'embeded_eosp1': self.embeded_eosp1,
}
fused_input, audio_qt_embs = self.model_condition.generate_condition(**generate_inp, return_tokens=True)
prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
allowed_token_ids = [x for x in range(self.cfg.lm.code_size+1) if x not in prompt_token]
sampling_params = SamplingParams(
max_tokens=self.cfg.audio_tokenizer_frame_rate*self.max_duration,
temperature=params["temperature"],
stop_token_ids=[self.cfg.lm.code_size],
top_k=params["top_k"],
frequency_penalty=0.2,
seed=int(time.time() * 1000000) % (2**32) if self.cfg.vllm.cfg else -1,
allowed_token_ids=allowed_token_ids,
guidance_scale=params["cfg_coef"]
)
# 拆成现支持的batch 3 CFG形式
prompts = [{"prompt_embeds": embed} for embed in fused_input]
condi, uncondi = prompts[0], prompts[1]
promptss = [condi, condi, uncondi]
outputs = self.llm.generate(promptss, sampling_params=sampling_params)
token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
with torch.no_grad():
if melody_is_wav:
wav_cfg = self.model_condition.generate_audio(token_ids_CFG, pmt_wav, chunked=True)
else:
wav_cfg = self.model_condition.generate_audio(token_ids_CFG, chunked=True)
return wav_cfg[0]
|