Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| from model.DiffSynthSampler import DiffSynthSampler | |
| from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT | |
| import mido | |
| import pyrubberband as pyrb | |
| from tqdm import tqdm | |
| class NoteEvent: | |
| def __init__(self, note, velocity, start_time, duration): | |
| self.note = note | |
| self.velocity = velocity | |
| self.start_time = start_time # In ticks | |
| self.duration = duration # In ticks | |
| def __str__(self): | |
| return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}" | |
| class Track: | |
| def __init__(self, track, ticks_per_beat): | |
| self.tempo_events = self._parse_tempo_events(track) | |
| self.events = self._parse_note_events(track) | |
| self.ticks_per_beat = ticks_per_beat | |
| def _parse_tempo_events(self, track): | |
| tempo_events = [] | |
| current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat | |
| for msg in track: | |
| if msg.type == 'set_tempo': | |
| tempo_events.append((msg.time, msg.tempo)) | |
| elif not msg.is_meta: | |
| tempo_events.append((msg.time, current_tempo)) | |
| return tempo_events | |
| def _parse_note_events(self, track): | |
| events = [] | |
| start_time = 0 | |
| for msg in track: | |
| if not msg.is_meta: | |
| start_time += msg.time | |
| if msg.type == 'note_on' and msg.velocity > 0: | |
| note_on_time = start_time | |
| elif msg.type == 'note_on' and msg.velocity == 0: | |
| duration = start_time - note_on_time | |
| events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration)) | |
| return events | |
| def synthesize_track(self, diffSynthSampler, sample_rate=16000): | |
| track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32) | |
| current_tempo = 500000 # Start with default MIDI tempo 120 BPM | |
| duration_note_mapping = {} | |
| for event in tqdm(self.events[:50]): | |
| current_tempo = self._get_tempo_at(event.start_time) | |
| seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) | |
| start_time_sec = event.start_time * seconds_per_tick | |
| # Todo: set a minimum duration | |
| duration_sec = event.duration * seconds_per_tick | |
| duration_sec = max(duration_sec, 0.75) | |
| start_sample = int(start_time_sec * sample_rate) | |
| if not (str(duration_sec) in duration_note_mapping): | |
| note_sample = diffSynthSampler(event.velocity, duration_sec) | |
| duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample)) | |
| note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) | |
| end_sample = start_sample + len(note_audio) | |
| track_audio[start_sample:end_sample] += note_audio | |
| return track_audio | |
| def _get_tempo_at(self, time_tick): | |
| current_tempo = 500000 # Start with default MIDI tempo 120 BPM | |
| elapsed_ticks = 0 | |
| for tempo_change in self.tempo_events: | |
| if elapsed_ticks + tempo_change[0] > time_tick: | |
| return current_tempo | |
| elapsed_ticks += tempo_change[0] | |
| current_tempo = tempo_change[1] | |
| return current_tempo | |
| def _get_total_time(self): | |
| total_time = 0 | |
| current_tempo = 500000 # Start with default MIDI tempo 120 BPM | |
| for event in self.events: | |
| current_tempo = self._get_tempo_at(event.start_time) | |
| seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) | |
| total_time += event.duration * seconds_per_tick | |
| return total_time | |
| class DiffSynth: | |
| def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device, | |
| model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False): | |
| self.noise_prediction_model = noise_prediction_model | |
| self.VAE_quantizer = VAE_quantizer | |
| self.VAE_decoder = VAE_decoder | |
| self.device = device | |
| self.model_sample_rate = model_sample_rate | |
| self.timesteps = timesteps | |
| self.channels = channels | |
| self.freq_resolution = freq_resolution | |
| self.time_resolution = time_resolution | |
| self.height = int(freq_resolution/VAE_scale) | |
| self.VAE_scale = VAE_scale | |
| self.squared = squared | |
| self.text_encoder = text_encoder | |
| self.CLAP_tokenizer = CLAP_tokenizer | |
| # instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler) | |
| self.instruments_configs = instruments_configs | |
| self.diffSynthSamplers = {} | |
| self._update_instruments() | |
| def _update_instruments(self): | |
| def diffSynthSamplerWrapper(instruments_config): | |
| def diffSynthSampler(velocity, duration_sec, sample_rate=16000): | |
| condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device) | |
| sample_steps = instruments_config['sample_steps'] | |
| sampler = instruments_config['sampler'] | |
| noising_strength = instruments_config['noising_strength'] | |
| latent_representation = instruments_config['latent_representation'] | |
| attack = instruments_config['attack'] | |
| before_release = instruments_config['before_release'] | |
| assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate" | |
| width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale) | |
| mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True) | |
| mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32))) | |
| # mask = 1, freeze | |
| latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device) | |
| latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0 | |
| latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0 | |
| latent_representations, _ = \ | |
| mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width), | |
| noising_strength=noising_strength, condition=condition, | |
| guide_img=latent_representation, mask=latent_mask, return_tensor=True, | |
| sampler=sampler, | |
| use_dynamic_mask=True, end_noise_level_ratio=0.0, | |
| mask_flexivity=1.0) | |
| latent_representations = latent_representations[-1] | |
| quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations) | |
| # Todo: remove hard-coding | |
| flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder, | |
| quantized_latent_representations, | |
| resolution=( | |
| 512, | |
| width * self.VAE_scale), | |
| original_STFT_batch=None, | |
| ) | |
| return rec_signals[0] | |
| return diffSynthSampler | |
| for key in self.instruments_configs.keys(): | |
| self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key]) | |
| def get_music(self, mid, instrument_names, sample_rate=16000): | |
| tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] | |
| assert len(tracks) == len(instrument_names), f"len(tracks) = {len(tracks)} != {len(instrument_names)} = len(instrument_names)" | |
| track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)] | |
| # 将所有音轨填充至最长音轨的长度,以便它们可以被叠加 | |
| max_length = max(len(audio) for audio in track_audios) | |
| full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零 | |
| for audio in track_audios: | |
| # 音轨可能不够长,需要填充零 | |
| padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant') | |
| full_audio += padded_audio # 叠加音轨 | |
| return full_audio |