File size: 5,282 Bytes
2644f3e
 
c8c0ef5
2644f3e
 
 
 
 
 
544833b
2644f3e
c8c0ef5
2644f3e
f9e2d84
c8c0ef5
 
2644f3e
 
 
d658154
2644f3e
 
 
 
 
f6b176e
2644f3e
 
d658154
2644f3e
b000a9b
2644f3e
 
f9e2d84
c8c0ef5
 
 
 
 
 
 
 
 
 
 
 
 
eb8bfb7
c8c0ef5
c298f3c
e7ab0ec
c8c0ef5
 
 
 
2644f3e
 
 
c8c0ef5
 
 
2644f3e
 
 
 
 
 
 
d2ca82a
2644f3e
 
544833b
c8c0ef5
d658154
 
 
48275bf
 
d658154
 
 
 
 
2644f3e
e7ab0ec
57d225d
2644f3e
c8c0ef5
 
2644f3e
d658154
c8c0ef5
2644f3e
c8c0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7ab0ec
 
c8c0ef5
 
 
2644f3e
 
e55b921
c298f3c
e55b921
c298f3c
2644f3e
c8c0ef5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
import time

sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')

import torch
import numpy as np
from omegaconf import OmegaConf
from vllm import LLM, SamplingParams

from codeclm.models import builders
from codeclm.models.codeclm_gen import CodecLM_gen
from generate import check_language_by_text, load_audio


class LeVoInference(torch.nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()

        torch.backends.cudnn.enabled = False 
        OmegaConf.register_new_resolver("eval", lambda x: eval(x))
        OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
        OmegaConf.register_new_resolver("get_fname", lambda: 'default')
        OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))

        cfg_path = os.path.join(ckpt_path, 'config.yaml')
        self.cfg = OmegaConf.load(cfg_path)
        self.cfg.mode = 'inference'
        self.max_duration = self.cfg.max_dur

        audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
        if audio_tokenizer is not None:
            for param in audio_tokenizer.parameters():
                param.requires_grad = False
        print("Audio tokenizer successfully loaded!")
        audio_tokenizer = audio_tokenizer.eval().cuda()
        self.model_condition = CodecLM_gen(cfg=self.cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = self.max_duration)
        self.model_condition.condition_provider.conditioners.load_state_dict(torch.load(self.cfg.lm_checkpoint+"/conditioners_weights.pth"))
        self.embeded_eosp1 = torch.load(self.cfg.lm_checkpoint+'/embeded_eosp1.pt')
        print('Conditioner successfully loaded!')
        self.llm = LLM(
            model=self.cfg.lm_checkpoint,
            trust_remote_code=True,
            tensor_parallel_size=self.cfg.vllm.device_num,
            enforce_eager=True,
            dtype="bfloat16",  
            gpu_memory_utilization=0.65,
            max_num_seqs=8,
            tokenizer=None,
            skip_tokenizer_init=True,
            enable_prompt_embeds=True,
            enable_chunked_prefill=True,
        )

        self.default_params = dict(
            cfg_coef = 1.8,
            temperature = 0.8,
            top_k = 5000,
            top_p = 0.0,
            record_tokens = True,
            record_window = 50,
            extend_stride = 5,
            duration = self.max_duration,
        )

    def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
        params = {**self.default_params, **params}

        if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
            pmt_wav = load_audio(prompt_audio_path)
            melody_is_wav = True
        elif genre is not None and auto_prompt_path is not None:
            auto_prompt = torch.load(auto_prompt_path)
            lang = check_language_by_text(lyric)
            prompt_token = auto_prompt[genre][lang][np.random.randint(0, len(auto_prompt[genre][lang]))]
            pmt_wav = prompt_token[:,[0],:]
            melody_is_wav = False
        else:
            pmt_wav = None
            melody_is_wav = True

        description = description.lower() if description else '.'
        description = '[Musicality-very-high]' + ', ' + description
        generate_inp = {
            'descriptions': [lyric.replace("  ", " ")],
            'type_info': [description],
            'melody_wavs': pmt_wav,
            'melody_is_wav': melody_is_wav,
            'embeded_eosp1': self.embeded_eosp1,
        }
        fused_input, audio_qt_embs = self.model_condition.generate_condition(**generate_inp, return_tokens=True)
        prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
        allowed_token_ids = [x for x in range(self.cfg.lm.code_size+1) if x not in prompt_token]
        sampling_params = SamplingParams(
            max_tokens=self.cfg.audio_tokenizer_frame_rate*self.max_duration,
            temperature=params["temperature"],
            stop_token_ids=[self.cfg.lm.code_size],
            top_k=params["top_k"],
            frequency_penalty=0.2,
            seed=int(time.time() * 1000000) % (2**32) if self.cfg.vllm.cfg else -1,
            allowed_token_ids=allowed_token_ids,
            guidance_scale=params["cfg_coef"]
        )
        # 拆成现支持的batch 3 CFG形式
        prompts = [{"prompt_embeds": embed} for embed in fused_input]
        condi, uncondi = prompts[0], prompts[1]
        promptss = [condi, condi, uncondi]
        outputs = self.llm.generate(promptss, sampling_params=sampling_params)
        token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
        token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)

        with torch.no_grad():
            if melody_is_wav:
                wav_cfg = self.model_condition.generate_audio(token_ids_CFG, pmt_wav, chunked=True)
            else:
                wav_cfg = self.model_condition.generate_audio(token_ids_CFG, chunked=True)

        return wav_cfg[0]