xjsc0's picture
11
1e0878a
import hydra
import torch
import torch.nn as nn
import torchaudio
from einops import rearrange
from ema_pytorch import EMA
from huggingface_hub import PyTorchModelHubMixin
from omegaconf import OmegaConf
from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor
from src.YingMusicSinger.models.model import Singer
from src.YingMusicSinger.utils.cnen_tokenizer import CNENTokenizer
from src.YingMusicSinger.utils.lrc_align import (
align_lrc_put_to_front,
align_lrc_sentence_level,
)
from src.YingMusicSinger.utils.mel_spectrogram import MelodySpectrogram
from src.YingMusicSinger.utils.stable_audio_tools.vae_copysyn import StableAudioInfer
from src.YingMusicSinger.utils.smooth_ending import smooth_ending
class YingMusicSinger(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
model_cfg_path,
ckpt_path=None,
vae_config_path=None,
vae_ckpt_path=None,
midi_teacher_ckpt_path=None,
is_distilled=False,
use_ema=True,
):
super().__init__()
self.cfg = OmegaConf.load(model_cfg_path)
model_cls = hydra.utils.get_class(
f"src.YingMusicSinger.models.{self.cfg.model.backbone}"
)
self.melody_input_source = self.cfg.model.melody_input_source
self.is_tts_pretrain = self.cfg.model.is_tts_pretrain
self.model = Singer(
transformer=model_cls(
**self.cfg.model.arch,
text_num_embeds=self.cfg.datasets_cfg.text_num_embeds,
mel_dim=self.cfg.model.mel_spec.n_mel_channels,
use_guidance_scale_embed=is_distilled,
),
mel_spec_kwargs=self.cfg.model.mel_spec,
is_tts_pretrain=self.is_tts_pretrain,
melody_input_source=self.melody_input_source,
cka_disabled=self.cfg.model.cka_disabled,
num_channels=None,
extra_parameters=self.cfg.extra_parameters,
distill_stage=1,
use_guidance_scale_embed=is_distilled,
)
self.vae = StableAudioInfer(
model_config_path=vae_config_path,
model_ckpt_path=vae_ckpt_path,
)
self._need_midi = self.melody_input_source in {
"some_pretrain",
"some_pretrain_fuzzdisturb",
"some_pretrain_postprocess_embedding",
}
self.midi_teacher = None
if self._need_midi:
self.midi_teacher = MIDIExtractor()
if midi_teacher_ckpt_path is not None:
self.midi_teacher._load_form_ckpt(midi_teacher_ckpt_path)
for p in self.midi_teacher.parameters():
p.requires_grad = False
self.melody_spectrogram_extract = MelodySpectrogram()
self.vae_frame_rate = 44100 / 2048
if ckpt_path is not None:
ckpt = torch.load(ckpt_path, map_location="cpu")
if use_ema:
ema_model = EMA(self.model, include_online_model=False)
ema_model.load_state_dict(ckpt["ema_model_state_dict"])
self.model = ema_model.ema_model
else:
self.model.load_state_dict(ckpt["model_state_dict"])
self.cnen_tokenizer = CNENTokenizer()
self.rear_silent_time = 1.0
@property
def device(self):
return next(self.parameters()).device
def prepare_input(
self,
ref_audio_path,
melody_audio_path,
ref_text,
target_text,
sil_len_to_end,
lrc_align_mode,
):
ref_audio, ref_audio_sr = torchaudio.load(ref_audio_path)
silence = torch.zeros(ref_audio.shape[0], int(ref_audio_sr * sil_len_to_end))
ref_wav = torch.cat([ref_audio, silence], dim=1)
ref_latent = self.vae.encode_audio(ref_wav, in_sr=ref_audio_sr).transpose(
1, 2
) # [B, T, D]
melody_audio, melody_sr = torchaudio.load(melody_audio_path)
silence = torch.zeros(melody_audio.shape[0], int(melody_sr * self.rear_silent_time))
melody_wav = torch.cat([melody_audio, silence], dim=1)
melody_latent = self.vae.encode_audio(melody_wav, in_sr=melody_sr).transpose(
1, 2
) # [B, T, D]
midi_in = torch.cat([ref_latent, melody_latent], dim=1)
if self.is_tts_pretrain:
midi_in = torch.zeros_like(midi_in)
ref_latent_len = ref_latent.shape[1]
total_len = int(ref_latent.shape[1] + melody_latent.shape[1])
if self._need_midi:
ref_mel = self.melody_spectrogram_extract(audio=ref_wav, sr=ref_audio_sr)
melody_mel = self.melody_spectrogram_extract(audio=melody_wav, sr=melody_sr)
melody_mel_spec = torch.cat([ref_mel, melody_mel], dim=2)
else:
raise NotImplementedError()
assert isinstance(ref_text, str) and isinstance(target_text, str)
text_list = [ref_text] + [target_text]
if lrc_align_mode == "put_to_front":
lrc_token, _ = align_lrc_put_to_front(
tokenizer=self.cnen_tokenizer,
lrc_start_times=None,
lrc_lines=text_list,
total_lens=total_len,
)
elif lrc_align_mode == "sentence_level":
lrc_token, _ = align_lrc_sentence_level(
tokenizer=self.cnen_tokenizer,
lrc_start_times=[0.0, ref_latent_len / self.vae_frame_rate],
lrc_lines=text_list,
total_lens=total_len,
vae_frame_rate=self.vae_frame_rate,
)
else:
raise ValueError(f"Unsupported lrc_align_mode: {lrc_align_mode}")
text_tokens = (
torch.tensor(lrc_token, dtype=torch.int64).unsqueeze(0).to(self.device)
)
midi_p, bound_p = None, None
if self._need_midi:
with torch.no_grad():
midi_p, bound_p = self.midi_teacher(melody_mel_spec.transpose(1, 2))
return (
ref_latent,
ref_latent_len,
text_tokens,
total_len,
midi_in,
midi_p,
bound_p,
)
def forward(
self,
ref_audio_path,
melody_audio_path,
ref_text,
target_text,
lrc_align_mode: str = "sentence_level",
sil_len_to_end: float = 0.5,
t_shift: float = 0.5,
nfe_step: int = 32,
cfg_strength: float = 3.0,
seed: int = 666,
is_tts_pretrain: bool = False,
):
"""
Args:
ref_audio_path: Path to the reference audio (for timbre)
melody_audio_path: Path to the melody reference audio (provides target duration and melody information)
ref_text: Text corresponding to the reference audio
target_text: Target text to be synthesized
lrc_align_mode: Lyric alignment mode "sentence_level" | "put_to_front"
sil_len_to_end: Duration of silence appended to the end of the reference audio (seconds)
t_shift: Sampling time offset
nfe_step: ODE sampling steps
cfg_strength: CFG strength
seed: Random seed
is_tts_pretrain: If True, melody is not provided (TTS mode)
"""
ref_latent, ref_latent_len, text_tokens, total_len, midi_in, midi_p, bound_p = (
self.prepare_input(
ref_audio_path=ref_audio_path,
melody_audio_path=melody_audio_path,
ref_text=ref_text,
target_text=target_text,
sil_len_to_end=sil_len_to_end,
lrc_align_mode=lrc_align_mode,
)
)
assert midi_p is not None and bound_p is not None
with torch.inference_mode():
generated_latent, _ = self.model.sample(
cond=ref_latent,
midi_in=midi_in,
text=text_tokens,
duration=total_len,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=None,
use_epss=False,
seed=seed,
midi_p=midi_p,
t_shift=t_shift,
bound_p=bound_p,
guidance_scale=cfg_strength,
)
generated_latent = generated_latent.to(torch.float32)
generated_latent = generated_latent[:, ref_latent_len: -int(self.vae_frame_rate*self.rear_silent_time), :]
generated_latent = generated_latent.permute(0, 2, 1) # [B, D, T]
generated_audio = self.vae.decode_audio(generated_latent)
audio = rearrange(generated_audio, "b d n -> d (b n)")
audio = audio.to(torch.float32).cpu()
audio = smooth_ending(audio, 44100)
return audio, 44100
if __name__ == "__main__":
# === Export to HuggingFace safetensors (optional) ===
# model = YingMusicSinger(
# model_cfg_path="src/YingMusicSinger/config/YingMusic_Singer.yaml",
# ckpt_path="ckpts/YingMusicSinger_model.pt",
# vae_config_path="src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json",
# vae_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
# midi_teacher_ckpt_path="ckpts/model_ckpt_steps_100000_simplified.ckpt",
# )
# model.save_pretrained("path/to/save")
# === Inference Example ===
model = YingMusicSinger.from_pretrained("ASLP-lab/YingMusic-Singer")
model.to("cuda:0")
model.eval()
waveform, sample_rate = model(
ref_audio_path="path/to/ref_audio", # Timbre reference audio
melody_audio_path="path/to/melody_audio", # Melody-providing singing clip
ref_text="oh the reason i hold on", # Lyrics corresponding to ref_audio
target_text="oldest book broken watch|bare feet in grassy spot", # Modified target lyrics
seed=42,
)
torchaudio.save("output.wav", waveform, sample_rate=sample_rate)
print("Saved to output.wav")