Spaces:
Sleeping
Sleeping
| from musc.dtw.mrmsdtw import sync_via_mrmsdtw_with_anchors | |
| from musc.dtw.utils import make_path_strictly_monotonic | |
| import numpy as np | |
| from musc.transcriber import Transcriber | |
| from typing import Dict | |
| class Synchronizer(Transcriber): | |
| def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160): | |
| super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length) | |
| def synchronize(self, audio, midi, batch_size=128, include_pitch_bends=True, to_midi=True, debug=False, | |
| include_velocity=False, alignment_padding=50, timing_refinement_range_with_f0s=0): | |
| """ | |
| Synchronize an audio file or mono waveform in numpy or torch with a MIDI file. | |
| :param audio: str, pathlib.Path, np.ndarray, or torch.Tensor | |
| :param midi: str, pathlib.Path, or pretty_midi.PrettyMIDI | |
| :param batch_size: frames to process at once | |
| :param include_pitch_bends: whether to include pitch bends in the MIDI file | |
| :param to_midi: whether to return a MIDI file or a list of note events (as tuple) | |
| :param debug: whether to plot the alignment path and compare the alignment with the predicted notes | |
| :param include_velocity: whether to embed the note confidence in place of the velocity in the MIDI file | |
| :param alignment_padding: how many frames to pad the audio and MIDI representations with | |
| :param timing_refinement_range_with_f0s: how many frames to refine the alignment with the f0 confidence | |
| :return: aligned MIDI file as a pretty_midi.PrettyMIDI object | |
| Args: | |
| debug: | |
| to_midi: | |
| include_pitch_bends: | |
| """ | |
| audio = self.predict(audio, batch_size) | |
| notes_and_midi = self.out2sync(audio, midi, include_velocity=include_velocity, | |
| alignment_padding=alignment_padding) | |
| if notes_and_midi: # it might be none | |
| notes, midi = notes_and_midi | |
| if debug: | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| estimated_notes = self.out2note(audio, postprocessing='spotify', include_pitch_bends=True) | |
| est_df = pd.DataFrame(estimated_notes).sort_values(by=0) | |
| note_df = pd.DataFrame(notes).sort_values(by=0) | |
| fig, ax = plt.subplots(figsize=(20, 10)) | |
| for row in notes: | |
| t_start = row[0] # sec | |
| t_end = row[1] # sec | |
| freq = row[2] # Hz | |
| ax.hlines(freq, t_start, t_end, color='k', linewidth=3, zorder=2, alpha=0.5) | |
| for row in estimated_notes: | |
| t_start = row[0] # sec | |
| t_end = row[1] # sec | |
| freq = row[2] # Hz | |
| ax.hlines(freq, t_start, t_end, color='r', linewidth=3, zorder=2, alpha=0.5) | |
| fig.suptitle('alignment (black) vs. estimated (red)') | |
| fig.show() | |
| if not include_pitch_bends: | |
| if to_midi: | |
| return midi['midi'] | |
| else: | |
| return notes | |
| else: | |
| notes = [(np.argmin(np.abs(audio['time']-note[0])), | |
| np.argmin(np.abs(audio['time']-note[1])), | |
| note[2], note[3]) for note in notes] | |
| notes = self.get_pitch_bends(audio["f0"], notes, timing_refinement_range_with_f0s) | |
| notes = [ | |
| (audio['time'][note[0]], audio['time'][note[1]], note[2], note[3], note[4]) for note in | |
| notes | |
| ] | |
| if to_midi: | |
| return self.note2midi(notes, 120) #int(midi['midi'].estimate_tempo())) | |
| else: | |
| return notes | |
| def out2sync_old(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False): | |
| """ | |
| Synchronizes the output of the model with the MIDI file. | |
| Args: | |
| out: Model output dictionary | |
| midi: Path to the MIDI file or PrettyMIDI object | |
| include_velocity: Whether to encode the note confidence in place of velocity | |
| alignment_padding: Number of frames to pad the MIDI features with zeros | |
| debug: Visualize the alignment | |
| Returns: | |
| note events and the aligned PrettyMIDI object | |
| """ | |
| midi = self.labeling.represent_midi(midi, self.sr/self.hop_length) | |
| audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length, | |
| pad_length=alignment_padding) | |
| if isinstance(audio_midi_anchors, str): | |
| print(audio_midi_anchors) | |
| return None # the file is corrupted! no possible alignment at all | |
| else: | |
| audio, midi, anchor_pairs = audio_midi_anchors | |
| ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets | |
| wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T, | |
| f_onset1=np.hstack([ALPHA * audio['onset'], | |
| (1 - ALPHA) * audio['offset']]).T, | |
| f_chroma2=midi['note'].T, | |
| f_onset2=np.hstack([ALPHA * midi['onset'], | |
| (1 - ALPHA) * midi['offset']]).T, | |
| input_feature_rate=self.sr/self.hop_length, | |
| step_weights=np.array([1.5, 1.5, 2.0]), | |
| threshold_rec=10 ** 6, | |
| verbose=debug, normalize_chroma=False, | |
| anchor_pairs=anchor_pairs) | |
| wp = make_path_strictly_monotonic(wp).astype(int) | |
| audio_time = np.take(audio['time'], wp[0]) | |
| midi_time = np.take(midi['time'], wp[1]) | |
| notes = [] | |
| for instrument in midi['midi'].instruments: | |
| for note in instrument.notes: | |
| note.start = np.interp(note.start, midi_time, audio_time) | |
| note.end = np.interp(note.end, midi_time, audio_time) | |
| if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames) | |
| note.start = note.start - 0.003 | |
| note.end = note.start + 0.012 | |
| if include_velocity: # encode the note confidence in place of velocity | |
| velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)): | |
| np.argmin(np.abs(audio['time']-note.end)), | |
| note.pitch-self.labeling.midi_centers[0]]) | |
| note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note | |
| else: | |
| velocity = note.velocity/127 | |
| notes.append((note.start, note.end, note.pitch, velocity)) | |
| return notes, midi | |
| def out2sync(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False): | |
| """ | |
| Synchronizes the output of the model with the MIDI file. | |
| Args: | |
| out: Model output dictionary | |
| midi: Path to the MIDI file or PrettyMIDI object | |
| include_velocity: Whether to encode the note confidence in place of velocity | |
| alignment_padding: Number of frames to pad the MIDI features with zeros | |
| debug: Visualize the alignment | |
| Returns: | |
| note events and the aligned PrettyMIDI object | |
| """ | |
| midi = self.labeling.represent_midi(midi, self.sr/self.hop_length) | |
| audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length, | |
| pad_length=alignment_padding) | |
| if isinstance(audio_midi_anchors, str): | |
| print(audio_midi_anchors) | |
| return None # the file is corrupted! no possible alignment at all | |
| else: | |
| audio, midi, anchor_pairs = audio_midi_anchors | |
| ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets | |
| starts = (np.array(anchor_pairs[0])*self.sr/self.hop_length).astype(int) | |
| ends = (np.array(anchor_pairs[1])*self.sr/self.hop_length).astype(int) | |
| wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T[:, starts[0]:ends[0]], | |
| f_onset1=np.hstack([ALPHA * audio['onset'], | |
| (1 - ALPHA) * audio['offset']]).T[:, starts[0]:ends[0]], | |
| f_chroma2=midi['note'].T[:, starts[1]:ends[1]], | |
| f_onset2=np.hstack([ALPHA * midi['onset'], | |
| (1 - ALPHA) * midi['offset']]).T[:, starts[1]:ends[1]], | |
| input_feature_rate=self.sr/self.hop_length, | |
| step_weights=np.array([1.5, 1.5, 2.0]), | |
| threshold_rec=10 ** 6, | |
| verbose=debug, normalize_chroma=False, | |
| anchor_pairs=None) | |
| wp = make_path_strictly_monotonic(wp).astype(int) | |
| wp[0] += starts[0] | |
| wp[1] += starts[1] | |
| wp = np.hstack((wp, ends[:,np.newaxis])) | |
| audio_time = np.take(audio['time'], wp[0]) | |
| midi_time = np.take(midi['time'], wp[1]) | |
| notes = [] | |
| for instrument in midi['midi'].instruments: | |
| for note in instrument.notes: | |
| note.start = np.interp(note.start, midi_time, audio_time) | |
| note.end = np.interp(note.end, midi_time, audio_time) | |
| if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames) | |
| note.start = note.start - 0.003 | |
| note.end = note.start + 0.012 | |
| if include_velocity: # encode the note confidence in place of velocity | |
| velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)): | |
| np.argmin(np.abs(audio['time']-note.end)), | |
| note.pitch-self.labeling.midi_centers[0]]) | |
| note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note | |
| else: | |
| velocity = note.velocity/127 | |
| notes.append((note.start, note.end, note.pitch, velocity)) | |
| return notes, midi | |
| def pad_representations(dict_of_representations, pad_length=10): | |
| """ | |
| Pad the representations so that the DTW does not enforce them to encompass the entire duration. | |
| Args: | |
| dict_of_representations: audio or midi representations | |
| pad_length: how many frames to pad | |
| Returns: | |
| padded representations | |
| """ | |
| for key, value in dict_of_representations.items(): | |
| if key == 'time': | |
| padded_time = dict_of_representations[key] | |
| padded_time = np.concatenate([padded_time[:2*pad_length], padded_time+padded_time[2*pad_length]]) | |
| dict_of_representations[key] = padded_time - padded_time[pad_length] # this is to ensure that the | |
| # first frame times are negative until the real zero time | |
| elif key in ['onset', 'offset', 'note']: | |
| dict_of_representations[key] = np.pad(value, ((pad_length, pad_length), (0, 0))) | |
| elif key in ['start_anchor', 'end_anchor']: | |
| anchor_time = dict_of_representations[key][0][0] | |
| anchor_time = np.argmin(np.abs(dict_of_representations['time'] - anchor_time)) | |
| dict_of_representations[key][:,0] = anchor_time | |
| dict_of_representations[key] = dict_of_representations[key].astype(np.int) | |
| return dict_of_representations | |
| def prepare_for_synchronization(self, audio, midi, feature_rate=44100/256, pad_length=100): | |
| """ | |
| MrMsDTW works better with start and end anchors. This function finds the start and end anchors for audio | |
| based on the midi notes. It also pads the MIDI representations since MIDI files most often start with an active | |
| note and end with an active note. Thus, the DTW will try to align the active notes to the entire duration of the | |
| audio. This is not desirable. Therefore, we pad the MIDI representations with a few frames of silence at the | |
| beginning and end of the audio. This way, the DTW will not try to align the active notes to the entire duration. | |
| Args: | |
| audio: | |
| midi: | |
| feature_rate: | |
| pad_length: | |
| Returns: | |
| """ | |
| # first pad the MIDI | |
| midi = self.pad_representations(midi, pad_length) | |
| # sometimes f0s are more reliable than the notes. So, we use both the f0s and the notes together to find the | |
| # start and end anchors. f0 lookup bins is the number of bins to look around the f0 to assign a note to it. | |
| f0_lookup_bins = int(100//(2*self.labeling.f0_granularity_c)) | |
| # find the start anchor for the audio | |
| # first decide on which notes to use for the start anchor (take the entire chord where the MIDI file starts) | |
| anchor_notes = midi['start_anchor'][:, 1] - self.labeling.midi_centers[0] | |
| # now find which f0 bins to look at for the start anchor | |
| anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes] | |
| anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1) | |
| # first start anchor proposals come from the notes | |
| anchor_vals = np.any(audio['note'][:, anchor_notes]>0.5, axis=1) | |
| # now the f0s | |
| anchor_vals_f0 = np.any(audio['f0'][:, anchor_f0s]>0.5, axis=1) | |
| # combine the two | |
| anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0) | |
| if not any(anchor_vals): | |
| return 'corrupted' # do not consider the file if we cannot find the start anchor | |
| audio_start = np.argmax(anchor_vals) | |
| # now the end anchor (most string instruments use chords in cadences: in general the end anchor is polyphonic) | |
| anchor_notes = midi['end_anchor'][:, 1] - self.labeling.midi_centers[0] | |
| anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes] | |
| anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1) | |
| # the same procedure as above | |
| anchor_vals = np.any(audio['note'][::-1, anchor_notes]>0.5, axis=1) | |
| anchor_vals_f0 = np.any(audio['f0'][::-1, anchor_f0s]>0.5, axis=1) | |
| anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0) | |
| if not any(anchor_vals): | |
| return 'corrupted' # do not consider the file if we cannot find the end anchor | |
| audio_end = audio['note'].shape[0] - np.argmax(anchor_vals) | |
| if audio_end - audio_start < (midi['end_anchor'][0][0] - midi['start_anchor'][0][0])/10: # no one plays x10 faster | |
| return 'corrupted' # do not consider the interval between anchors is too short | |
| anchor_pairs = [(audio_start - 5, midi['start_anchor'][0][0] - 5), | |
| (audio_end + 5, midi['end_anchor'][0][0] + 5)] | |
| if anchor_pairs[0][0] < 1: | |
| anchor_pairs[0] = (1, midi['start_anchor'][0][0]) | |
| if anchor_pairs[1][0] > audio['note'].shape[0] - 1: | |
| anchor_pairs[1] = (audio['note'].shape[0] - 1, midi['end_anchor'][0][0]) | |
| return audio, midi, [(anchor_pairs[0][0]/feature_rate, anchor_pairs[0][1]/feature_rate), | |
| (anchor_pairs[1][0]/feature_rate, anchor_pairs[1][1]/feature_rate)] | |