Spaces:
Running
on
L40S
Running
on
L40S
File size: 4,560 Bytes
2644f3e 544833b 2644f3e f9e2d84 2644f3e 57d225d 2644f3e d658154 2644f3e f6b176e 2644f3e d658154 2644f3e b000a9b 2644f3e f9e2d84 2644f3e f9e2d84 2644f3e f9e2d84 2644f3e f9e2d84 2644f3e f9e2d84 2644f3e d2ca82a 2644f3e 544833b 2644f3e d658154 57d225d d658154 2644f3e 57d225d 2644f3e d658154 2644f3e e55b921 3779445 e55b921 3779445 2644f3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import sys
sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')
import torch
import json
import numpy as np
from omegaconf import OmegaConf
from codeclm.models import builders
from codeclm.models import CodecLM
from separator import Separator
from generate import check_language_by_text
class LeVoInference(torch.nn.Module):
def __init__(self, ckpt_path):
super().__init__()
torch.backends.cudnn.enabled = False
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("get_fname", lambda: 'default')
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
cfg_path = os.path.join(ckpt_path, 'config.yaml')
pt_path = os.path.join(ckpt_path, 'model.pt')
self.cfg = OmegaConf.load(cfg_path)
self.cfg.mode = 'inference'
self.max_duration = self.cfg.max_dur
# Define model or load pretrained model
audiolm = builders.get_lm_model(self.cfg, version='v1.5')
checkpoint = torch.load(pt_path, map_location='cpu')
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
audiolm.load_state_dict(audiolm_state_dict, strict=False)
audiolm = audiolm.eval()
audiolm = audiolm.cuda().to(torch.float16)
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
audio_tokenizer = audio_tokenizer.eval()
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
seperate_tokenizer = seperate_tokenizer.eval()
self.model = CodecLM(name = "tmp",
lm = audiolm,
audiotokenizer = audio_tokenizer,
max_duration = self.max_duration,
seperate_tokenizer = seperate_tokenizer,
)
self.separator = Separator()
self.default_params = dict(
cfg_coef = 1.5,
temperature = 1.0,
top_k = 50,
top_p = 0.0,
record_tokens = True,
record_window = 50,
extend_stride = 5,
duration = self.max_duration,
)
self.model.set_generation_params(**self.default_params)
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
params = {**self.default_params, **params}
self.model.set_generation_params(**params)
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
melody_is_wav = True
elif genre is not None and auto_prompt_path is not None:
auto_prompt = torch.load(auto_prompt_path)
if genre == 'Auto':
lang = check_language_by_text(lyric)
prompt_token = auto_prompt['Auto'][lang][np.random.randint(0, len(auto_prompt['Auto'][lang]))]
else:
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
pmt_wav = prompt_token[:,[0],:]
vocal_wav = prompt_token[:,[1],:]
bgm_wav = prompt_token[:,[2],:]
melody_is_wav = False
else:
pmt_wav = None
vocal_wav = None
bgm_wav = None
melody_is_wav = True
description = description if description else '.'
description = '[Musicality-very-high]' + ', ' + description
generate_inp = {
'lyrics': [lyric.replace(" ", " ")],
'descriptions': [description],
'melody_wavs': pmt_wav,
'vocal_wavs': vocal_wav,
'bgm_wavs': bgm_wav,
'melody_is_wav': melody_is_wav,
}
with torch.autocast(device_type="cuda", dtype=torch.float16):
tokens = self.model.generate(**generate_inp, return_tokens=True)
with torch.no_grad():
if melody_is_wav:
wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
else:
wav_seperate = self.model.generate_audio(tokens, gen_type=gen_type)
return wav_seperate[0]
|