Spaces:
Paused
Paused
File size: 9,300 Bytes
2b389c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import numpy as np
import torch
from model.DiffSynthSampler import DiffSynthSampler
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
import mido
import pyrubberband as pyrb
from tqdm import tqdm
class NoteEvent:
def __init__(self, note, velocity, start_time, duration):
self.note = note
self.velocity = velocity
self.start_time = start_time # In ticks
self.duration = duration # In ticks
def __str__(self):
return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"
class Track:
def __init__(self, track, ticks_per_beat):
self.tempo_events = self._parse_tempo_events(track)
self.events = self._parse_note_events(track)
self.ticks_per_beat = ticks_per_beat
def _parse_tempo_events(self, track):
tempo_events = []
current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
for msg in track:
if msg.type == 'set_tempo':
tempo_events.append((msg.time, msg.tempo))
elif not msg.is_meta:
tempo_events.append((msg.time, current_tempo))
return tempo_events
def _parse_note_events(self, track):
events = []
start_time = 0
for msg in track:
if not msg.is_meta:
start_time += msg.time
if msg.type == 'note_on' and msg.velocity > 0:
note_on_time = start_time
elif msg.type == 'note_on' and msg.velocity == 0:
duration = start_time - note_on_time
events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
return events
def synthesize_track(self, diffSynthSampler, sample_rate=16000):
track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
duration_note_mapping = {}
for event in tqdm(self.events[:50]):
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
start_time_sec = event.start_time * seconds_per_tick
# Todo: set a minimum duration
duration_sec = event.duration * seconds_per_tick
duration_sec = max(duration_sec, 0.75)
start_sample = int(start_time_sec * sample_rate)
if not (str(duration_sec) in duration_note_mapping):
note_sample = diffSynthSampler(event.velocity, duration_sec)
duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))
note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
end_sample = start_sample + len(note_audio)
track_audio[start_sample:end_sample] += note_audio
return track_audio
def _get_tempo_at(self, time_tick):
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
elapsed_ticks = 0
for tempo_change in self.tempo_events:
if elapsed_ticks + tempo_change[0] > time_tick:
return current_tempo
elapsed_ticks += tempo_change[0]
current_tempo = tempo_change[1]
return current_tempo
def _get_total_time(self):
total_time = 0
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
for event in self.events:
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
total_time += event.duration * seconds_per_tick
return total_time
class DiffSynth:
def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,
model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):
self.noise_prediction_model = noise_prediction_model
self.VAE_quantizer = VAE_quantizer
self.VAE_decoder = VAE_decoder
self.device = device
self.model_sample_rate = model_sample_rate
self.timesteps = timesteps
self.channels = channels
self.freq_resolution = freq_resolution
self.time_resolution = time_resolution
self.height = int(freq_resolution/VAE_scale)
self.VAE_scale = VAE_scale
self.squared = squared
self.text_encoder = text_encoder
self.CLAP_tokenizer = CLAP_tokenizer
# instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
self.instruments_configs = instruments_configs
self.diffSynthSamplers = {}
self._update_instruments()
def _update_instruments(self):
def diffSynthSamplerWrapper(instruments_config):
def diffSynthSampler(velocity, duration_sec, sample_rate=16000):
condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
sample_steps = instruments_config['sample_steps']
sampler = instruments_config['sampler']
noising_strength = instruments_config['noising_strength']
latent_representation = instruments_config['latent_representation']
attack = instruments_config['attack']
before_release = instruments_config['before_release']
assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"
width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)
mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))
# mask = 1, freeze
latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0
latent_representations, _ = \
mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
noising_strength=noising_strength, condition=condition,
guide_img=latent_representation, mask=latent_mask, return_tensor=True,
sampler=sampler,
use_dynamic_mask=True, end_noise_level_ratio=0.0,
mask_flexivity=1.0)
latent_representations = latent_representations[-1]
quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
# Todo: remove hard-coding
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
quantized_latent_representations,
resolution=(
512,
width * self.VAE_scale),
original_STFT_batch=None,
)
return rec_signals[0]
return diffSynthSampler
for key in self.instruments_configs.keys():
self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])
def get_music(self, mid, instrument_names, sample_rate=16000):
tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
assert len(tracks) == len(instrument_names), f"len(tracks) = {len(tracks)} != {len(instrument_names)} = len(instrument_names)"
track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]
# 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
max_length = max(len(audio) for audio in track_audios)
full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零
for audio in track_audios:
# 音轨可能不够长,需要填充零
padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
full_audio += padded_audio # 叠加音轨
return full_audio |