File size: 11,083 Bytes
d896bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import os
import time
import sys
import multiprocessing
import itertools
import argparse
from itertools import product

import numpy as np
import tqdm
import pypianoroll as pproll
import muspy

import constants
from constants import PitchToken, DurationToken


def preprocess_midi_file(filepath, dest_dir, n_bars, resolution):

    print("Preprocessing file {}".format(filepath))

    filename = os.path.basename(filepath)
    saved_samples = 0

    # Load the file both as a pypianoroll song and a muspy song
    # (Need to load both since muspy.to_pypianoroll() is expensive)
    try:
        pproll_song = pproll.read(filepath, resolution=resolution)
        muspy_song = muspy.read(filepath)
    except Exception as e:
        print("Song skipped (Invalid song format)")
        return 0

    # Only accept songs that have a time signature of 4/4 and no time changes
    for t in muspy_song.time_signatures:
        if t.numerator != 4 or t.denominator != 4:
            print("Song skipped ({}/{} time signature)".
                  format(t.numerator, t.denominator))
            return 0

    # Gather tracks of pypianoroll song based on MIDI program number
    drum_tracks = []
    bass_tracks = []
    guitar_tracks = []
    strings_tracks = []

    for track in pproll_song.tracks:
        if track.is_drum:
            track.name = 'Drums'
            drum_tracks.append(track)
        elif 0 <= track.program <= 31:
            track.name = 'Guitar'
            guitar_tracks.append(track)
        elif 32 <= track.program <= 39:
            track.name = 'Bass'
            bass_tracks.append(track)
        else:
            # Tracks with program > 39 are all considered as strings tracks
            # and will be merged into a single track later on
            strings_tracks.append(track)

    # Filter song if it does not contain drum, guitar, bass or strings tracks
    # if not guitar_tracks \
    if not drum_tracks or not guitar_tracks \
            or not bass_tracks or not strings_tracks:
        print("Song skipped (does not contain drum or "
              "guitar or bass or strings tracks)")
        return 0

    # Merge strings tracks into a single pypianoroll track
    strings = pproll.Multitrack(tracks=strings_tracks)
    strings_track = pproll.Track(pianoroll=strings.blend(mode='max'),
                                 program=48, name='Strings')

    combinations = list(product(drum_tracks, bass_tracks, guitar_tracks))

    # Single instruments can have multiple tracks.
    # Consider all possible combinations of drum, bass, and guitar tracks
    for i, combination in enumerate(combinations):

        print("Processing combination {} of {}".format(i + 1, 
                                                       len(combinations)))

        # Process combination (called 'subsong' from now on)
        drum_track, bass_track, guitar_track = combination
        tracks = [drum_track, bass_track, guitar_track, strings_track]

        pproll_subsong = pproll.Multitrack(
            tracks=tracks,
            tempo=pproll_song.tempo,
            resolution=resolution
        )
        muspy_subsong = muspy.from_pypianoroll(pproll_subsong)

        tracks_notes = [track.notes for track in muspy_subsong.tracks]

        # Obtain length of subsong (maximum of each track's length)
        length = 0
        for notes in tracks_notes:
            track_length = max(note.end for note in notes) if notes else 0
            length = max(length, track_length)
        length += 1

        # Add timesteps until length is a multiple of resolution
        length = length if length % (4*resolution) == 0 \
            else length + (4*resolution-(length % (4*resolution)))

        tracks_content = []
        tracks_structure = []

        for notes in tracks_notes:

            # track_content: length x MAX_SIMU_TOKENS x 2
            # This is used as a basis to build the final content tensors for
            # each sequence.
            # The last dimension contains pitches and durations. int16 is enough
            # to encode small to medium duration values.
            track_content = np.zeros((length, constants.MAX_SIMU_TOKENS, 2), 
                                    np.int16)

            track_content[:, :, 0] = PitchToken.PAD.value
            track_content[:, 0, 0] = PitchToken.SOS.value
            track_content[:, :, 1] = DurationToken.PAD.value
            track_content[:, 0, 1] = DurationToken.SOS.value

            # Keeps track of how many notes have been stored in each timestep
            # (int8 imposes MAX_SIMU_TOKENS < 256)
            notes_counter = np.ones(length, dtype=np.int8)

            # Todo: np.put_along_axis?
            for note in notes:
                # Insert note in the lowest position available in the timestep

                t = note.time

                if notes_counter[t] >= constants.MAX_SIMU_TOKENS-1:
                    # Skip note if there is no more space
                    continue

                pitch = max(min(note.pitch, constants.MAX_PITCH_TOKEN), 0)
                track_content[t, notes_counter[t], 0] = pitch
                dur = max(min(note.duration, constants.MAX_DUR_TOKEN + 1), 1)
                track_content[t, notes_counter[t], 1] = dur-1
                notes_counter[t] += 1

            # Add EOS token
            t_range = np.arange(0, length)
            track_content[t_range, notes_counter, 0] = PitchToken.EOS.value
            track_content[t_range, notes_counter, 1] = DurationToken.EOS.value

            # Get track activations, a boolean tensor indicating whether notes
            # are being played in a timestep (sustain does not count)
            # (needed for graph rep.)
            activations = np.array(notes_counter-1, dtype=bool)

            tracks_content.append(track_content)
            tracks_structure.append(activations)

        # n_tracks x length x MAX_SIMU_TOKENS x 2
        subsong_content = np.stack(tracks_content, axis=0)

        # n_tracks x length
        subsong_structure = np.stack(tracks_structure, axis=0)

        # Slide window over 'subsong_content' and 'subsong_structure' along the
        # time axis (2nd dimension) with the stride of a bar
        # Todo: np.lib.stride_tricks.as_strided(song_proll)?
        for i in range(0, length-n_bars*4*resolution+1, 4*resolution):

            # Get the content and structure tensors of a single sequence
            c_tensor = subsong_content[:, i:i+n_bars*4*resolution, :]
            s_tensor = subsong_structure[:, i:i+n_bars*4*resolution]
            c_tensor = np.copy(c_tensor)
            s_tensor = np.copy(s_tensor)

            if n_bars > 1:
                # Skip sequence if it contains more than one bar of consecutive
                # silence in at least one track
                bars = s_tensor.reshape(s_tensor.shape[0], n_bars, -1)
                bars_acts = np.any(bars, axis=2)

                if 1 in np.diff(np.where(bars_acts == 0)[1]):
                    continue

                # Skip sequence if it contains one bar of complete silence
                silences = np.logical_not(np.any(bars_acts, axis=0))
                if np.any(silences):
                    continue

            else:
                # Skip if all tracks are silenced
                bar_acts = np.any(s_tensor, axis=1)
                if not np.any(bar_acts):
                    continue

            # Randomly transpose the pitches of the sequence (-5 to 6 semitones)
            # Not considering SOS, EOS or PAD tokens. Not transposing drums.
            shift = np.random.choice(np.arange(-5, 7), 1)
            cond = (c_tensor[1:, :, :, 0] != PitchToken.PAD.value) &           \
                   (c_tensor[1:, :, :, 0] != PitchToken.SOS.value) &           \
                   (c_tensor[1:, :, :, 0] != PitchToken.EOS.value)
            non_drums = c_tensor[1:, ...]
            non_drums[cond, 0] += shift
            non_drums[cond, 0] = np.clip(non_drums[cond, 0], a_min=0, 
                                         a_max=constants.MAX_PITCH_TOKEN)

            # Save sample (content and structure) to file
            sample_filepath = os.path.join(
                dest_dir, filename+str(saved_samples))
            np.savez(sample_filepath, c_tensor=c_tensor, s_tensor=s_tensor)

            saved_samples += 1


