MySongStudio / tools /gradio /levo_inference.py
Novix's picture
Upload 6 files
8940c35 verified
import os
import re
import sys
import torch
import json
import numpy as np
from omegaconf import OmegaConf
from codeclm.trainer.codec_song_pl import CodecLM_PL
from codeclm.models import CodecLM
from separator import Separator
def check_language_by_text(text):
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
english_pattern = re.compile(r'[a-zA-Z]')
chinese_count = len(re.findall(chinese_pattern, text))
english_count = len(re.findall(english_pattern, text))
chinese_ratio = chinese_count / len(text)
english_ratio = english_count / len(text)
if chinese_ratio >= 0.2:
return "zh"
elif english_ratio >= 0.5:
return "en"
else:
return "en"
class LeVoInference(torch.nn.Module):
def __init__(self, ckpt_path):
super().__init__()
torch.backends.cudnn.enabled = False
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda: 'default')
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
cfg_path = os.path.join(ckpt_path, 'config.yaml')
pt_path = os.path.join(ckpt_path, 'model.pt')
self.cfg = OmegaConf.load(cfg_path)
self.cfg.mode = 'inference'
self.max_duration = self.cfg.max_dur
# Define model or load pretrained model
model_light = CodecLM_PL(self.cfg, pt_path, version="v2")
model_light = model_light.eval().cuda()
model_light.audiolm.cfg = self.cfg
self.model_lm = model_light.audiolm
self.model_audio_tokenizer = model_light.audio_tokenizer
self.model_seperate_tokenizer = model_light.seperate_tokenizer
self.model = CodecLM(name = "tmp",
lm = self.model_lm,
audiotokenizer = self.model_audio_tokenizer,
max_duration = self.max_duration,
seperate_tokenizer = self.model_seperate_tokenizer,
)
self.separator = Separator()
self.default_params = dict(
cfg_coef = 1.5,
temperature = 1.0,
top_k = 50,
top_p = 0.0,
record_tokens = True,
record_window = 50,
extend_stride = 5,
duration = self.max_duration,
)
self.model.set_generation_params(**self.default_params)
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
params = {**self.default_params, **params}
self.model.set_generation_params(**params)
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
melody_is_wav = True
elif genre is not None and auto_prompt_path is not None:
auto_prompt = torch.load(auto_prompt_path)
lang = check_language_by_text(lyric)
prompt_token = auto_prompt[genre][lang][np.random.randint(0, len(auto_prompt[genre][lang]))]
pmt_wav = prompt_token[:,[0],:]
vocal_wav = prompt_token[:,[1],:]
bgm_wav = prompt_token[:,[2],:]
melody_is_wav = False
else:
pmt_wav = None
vocal_wav = None
bgm_wav = None
melody_is_wav = True
if gen_type == 'bgm':
description = '[Musicality-very-high]' + ', ' + '[Pure-Music]' + ', ' + description.lower() if description else '.'
else:
description = description.lower() if description else '.'
description = '[Musicality-very-high]' + ', ' + description
generate_inp = {
'lyrics': [lyric.replace(" ", " ")] if gen_type != 'bgm' else '.',
'descriptions': [description],
'melody_wavs': pmt_wav,
'vocal_wavs': vocal_wav,
'bgm_wavs': bgm_wav,
'melody_is_wav': melody_is_wav,
}
with torch.autocast(device_type="cuda", dtype=torch.float16):
tokens = self.model.generate(**generate_inp, return_tokens=True)
with torch.no_grad():
if melody_is_wav:
wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
else:
wav_seperate = self.model.generate_audio(tokens, gen_type=gen_type)
return wav_seperate[0]