Spaces:
Running on Zero
Running on Zero
| import hydra | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| from einops import rearrange | |
| from ema_pytorch import EMA | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from omegaconf import OmegaConf | |
| from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor | |
| from src.YingMusicSinger.models.model import Singer | |
| from src.YingMusicSinger.utils.cnen_tokenizer import CNENTokenizer | |
| from src.YingMusicSinger.utils.lrc_align import ( | |
| align_lrc_put_to_front, | |
| align_lrc_sentence_level, | |
| ) | |
| from src.YingMusicSinger.utils.mel_spectrogram import MelodySpectrogram | |
| from src.YingMusicSinger.utils.stable_audio_tools.vae_copysyn import StableAudioInfer | |
| from src.YingMusicSinger.utils.smooth_ending import smooth_ending | |
| class YingMusicSinger(nn.Module, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| model_cfg_path, | |
| ckpt_path=None, | |
| vae_config_path=None, | |
| vae_ckpt_path=None, | |
| midi_teacher_ckpt_path=None, | |
| is_distilled=False, | |
| use_ema=True, | |
| ): | |
| super().__init__() | |
| self.cfg = OmegaConf.load(model_cfg_path) | |
| model_cls = hydra.utils.get_class( | |
| f"src.YingMusicSinger.models.{self.cfg.model.backbone}" | |
| ) | |
| self.melody_input_source = self.cfg.model.melody_input_source | |
| self.is_tts_pretrain = self.cfg.model.is_tts_pretrain | |
| self.model = Singer( | |
| transformer=model_cls( | |
| **self.cfg.model.arch, | |
| text_num_embeds=self.cfg.datasets_cfg.text_num_embeds, | |
| mel_dim=self.cfg.model.mel_spec.n_mel_channels, | |
| use_guidance_scale_embed=is_distilled, | |
| ), | |
| mel_spec_kwargs=self.cfg.model.mel_spec, | |
| is_tts_pretrain=self.is_tts_pretrain, | |
| melody_input_source=self.melody_input_source, | |
| cka_disabled=self.cfg.model.cka_disabled, | |
| num_channels=None, | |
| extra_parameters=self.cfg.extra_parameters, | |
| distill_stage=1, | |
| use_guidance_scale_embed=is_distilled, | |
| ) | |
| self.vae = StableAudioInfer( | |
| model_config_path=vae_config_path, | |
| model_ckpt_path=vae_ckpt_path, | |
| ) | |
| self._need_midi = self.melody_input_source in { | |
| "some_pretrain", | |
| "some_pretrain_fuzzdisturb", | |
| "some_pretrain_postprocess_embedding", | |
| } | |
| self.midi_teacher = None | |
| if self._need_midi: | |
| self.midi_teacher = MIDIExtractor() | |
| if midi_teacher_ckpt_path is not None: | |
| self.midi_teacher._load_form_ckpt(midi_teacher_ckpt_path) | |
| for p in self.midi_teacher.parameters(): | |
| p.requires_grad = False | |
| self.melody_spectrogram_extract = MelodySpectrogram() | |
| self.vae_frame_rate = 44100 / 2048 | |
| if ckpt_path is not None: | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| if use_ema: | |
| ema_model = EMA(self.model, include_online_model=False) | |
| ema_model.load_state_dict(ckpt["ema_model_state_dict"]) | |
| self.model = ema_model.ema_model | |
| else: | |
| self.model.load_state_dict(ckpt["model_state_dict"]) | |
| self.cnen_tokenizer = CNENTokenizer() | |
| self.rear_silent_time = 1.0 | |
| def device(self): | |
| return next(self.parameters()).device | |
| def prepare_input( | |
| self, | |
| ref_audio_path, | |
| melody_audio_path, | |
| ref_text, | |
| target_text, | |
| sil_len_to_end, | |
| lrc_align_mode, | |
| ): | |
| ref_audio, ref_audio_sr = torchaudio.load(ref_audio_path) | |
| silence = torch.zeros(ref_audio.shape[0], int(ref_audio_sr * sil_len_to_end)) | |
| ref_wav = torch.cat([ref_audio, silence], dim=1) | |
| ref_latent = self.vae.encode_audio(ref_wav, in_sr=ref_audio_sr).transpose( | |
| 1, 2 | |
| ) # [B, T, D] | |
| melody_audio, melody_sr = torchaudio.load(melody_audio_path) | |
| silence = torch.zeros(melody_audio.shape[0], int(melody_sr * self.rear_silent_time)) | |
| melody_wav = torch.cat([melody_audio, silence], dim=1) | |
| melody_latent = self.vae.encode_audio(melody_wav, in_sr=melody_sr).transpose( | |
| 1, 2 | |
| ) # [B, T, D] | |
| midi_in = torch.cat([ref_latent, melody_latent], dim=1) | |
| if self.is_tts_pretrain: | |
| midi_in = torch.zeros_like(midi_in) | |
| ref_latent_len = ref_latent.shape[1] | |
| total_len = int(ref_latent.shape[1] + melody_latent.shape[1]) | |
| if self._need_midi: | |
| ref_mel = self.melody_spectrogram_extract(audio=ref_wav, sr=ref_audio_sr) | |
| melody_mel = self.melody_spectrogram_extract(audio=melody_wav, sr=melody_sr) | |
| melody_mel_spec = torch.cat([ref_mel, melody_mel], dim=2) | |
| else: | |
| raise NotImplementedError() | |
| assert isinstance(ref_text, str) and isinstance(target_text, str) | |
| text_list = [ref_text] + [target_text] | |
| if lrc_align_mode == "put_to_front": | |
| lrc_token, _ = align_lrc_put_to_front( | |
| tokenizer=self.cnen_tokenizer, | |
| lrc_start_times=None, | |
| lrc_lines=text_list, | |
| total_lens=total_len, | |
| ) | |
| elif lrc_align_mode == "sentence_level": | |
| lrc_token, _ = align_lrc_sentence_level( | |
| tokenizer=self.cnen_tokenizer, | |
| lrc_start_times=[0.0, ref_latent_len / self.vae_frame_rate], | |
| lrc_lines=text_list, | |
| total_lens=total_len, | |
| vae_frame_rate=self.vae_frame_rate, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported lrc_align_mode: {lrc_align_mode}") | |
| text_tokens = ( | |
| torch.tensor(lrc_token, dtype=torch.int64).unsqueeze(0).to(self.device) | |
| ) | |
| midi_p, bound_p = None, None | |
| if self._need_midi: | |
| with torch.no_grad(): | |
| midi_p, bound_p = self.midi_teacher(melody_mel_spec.transpose(1, 2)) | |
| return ( | |
| ref_latent, | |
| ref_latent_len, | |
| text_tokens, | |
| total_len, | |
| midi_in, | |
| midi_p, | |
| bound_p, | |
| ) | |
| def forward( | |
| self, | |
| ref_audio_path, | |
| melody_audio_path, | |
| ref_text, | |
| target_text, | |
| lrc_align_mode: str = "sentence_level", | |
| sil_len_to_end: float = 0.5, | |
| t_shift: float = 0.5, | |
| nfe_step: int = 32, | |
| cfg_strength: float = 3.0, | |
| seed: int = 666, | |
| is_tts_pretrain: bool = False, | |
| ): | |
| """ | |
| Args: | |
| ref_audio_path: Path to the reference audio (for timbre) | |
| melody_audio_path: Path to the melody reference audio (provides target duration and melody information) | |
| ref_text: Text corresponding to the reference audio | |
| target_text: Target text to be synthesized | |
| lrc_align_mode: Lyric alignment mode "sentence_level" | "put_to_front" | |
| sil_len_to_end: Duration of silence appended to the end of the reference audio (seconds) | |
| t_shift: Sampling time offset | |
| nfe_step: ODE sampling steps | |
| cfg_strength: CFG strength | |
| seed: Random seed | |
| is_tts_pretrain: If True, melody is not provided (TTS mode) | |
| """ | |
| ref_latent, ref_latent_len, text_tokens, total_len, midi_in, midi_p, bound_p = ( | |
| self.prepare_input( | |
| ref_audio_path=ref_audio_path, | |
| melody_audio_path=melody_audio_path, | |
| ref_text=ref_text, | |
| target_text=target_text, | |
| sil_len_to_end=sil_len_to_end, | |
| lrc_align_mode=lrc_align_mode, | |
| ) | |
| ) | |
| assert midi_p is not None and bound_p is not None | |
| with torch.inference_mode(): | |
| generated_latent, _ = self.model.sample( | |
| cond=ref_latent, | |
| midi_in=midi_in, | |
| text=text_tokens, | |
| duration=total_len, | |
| steps=nfe_step, | |
| cfg_strength=cfg_strength, | |
| sway_sampling_coef=None, | |
| use_epss=False, | |
| seed=seed, | |
| midi_p=midi_p, | |
| t_shift=t_shift, | |
| bound_p=bound_p, | |
| guidance_scale=cfg_strength, | |
| ) | |
| generated_latent = generated_latent.to(torch.float32) | |
| generated_latent = generated_latent[:, ref_latent_len: -int(self.vae_frame_rate*self.rear_silent_time), :] | |
| generated_latent = generated_latent.permute(0, 2, 1) # [B, D, T] | |
| generated_audio = self.vae.decode_audio(generated_latent) | |
| audio = rearrange(generated_audio, "b d n -> d (b n)") | |
| audio = audio.to(torch.float32).cpu() | |
| audio = smooth_ending(audio, 44100) | |
| return audio, 44100 | |
| if __name__ == "__main__": | |
| # === Export to HuggingFace safetensors (optional) === | |
| # model = YingMusicSinger( | |
| # model_cfg_path="src/YingMusicSinger/config/YingMusic_Singer.yaml", | |
| # ckpt_path="ckpts/YingMusicSinger_model.pt", | |
| # vae_config_path="src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json", | |
| # vae_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt", | |
| # midi_teacher_ckpt_path="ckpts/model_ckpt_steps_100000_simplified.ckpt", | |
| # ) | |
| # model.save_pretrained("path/to/save") | |
| # === Inference Example === | |
| model = YingMusicSinger.from_pretrained("ASLP-lab/YingMusic-Singer") | |
| model.to("cuda:0") | |
| model.eval() | |
| waveform, sample_rate = model( | |
| ref_audio_path="path/to/ref_audio", # Timbre reference audio | |
| melody_audio_path="path/to/melody_audio", # Melody-providing singing clip | |
| ref_text="oh the reason i hold on", # Lyrics corresponding to ref_audio | |
| target_text="oldest book broken watch|bare feet in grassy spot", # Modified target lyrics | |
| seed=42, | |
| ) | |
| torchaudio.save("output.wav", waveform, sample_rate=sample_rate) | |
| print("Saved to output.wav") | |