from __future__ import annotations from typing import Callable import torch import torch.nn.functional as F from torch import nn from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor from src.YingMusicSinger.utils.common import ( default, exists, get_epss_timesteps, lens_to_mask, ) def interpolation_midi_continuous(midi_p, bound_p, total_len): """Temporally interpolate 3D melody latent to match target length.""" if midi_p.shape[1] != total_len: midi = ( F.interpolate( midi_p.clone().detach().transpose(1, 2), size=total_len, mode="linear", align_corners=False, ) .transpose(1, 2) .clone() .detach() ) if bound_p is not None: midi_bound = ( F.interpolate( bound_p.clone().detach().transpose(1, 2), size=total_len, mode="linear", align_corners=False, ) .transpose(1, 2) .clone() .detach() ) else: midi = midi_p.clone().detach() if bound_p is not None: midi_bound = bound_p.clone().detach() if bound_p is not None: return midi, midi_bound else: return midi def interpolation_midi_continuous_2_dim(midi_p, bound_p, total_len): """Temporally interpolate 2D melody latent to match target length.""" assert len(midi_p.shape) == 2 if midi_p.shape[1] != total_len: midi = ( F.interpolate( midi_p.unsqueeze(2).clone().detach().transpose(1, 2), size=total_len, mode="linear", align_corners=False, ) .transpose(1, 2) .clone() .detach() ) if bound_p: midi_bound = ( F.interpolate( bound_p.unsqueeze(2).clone().detach().transpose(1, 2), size=total_len, mode="linear", align_corners=False, ) .transpose(1, 2) .clone() .detach() ) else: midi = midi_p.clone().detach() if bound_p: midi_bound = bound_p.clone().detach() if bound_p: return midi.squeeze(2), midi_bound.squeeze(2) else: return midi.squeeze(2) class Singer(nn.Module): def __init__( self, transformer: nn.Module, is_tts_pretrain, melody_input_source, cka_disabled, distill_stage, use_guidance_scale_embed, sigma=0.0, odeint_kwargs: dict = dict(method="euler"), audio_drop_prob=0.3, cond_drop_prob=0.2, num_channels=None, mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), frac_lengths_mask: tuple[float, float] = (0.7, 1.0), extra_parameters=None, ): super().__init__() self.is_tts_pretrain = is_tts_pretrain if distill_stage is None: self.distill_stage = 0 else: self.distill_stage = int(distill_stage) self.use_guidance_scale_embed = use_guidance_scale_embed assert melody_input_source in { "student_model", "some_pretrain", "some_pretrain_fuzzdisturb", "some_pretrain_postprocess_embedding", "none", } from src.YingMusicSinger.melody.SmoothMelody import MIDIFuzzDisturb if melody_input_source == "some_pretrain_fuzzdisturb": self.smoothMelody_MIDIFuzzDisturb = MIDIFuzzDisturb( dim=extra_parameters.some_pretrain_fuzzdisturb.dim, drop_prob=extra_parameters.some_pretrain_fuzzdisturb.drop_prob, noise_scale=extra_parameters.some_pretrain_fuzzdisturb.noise_scale, blur_kernel=extra_parameters.some_pretrain_fuzzdisturb.blur_kernel, drop_type=extra_parameters.some_pretrain_fuzzdisturb.drop_type, ) from src.YingMusicSinger.melody.SmoothMelody import MIDIDigitalEmbedding if melody_input_source == "some_pretrain_postprocess_embedding": self.smoothMelody_MIDIDigitalEmbedding = MIDIDigitalEmbedding( embed_dim=extra_parameters.some_pretrain_postprocess_embedding.embed_dim, num_classes=extra_parameters.some_pretrain_postprocess_embedding.num_classes, mark_distinguish_scale=extra_parameters.some_pretrain_postprocess_embedding.mark_distinguish_scale, ) self.melody_input_source = melody_input_source self.cka_disabled = cka_disabled self.frac_lengths_mask = frac_lengths_mask num_channels = default(num_channels, mel_spec_kwargs.n_mel_channels) self.num_channels = num_channels # Classifier-free guidance drop probabilities self.audio_drop_prob = audio_drop_prob self.cond_drop_prob = cond_drop_prob # Transformer backbone self.transformer = transformer dim = transformer.dim self.dim = dim # Conditional flow matching self.sigma = sigma self.odeint_kwargs = odeint_kwargs # Melody extractor self.midi_extractor = MIDIExtractor(in_dim=num_channels) @property def device(self): return next(self.parameters()).device @torch.no_grad() def sample( self, cond: float["b n d"] | float["b nw"], # noqa: F722 text: int["b nt"] | list[str], # noqa: F722 duration: int | int["b"] | None = None, # noqa: F821 *, midi_in: float["b n d"] | None = None, lens: int["b"] | None = None, # noqa: F821 steps=32, cfg_strength=1.0, sway_sampling_coef=None, seed: int | None = None, max_duration=4096, # Maximum total length (including ICL prompt), ~190s vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 use_epss=True, no_ref_audio=False, duplicate_test=False, t_inter=0.1, t_shift=1.0, # Sampling timestep shift (ZipVoice-style) guidance_scale=None, edit_mask=None, midi_p=None, bound_p=None, enable_melody_control=True, ): self.eval() assert isinstance(cond, torch.Tensor) assert not edit_mask, "edit_mask is not supported in this mode" assert not duplicate_test, "duplicate_test is not supported in this mode" if self.melody_input_source == "student_model": assert midi_p is None and bound_p is None elif self.melody_input_source in { "some_pretrain", "some_pretrain_fuzzdisturb", "some_pretrain_postprocess_embedding", }: assert midi_p is not None and bound_p is not None elif self.melody_input_source == "none": assert midi_p is None and bound_p is None else: raise ValueError( f"Unsupported melody_input_source: {self.melody_input_source}" ) # duration is the total latent sequence length assert duration cond = cond.to(next(self.parameters()).dtype) # Extract or interpolate melody representation if self.melody_input_source == "student_model": midi, midi_bound = self.midi_extractor(midi_in) elif self.melody_input_source == "some_pretrain": midi, midi_bound = interpolation_midi_continuous( midi_p=midi_p, bound_p=bound_p, total_len=text.shape[1] ) elif self.melody_input_source == "some_pretrain_fuzzdisturb": midi, midi_bound = interpolation_midi_continuous( midi_p=midi_p, bound_p=bound_p, total_len=text.shape[1] ) midi = self.smoothMelody_MIDIFuzzDisturb(midi) elif self.melody_input_source == "some_pretrain_postprocess_embedding": midi_after_postprocess, _ = self.midi_extractor.postprocess( midi=midi_p, bounds=bound_p, with_expand=True ) midi = interpolation_midi_continuous_2_dim( midi_p=midi_after_postprocess, bound_p=None, total_len=text.shape[1] ) midi = self.smoothMelody_MIDIDigitalEmbedding(midi) midi_bound = None elif self.melody_input_source == "none": midi = torch.zeros( text.shape[0], text.shape[1], 128, dtype=cond.dtype, device=text.device ) midi_bound = None else: raise NotImplementedError() batch, cond_seq_len, device = *cond.shape[:2], cond.device if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) assert isinstance(text, torch.Tensor) cond_mask = lens_to_mask(lens) if edit_mask is not None: cond_mask = cond_mask & edit_mask if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) # Duration must be at least max(text_len, audio_prompt_len) + 1 duration = torch.maximum( torch.maximum((text != 0).sum(dim=-1), lens) + 1, duration ) duration = duration.clamp(max=max_duration) max_duration = duration.amax() # Duplicate test: interpolate between noise and conditioning if duplicate_test: test_cond = F.pad( cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 ) # Zero-pad conditioning latent to max_duration cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) if no_ref_audio: cond = torch.zeros_like(cond) cond_mask = F.pad( cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False ) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) assert max_duration == midi.shape[1] # Zero out melody in prompt region; optionally disable melody control entirely if enable_melody_control: midi = torch.where(cond_mask, torch.zeros_like(midi), midi) else: midi = torch.zeros_like(midi) if self.is_tts_pretrain: midi = torch.zeros_like(midi) # For batched inference, explicit mask prevents causal attention fallback if batch > 1: mask = lens_to_mask(duration) else: mask = None # ODE velocity function def fn(t, x): if cfg_strength < 1e-5: # No classifier-free guidance pred, _ = self.transformer( x=x, cond=step_cond, text=text, midi=midi, time=t, mask=mask, drop_audio_cond=False, drop_text=False, drop_midi=not enable_melody_control, cache=False, ) return pred else: if self.use_guidance_scale_embed: # Distilled model with built-in CFG assert enable_melody_control pred_cfg, _ = self.transformer( x=x, cond=step_cond, text=text, midi=midi, time=t, mask=mask, drop_audio_cond=False, drop_text=False, drop_midi=not enable_melody_control, cache=False, guidance_scale=torch.tensor([guidance_scale], device=device), ) print( f"CFG 参数调节无作用! 蒸馏之后的,输入CFG为 guidance_scale={guidance_scale}" ) return pred_cfg else: # Standard CFG: cond + uncond forward # BUG If enable_melody_control is False, there might be a slight issue here assert guidance_scale is not None pred_cfg, _ = self.transformer( x=x, cond=step_cond, text=text, midi=midi, time=t, mask=mask, cfg_infer=True, cache=False, cfg_infer_ids=(True, False, False, True), ) pred, pred_drop_all_cond = torch.chunk(pred_cfg, 2, dim=0) return pred + (pred - pred_drop_all_cond) * float(guidance_scale) # Generate initial noise (per-sample seeding for batch consistency) y0 = [] for dur in duration: if exists(seed): torch.manual_seed(seed) y0.append( torch.randn( dur, self.num_channels, device=self.device, dtype=step_cond.dtype ) ) y0 = pad_sequence(y0, padding_value=0, batch_first=True) t_start = 0 if duplicate_test: t_start = t_inter y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) # Build timestep schedule assert not use_epss and sway_sampling_coef is None, ( "Use timestep shift instead of the strategy in F5" ) if t_start == 0 and use_epss: # Empirically Pruned Step Sampling for low NFE t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) else: t = torch.linspace( t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype ) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) # Apply timestep shift t = t_shift * t / (1 + (t_shift - 1) * t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) self.transformer.clear_cache() sampled = trajectory[-1] out = sampled if exists(vocoder): out = out.permute(0, 2, 1) out = vocoder(out) return out, trajectory