WeixuanYuan's picture
Upload 49 files
2b389c5 verified
import numpy as np
import torch
from model.DiffSynthSampler import DiffSynthSampler
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
import mido
import pyrubberband as pyrb
from tqdm import tqdm
class NoteEvent:
def __init__(self, note, velocity, start_time, duration):
self.note = note
self.velocity = velocity
self.start_time = start_time # In ticks
self.duration = duration # In ticks
def __str__(self):
return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"
class Track:
def __init__(self, track, ticks_per_beat):
self.tempo_events = self._parse_tempo_events(track)
self.events = self._parse_note_events(track)
self.ticks_per_beat = ticks_per_beat
def _parse_tempo_events(self, track):
tempo_events = []
current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
for msg in track:
if msg.type == 'set_tempo':
tempo_events.append((msg.time, msg.tempo))
elif not msg.is_meta:
tempo_events.append((msg.time, current_tempo))
return tempo_events
def _parse_note_events(self, track):
events = []
start_time = 0
for msg in track:
if not msg.is_meta:
start_time += msg.time
if msg.type == 'note_on' and msg.velocity > 0:
note_on_time = start_time
elif msg.type == 'note_on' and msg.velocity == 0:
duration = start_time - note_on_time
events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
return events
def synthesize_track(self, diffSynthSampler, sample_rate=16000):
track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
duration_note_mapping = {}
for event in tqdm(self.events[:50]):
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
start_time_sec = event.start_time * seconds_per_tick
# Todo: set a minimum duration
duration_sec = event.duration * seconds_per_tick
duration_sec = max(duration_sec, 0.75)
start_sample = int(start_time_sec * sample_rate)
if not (str(duration_sec) in duration_note_mapping):
note_sample = diffSynthSampler(event.velocity, duration_sec)
duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))
note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
end_sample = start_sample + len(note_audio)
track_audio[start_sample:end_sample] += note_audio
return track_audio
def _get_tempo_at(self, time_tick):
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
elapsed_ticks = 0
for tempo_change in self.tempo_events:
if elapsed_ticks + tempo_change[0] > time_tick:
return current_tempo
elapsed_ticks += tempo_change[0]
current_tempo = tempo_change[1]
return current_tempo
def _get_total_time(self):
total_time = 0
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
for event in self.events:
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
total_time += event.duration * seconds_per_tick
return total_time
class DiffSynth:
def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,
model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):
self.noise_prediction_model = noise_prediction_model
self.VAE_quantizer = VAE_quantizer
self.VAE_decoder = VAE_decoder
self.device = device
self.model_sample_rate = model_sample_rate
self.timesteps = timesteps
self.channels = channels
self.freq_resolution = freq_resolution
self.time_resolution = time_resolution
self.height = int(freq_resolution/VAE_scale)
self.VAE_scale = VAE_scale
self.squared = squared
self.text_encoder = text_encoder
self.CLAP_tokenizer = CLAP_tokenizer
# instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
self.instruments_configs = instruments_configs
self.diffSynthSamplers = {}
self._update_instruments()
def _update_instruments(self):
def diffSynthSamplerWrapper(instruments_config):
def diffSynthSampler(velocity, duration_sec, sample_rate=16000):
condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
sample_steps = instruments_config['sample_steps']
sampler = instruments_config['sampler']
noising_strength = instruments_config['noising_strength']
latent_representation = instruments_config['latent_representation']
attack = instruments_config['attack']
before_release = instruments_config['before_release']
assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"
width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)
mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))
# mask = 1, freeze
latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0
latent_representations, _ = \
mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
noising_strength=noising_strength, condition=condition,
guide_img=latent_representation, mask=latent_mask, return_tensor=True,
sampler=sampler,
use_dynamic_mask=True, end_noise_level_ratio=0.0,
mask_flexivity=1.0)
latent_representations = latent_representations[-1]
quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
# Todo: remove hard-coding
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
quantized_latent_representations,
resolution=(
512,
width * self.VAE_scale),
original_STFT_batch=None,
)
return rec_signals[0]
return diffSynthSampler
for key in self.instruments_configs.keys():
self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])
def get_music(self, mid, instrument_names, sample_rate=16000):
tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
assert len(tracks) == len(instrument_names), f"len(tracks) = {len(tracks)} != {len(instrument_names)} = len(instrument_names)"
track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]
# 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
max_length = max(len(audio) for audio in track_audios)
full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零
for audio in track_audios:
# 音轨可能不够长,需要填充零
padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
full_audio += padded_audio # 叠加音轨
return full_audio