def preprocess_midi_dataset(midi_dataset_dir, preprocessed_dir, n_bars, 
                            resolution, n_files=None, n_workers=1):

    print("Starting preprocessing")
    start = time.time()

    # Visit recursively the directories inside the dataset directory
    with multiprocessing.Pool(n_workers) as pool:

        walk = os.walk(midi_dataset_dir)
        fn_gen = itertools.chain.from_iterable(
            ((os.path.join(dirpath, file), preprocessed_dir, n_bars, resolution)
                for file in files)
                for dirpath, dirs, files in walk
        )

        r = list(tqdm.tqdm(pool.starmap(preprocess_midi_file, fn_gen),
                           total=n_files))

    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Preprocessing completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
          .format(int(hours), int(minutes), seconds))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Preprocesses a MIDI dataset. MIDI files can be arranged " 
            "hierarchically in subdirectories, similarly to the Lakh MIDI "
            "Dataset (lmd_matched) and the MetaMIDI Dataset."
    )
    parser.add_argument(
        'midi_dataset_dir',
        type=str, 
        help='Directory of the MIDI dataset.'
    )
    parser.add_argument(
        'preprocessed_dir',
        type=str,
        help='Directory to save the preprocessed dataset.'
    )
    parser.add_argument(
        '--n_bars',
        type=int,
        default=2,
        help="Number of bars for each sequence of the resulting preprocessed "
            "dataset. Defaults to 2 bars."
    )
    parser.add_argument(
        '--resolution',
        type=int,
        default=8,
        help="Number of timesteps per beat. When set to r, given that only "
            "4/4 songs are preprocessed, there will be 4*r timesteps in a bar. "
            "Defaults to 8."
    )
    parser.add_argument(
        '--n_files',
        type=int,
        help="Number of files in the MIDI dataset. If set, the script "
            "will provide statistics on the time remaining."
    )
    parser.add_argument(
        '--n_workers',
        type=int,
        default=1,
        help="Number of parallel workers. Defaults to 1."
    )

    args = parser.parse_args()
    
    # Create the output directory if it does not exist
    if not os.path.exists(args.preprocessed_dir):
        os.makedirs(args.preprocessed_dir)

    preprocess_midi_dataset(args.midi_dataset_dir, args.preprocessed_dir, 
                            args.n_bars, args.resolution, args.n_files,
                            n_workers=args.n_workers)