File size: 10,090 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0878a
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0878a
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0878a
 
 
 
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0878a
64ec292
 
 
 
 
 
1e0878a
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import hydra
import torch
import torch.nn as nn
import torchaudio
from einops import rearrange
from ema_pytorch import EMA
from huggingface_hub import PyTorchModelHubMixin
from omegaconf import OmegaConf

from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor
from src.YingMusicSinger.models.model import Singer
from src.YingMusicSinger.utils.cnen_tokenizer import CNENTokenizer
from src.YingMusicSinger.utils.lrc_align import (
    align_lrc_put_to_front,
    align_lrc_sentence_level,
)
from src.YingMusicSinger.utils.mel_spectrogram import MelodySpectrogram
from src.YingMusicSinger.utils.stable_audio_tools.vae_copysyn import StableAudioInfer
from src.YingMusicSinger.utils.smooth_ending import smooth_ending

class YingMusicSinger(nn.Module, PyTorchModelHubMixin):
    def __init__(
        self,
        model_cfg_path,
        ckpt_path=None,
        vae_config_path=None,
        vae_ckpt_path=None,
        midi_teacher_ckpt_path=None,
        is_distilled=False,
        use_ema=True,
    ):
        super().__init__()
        self.cfg = OmegaConf.load(model_cfg_path)
        model_cls = hydra.utils.get_class(
            f"src.YingMusicSinger.models.{self.cfg.model.backbone}"
        )
        self.melody_input_source = self.cfg.model.melody_input_source
        self.is_tts_pretrain = self.cfg.model.is_tts_pretrain

        self.model = Singer(
            transformer=model_cls(
                **self.cfg.model.arch,
                text_num_embeds=self.cfg.datasets_cfg.text_num_embeds,
                mel_dim=self.cfg.model.mel_spec.n_mel_channels,
                use_guidance_scale_embed=is_distilled,
            ),
            mel_spec_kwargs=self.cfg.model.mel_spec,
            is_tts_pretrain=self.is_tts_pretrain,
            melody_input_source=self.melody_input_source,
            cka_disabled=self.cfg.model.cka_disabled,
            num_channels=None,
            extra_parameters=self.cfg.extra_parameters,
            distill_stage=1,
            use_guidance_scale_embed=is_distilled,
        )

        self.vae = StableAudioInfer(
            model_config_path=vae_config_path,
            model_ckpt_path=vae_ckpt_path,
        )

        self._need_midi = self.melody_input_source in {
            "some_pretrain",
            "some_pretrain_fuzzdisturb",
            "some_pretrain_postprocess_embedding",
        }
        self.midi_teacher = None
        if self._need_midi:
            self.midi_teacher = MIDIExtractor()
            if midi_teacher_ckpt_path is not None:
                self.midi_teacher._load_form_ckpt(midi_teacher_ckpt_path)
            for p in self.midi_teacher.parameters():
                p.requires_grad = False

            self.melody_spectrogram_extract = MelodySpectrogram()

        self.vae_frame_rate = 44100 / 2048

        if ckpt_path is not None:
            ckpt = torch.load(ckpt_path, map_location="cpu")
            if use_ema:
                ema_model = EMA(self.model, include_online_model=False)
                ema_model.load_state_dict(ckpt["ema_model_state_dict"])

                self.model = ema_model.ema_model
            else:
                self.model.load_state_dict(ckpt["model_state_dict"])

        self.cnen_tokenizer = CNENTokenizer()
        self.rear_silent_time = 1.0

    @property
    def device(self):
        return next(self.parameters()).device

    def prepare_input(
        self,
        ref_audio_path,
        melody_audio_path,
        ref_text,
        target_text,
        sil_len_to_end,
        lrc_align_mode,
    ):
        ref_audio, ref_audio_sr = torchaudio.load(ref_audio_path)
        silence = torch.zeros(ref_audio.shape[0], int(ref_audio_sr * sil_len_to_end))
        ref_wav = torch.cat([ref_audio, silence], dim=1)
        ref_latent = self.vae.encode_audio(ref_wav, in_sr=ref_audio_sr).transpose(
            1, 2
        )  # [B, T, D]


        melody_audio, melody_sr = torchaudio.load(melody_audio_path)
        silence = torch.zeros(melody_audio.shape[0], int(melody_sr * self.rear_silent_time))
        melody_wav = torch.cat([melody_audio, silence], dim=1)
        melody_latent = self.vae.encode_audio(melody_wav, in_sr=melody_sr).transpose(
            1, 2
        )  # [B, T, D]

        midi_in = torch.cat([ref_latent, melody_latent], dim=1)
        if self.is_tts_pretrain:
            midi_in = torch.zeros_like(midi_in)

        ref_latent_len = ref_latent.shape[1]
        total_len = int(ref_latent.shape[1] + melody_latent.shape[1])

        if self._need_midi:
            ref_mel = self.melody_spectrogram_extract(audio=ref_wav, sr=ref_audio_sr)
            melody_mel = self.melody_spectrogram_extract(audio=melody_wav, sr=melody_sr)
            melody_mel_spec = torch.cat([ref_mel, melody_mel], dim=2)
        else:
            raise NotImplementedError()

        assert isinstance(ref_text, str) and isinstance(target_text, str)
        text_list = [ref_text] + [target_text]

        if lrc_align_mode == "put_to_front":
            lrc_token, _ = align_lrc_put_to_front(
                tokenizer=self.cnen_tokenizer,
                lrc_start_times=None,
                lrc_lines=text_list,
                total_lens=total_len,
            )
        elif lrc_align_mode == "sentence_level":
            lrc_token, _ = align_lrc_sentence_level(
                tokenizer=self.cnen_tokenizer,
                lrc_start_times=[0.0, ref_latent_len / self.vae_frame_rate],
                lrc_lines=text_list,
                total_lens=total_len,
                vae_frame_rate=self.vae_frame_rate,
            )
        else:
            raise ValueError(f"Unsupported lrc_align_mode: {lrc_align_mode}")

        text_tokens = (
            torch.tensor(lrc_token, dtype=torch.int64).unsqueeze(0).to(self.device)
        )

        midi_p, bound_p = None, None
        if self._need_midi:
            with torch.no_grad():
                midi_p, bound_p = self.midi_teacher(melody_mel_spec.transpose(1, 2))

        return (
            ref_latent,
            ref_latent_len,
            text_tokens,
            total_len,
            midi_in,
            midi_p,
            bound_p,
        )

    def forward(
        self,
        ref_audio_path,
        melody_audio_path,
        ref_text,
        target_text,
        lrc_align_mode: str = "sentence_level",
        sil_len_to_end: float = 0.5,
        t_shift: float = 0.5,
        nfe_step: int = 32,
        cfg_strength: float = 3.0,
        seed: int = 666,
        is_tts_pretrain: bool = False,
    ):
        """
        Args:
            ref_audio_path:    Path to the reference audio (for timbre)
            melody_audio_path: Path to the melody reference audio (provides target duration and melody information)
            ref_text:          Text corresponding to the reference audio
            target_text:       Target text to be synthesized
            lrc_align_mode:    Lyric alignment mode "sentence_level" | "put_to_front"
            sil_len_to_end:    Duration of silence appended to the end of the reference audio (seconds)
            t_shift:           Sampling time offset
            nfe_step:          ODE sampling steps
            cfg_strength:      CFG strength
            seed:              Random seed
            is_tts_pretrain:   If True, melody is not provided (TTS mode)
        """
        ref_latent, ref_latent_len, text_tokens, total_len, midi_in, midi_p, bound_p = (
            self.prepare_input(
                ref_audio_path=ref_audio_path,
                melody_audio_path=melody_audio_path,
                ref_text=ref_text,
                target_text=target_text,
                sil_len_to_end=sil_len_to_end,
                lrc_align_mode=lrc_align_mode,
            )
        )

        assert midi_p is not None and bound_p is not None
        with torch.inference_mode():
            generated_latent, _ = self.model.sample(
                cond=ref_latent,
                midi_in=midi_in,
                text=text_tokens,
                duration=total_len,
                steps=nfe_step,
                cfg_strength=cfg_strength,
                sway_sampling_coef=None,
                use_epss=False,
                seed=seed,
                midi_p=midi_p,
                t_shift=t_shift,
                bound_p=bound_p,
                guidance_scale=cfg_strength,
            )
        generated_latent = generated_latent.to(torch.float32)
        generated_latent = generated_latent[:, ref_latent_len: -int(self.vae_frame_rate*self.rear_silent_time), :]
        generated_latent = generated_latent.permute(0, 2, 1)  # [B, D, T]

        generated_audio = self.vae.decode_audio(generated_latent)
        audio = rearrange(generated_audio, "b d n -> d (b n)")

        audio = audio.to(torch.float32).cpu()
        audio = smooth_ending(audio, 44100)
        return audio, 44100


if __name__ == "__main__":
    # === Export to HuggingFace safetensors (optional) ===
    # model = YingMusicSinger(
    #     model_cfg_path="src/YingMusicSinger/config/YingMusic_Singer.yaml",
    #     ckpt_path="ckpts/YingMusicSinger_model.pt",
    #     vae_config_path="src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json",
    #     vae_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
    #     midi_teacher_ckpt_path="ckpts/model_ckpt_steps_100000_simplified.ckpt",
    # )
    # model.save_pretrained("path/to/save")

    # === Inference Example ===
    model = YingMusicSinger.from_pretrained("ASLP-lab/YingMusic-Singer")
    model.to("cuda:0")
    model.eval()

    waveform, sample_rate = model(
        ref_audio_path="path/to/ref_audio",  # Timbre reference audio
        melody_audio_path="path/to/melody_audio",  # Melody-providing singing clip
        ref_text="oh the reason i hold on",  # Lyrics corresponding to ref_audio
        target_text="oldest book broken watch|bare feet in grassy spot",  # Modified target lyrics
        seed=42,
    )

    torchaudio.save("output.wav", waveform, sample_rate=sample_rate)
    print("Saved to output.wav")