xuantruong commited on
Commit
5ce0f29
·
1 Parent(s): c78d045

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +398 -0
inference.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('code')
3
+
4
+
5
+ import numpy as np
6
+ import os
7
+ import pretty_midi as pyd
8
+ import torch
9
+ import sys
10
+ from model import VAE
11
+ from util_tools.format_converter import melody_data2matrix, melody_matrix2data, chord_data2matrix, chord_matrix2data
12
+ from torch.distributions import kl_divergence, Normal
13
+ from nottingham_dataset import Nottingham
14
+ import copy
15
+
16
+ def chord_grid2data(est_pitch, bpm=60., start=0., max_simu_note=6, pitch_eos=129, num_step=32, min_pitch=0):
17
+ est_pitch = est_pitch[:, :, 0]#.cpu().detach().numpy() #(32, max_simu_note-1), NO BATCH HERE
18
+ if est_pitch.shape[1] == max_simu_note:
19
+ est_pitch = est_pitch[:, 1:]
20
+
21
+ #print(est_pitch.shape)
22
+ #print(est_pitch)
23
+ harmonic_rhythm = 1. - (est_pitch[:, 0]==pitch_eos) * 1.
24
+ #print(harmonic_rhythm)
25
+
26
+ pr = np.zeros((32, 128), dtype=int)
27
+ alpha = 0.25 * 60 / bpm
28
+ notes = []
29
+ for t in range(num_step):
30
+ for n in range(max_simu_note-1):
31
+ note = est_pitch[t, n]
32
+ if note == pitch_eos:
33
+ break
34
+ pitch = note + 12*4
35
+ duration = 1
36
+ for j in range(t+1, num_step):
37
+ if harmonic_rhythm[j] == 1:
38
+ break
39
+ duration +=1
40
+ pr[t, pitch] = min(duration, 32 - t)
41
+ notes.append(
42
+ pyd.Note(100, int(pitch), start + t * alpha,
43
+ start + (t + duration) * alpha))
44
+ chord = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano'))
45
+ chord.notes = notes
46
+ return chord
47
+
48
+ def melody_matrix2data(melody_matrix, tempo=120, start_time=0.0, get_list=False):
49
+ HOLD_PITCH = 12
50
+ REST_PITCH = 13
51
+ #melodyMatrix = melody_matrix[:, :ROLL_SIZE]
52
+ chroma = np.concatenate((melody_matrix[:, :12], melody_matrix[:, 15: 17]), axis=-1)
53
+ register = melody_matrix[:, -10:]
54
+ #print(chroma.shape)
55
+ melodySequence = np.argmax(chroma, axis=-1)
56
+ #print(melodySequence)
57
+
58
+ melody_notes = []
59
+ minStep = 60 / tempo / 4
60
+ onset_or_rest = [i for i in range(len(melodySequence)) if not melodySequence[i]==HOLD_PITCH]
61
+ onset_or_rest.append(len(melodySequence))
62
+
63
+ for idx, onset in enumerate(onset_or_rest[:-1]):
64
+ if melodySequence[onset] == REST_PITCH:
65
+ continue
66
+ else:
67
+ pitch = melodySequence[onset] + 12 * np.argmax(register[onset])
68
+ #print(pitch)
69
+ start = onset * minStep
70
+ end = onset_or_rest[idx+1] * minStep
71
+ noteRecon = pyd.Note(velocity=100, pitch=pitch, start=start_time+start, end=start_time+end)
72
+ melody_notes.append(noteRecon)
73
+ if get_list:
74
+ return melody_notes
75
+ else:
76
+ melody = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano'))
77
+ melody.notes = melody_notes
78
+ return melody
79
+
80
+
81
+ def get_gt(chord, melody):
82
+ #chord: (num_step, max_simu_note, 1), numpy
83
+ #melody: (num_step, 28), numpy
84
+ chord_recon = chord_grid2data(chord, 30, pitch_eos=13)
85
+ melody_recon = melody_matrix2data(melody, 120)
86
+ music = pyd.PrettyMIDI(initial_tempo=120)
87
+ music.instruments.append(melody_recon)
88
+ music.instruments.append(chord_recon)
89
+ return music
90
+
91
+ def shift(original_melody, p_shift):
92
+ melody = copy.deepcopy(original_melody).cpu().detach().numpy()[0]
93
+ onsets, pitch = np.nonzero(melody[:, :12])
94
+ onsets, register = np.nonzero(melody[:, -10:])
95
+ onset130 = pitch + register*12
96
+ onset130 += p_shift
97
+ onset12 = onset130 % 12
98
+ register = onset130 // 12
99
+ melody[onsets, :] = 0
100
+ melody[onsets, onset12] = 1.
101
+ melody[onsets, register+17] = 1.
102
+ return torch.from_numpy(melody).float().unsqueeze(0)
103
+
104
+
105
+ def reconstruct(chord, melody):
106
+ #chord: (1, num_step, max_simu_note, 1), torch.LongTensor, cuda()
107
+ #melody: (1, num_step*4, 28), torch.FloatTensor, cuda()
108
+ lengths = model.get_len_index_tensor(chord) # lengths: (B, num_step)
109
+ chord = model.index_tensor_to_multihot_tensor(chord)
110
+ chord = model.enc_note_embedding(chord) #(B, num_step, max_simu_note, note_emb_size)
111
+ mel_ebd = model.enc_note_embedding(melody) #(B, num_step*4, note_emb_size)
112
+ melody_beat_summary = mel_ebd[:, ::4, :] + mel_ebd[:, 1::4, :] + mel_ebd[:, 2::4, :] + mel_ebd[:, 3::4, :]
113
+ dist, mu, = model.encoder(chord, lengths, melody_beat_summary)
114
+ z = dist.mean
115
+ pitch_outs = model.decoder(z, melody_beat_summary,
116
+ inference=True, x=None, lengths=None,
117
+ teacher_forcing_ratio1=0., teacher_forcing_ratio2=0.)
118
+ pitch_outs = pitch_outs.max(-1, keepdim=True)[1]
119
+ pitch_outs = pitch_outs.cpu().detach().numpy()
120
+ chord_track = chord_grid2data(pitch_outs[0], bpm=120//4, start=0, pitch_eos=13)
121
+
122
+ melody = melody.cpu().detach().numpy()[0]
123
+ melody_track = melody_matrix2data(melody, tempo=120)
124
+
125
+ music = pyd.PrettyMIDI()
126
+ music.instruments.append(melody_track)
127
+ music.instruments.append(chord_track)
128
+ return music
129
+
130
+ def melody_control(chord, melody, new_melody):
131
+ #chord: (B, num_step, max_simu_note, 1), torch.LongTensor, cuda()
132
+ #melody: (B, num_step*4, 28), torch.FloatTensor, cuda()
133
+ #new_melody: (B, num_step*4, 28), torch.FloatTensor, cuda()
134
+
135
+ lengths = model.get_len_index_tensor(chord) # lengths: (B, num_step)
136
+ chord = model.index_tensor_to_multihot_tensor(chord)
137
+ chord = model.enc_note_embedding(chord) #(B, num_step, max_simu_note, note_emb_size)
138
+ mel_ebd = model.enc_note_embedding(melody) #(B, num_step*4, note_emb_size)
139
+ melody_beat_summary = mel_ebd[:, ::4, :] + mel_ebd[:, 1::4, :] + mel_ebd[:, 2::4, :] + mel_ebd[:, 3::4, :]
140
+ new_mel_ebd = model.enc_note_embedding(new_melody) #(B, num_step*4, note_emb_size)
141
+ new_melody_beat_summary = new_mel_ebd[:, ::4, :] + new_mel_ebd[:, 1::4, :] + new_mel_ebd[:, 2::4, :] + new_mel_ebd[:, 3::4, :]
142
+ dist, mu, = model.encoder(chord, lengths, melody_beat_summary)
143
+ z = dist.mean
144
+ pitch_outs = model.decoder(z, new_melody_beat_summary,
145
+ inference=True, x=None, lengths=None,
146
+ teacher_forcing_ratio1=0., teacher_forcing_ratio2=0.)
147
+ pitch_outs = pitch_outs.max(-1, keepdim=True)[1]
148
+ pitch_outs = pitch_outs.cpu().detach().numpy()
149
+ chord_track = chord_grid2data(pitch_outs[0], bpm=120//4, start=0, pitch_eos=13)
150
+
151
+ new_melody = new_melody.cpu().detach().numpy()[0]
152
+ melody_track = melody_matrix2data(new_melody, tempo=120)
153
+
154
+ music = pyd.PrettyMIDI()
155
+ music.instruments.append(melody_track)
156
+ music.instruments.append(chord_track)
157
+ return music
158
+
159
+ def melody_prior_control(new_melody):
160
+ #new_melody: (B, num_step*4, 28), torch.FloatTensor, cuda()
161
+ new_mel_ebd = model.enc_note_embedding(new_melody) #(B, num_step*4, note_emb_size)
162
+ new_melody_beat_summary = new_mel_ebd[:, ::4, :] + new_mel_ebd[:, 1::4, :] + new_mel_ebd[:, 2::4, :] + new_mel_ebd[:, 3::4, :]
163
+ z = Normal(torch.zeros(128), torch.ones(128)).rsample().unsqueeze(0)
164
+ pitch_outs = model.decoder(z, new_melody_beat_summary,
165
+ inference=True, x=None, lengths=None,
166
+ teacher_forcing_ratio1=0., teacher_forcing_ratio2=0.)
167
+ pitch_outs = pitch_outs.max(-1, keepdim=True)[1]
168
+ pitch_outs = pitch_outs.cpu().detach().numpy()
169
+ chord_track = chord_grid2data(pitch_outs[0], bpm=120//4, start=0, pitch_eos=13)
170
+
171
+ new_melody = new_melody.cpu().detach().numpy()[0]
172
+ melody_track = melody_matrix2data(new_melody, tempo=120)
173
+
174
+ music = pyd.PrettyMIDI()
175
+ music.instruments.append(melody_track)
176
+ music.instruments.append(chord_track)
177
+ return music
178
+
179
+
180
+ import utils
181
+ config_fn = './code/model_config.json'
182
+ #train_hyperparams = utils.load_params_dict('train_hyperparams', config_fn)
183
+ model_params = utils.load_params_dict('model_params', config_fn)
184
+ data_repr_params = utils.load_params_dict('data_repr', config_fn)
185
+ #project_params = utils.load_params_dict('project', config_fn)
186
+ #dataset_path = utils.load_params_dict('dataset_paths', config_fn)
187
+
188
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
189
+
190
+ model = VAE(max_simu_note=data_repr_params['max_simu_note'],
191
+ max_pitch=data_repr_params['max_pitch'],
192
+ min_pitch=data_repr_params['min_pitch'],
193
+ pitch_sos=data_repr_params['pitch_sos'],
194
+ pitch_eos=data_repr_params['pitch_eos'],
195
+ pitch_pad=data_repr_params['pitch_pad'],
196
+ num_step=data_repr_params['num_time_step'],
197
+
198
+ note_emb_size=model_params['note_emb_size'],
199
+ enc_notes_hid_size=model_params['enc_notes_hid_size'],
200
+ enc_time_hid_size=model_params['enc_time_hid_size'],
201
+ z_size=model_params['z_size'],
202
+ dec_emb_hid_size=model_params['dec_emb_hid_size'],
203
+ dec_time_hid_size=model_params['dec_time_hid_size'],
204
+ dec_notes_hid_size=model_params['dec_notes_hid_size'],
205
+ discr_nhead = model_params["discr_nhead"],
206
+ discr_hid_size = model_params["discr_hid_size"],
207
+ discr_dropout = model_params["discr_dropout"],
208
+ discr_nlayer = model_params["discr_nlayer"],
209
+
210
+ device=device
211
+ )
212
+
213
+
214
+ weight_path = './code/ad-ptvae_param.pt'
215
+ params = torch.load(weight_path,map_location=torch.device(device))
216
+ if 'model_state_dict' in params:
217
+ params = params['model_state_dict']
218
+ model.load_state_dict(params)
219
+ if torch.cuda.is_available():
220
+ model.cuda()
221
+ else:
222
+ model.cpu()
223
+
224
+ model.eval()
225
+ print('-'*100)
226
+ print(f'Loaded {weight_path}')
227
+ print('-'*100)
228
+
229
+
230
+ dataset = np.load('./code/data.npy', allow_pickle=True).T
231
+ print('-'*100)
232
+ print(f'Loaded ./code/data.npy')
233
+ print('-'*100)
234
+ np.random.seed(0)
235
+ np.random.shuffle(dataset)
236
+ anchor = int(dataset.shape[0] * 0.95)
237
+ val_data = dataset[anchor:, :]
238
+ val_set = Nottingham(dataset=val_data.T,
239
+ length=128,
240
+ step_size=16,
241
+ chord_fomat='pr', shift_high=0, shift_low=0)
242
+ print(len(val_set))
243
+ WRITE_PATH = './code/demo_generate'
244
+ if not os.path.exists(WRITE_PATH):
245
+ os.makedirs(WRITE_PATH)
246
+
247
+
248
+ chord_1, _, melody_1, _ = val_set.__getitem__(338)
249
+ music = get_gt(chord_1, melody_1)
250
+ music.write(os.path.join(WRITE_PATH, 'gt_1.mid'))
251
+ chord_1 = torch.from_numpy(chord_1).long().unsqueeze(0)
252
+ melody_1 = torch.from_numpy(melody_1).float().unsqueeze(0)
253
+ music = reconstruct(chord_1, melody_1)
254
+ music.write(os.path.join(WRITE_PATH, 'recon_1.mid'))
255
+ print(f'Saved to {WRITE_PATH}/recon_1.mid')
256
+
257
+ chord_2, _, melody_2, _ = val_set.__getitem__(2749)
258
+ music = get_gt(chord_2, melody_2)
259
+ music.write(os.path.join(WRITE_PATH, 'gt_2.mid'))
260
+ chord_2 = torch.from_numpy(chord_2).long().unsqueeze(0)
261
+ melody_2 = torch.from_numpy(melody_2).float().unsqueeze(0)
262
+ music = reconstruct(chord_2, melody_2)
263
+ music.write(os.path.join(WRITE_PATH, 'recon_2.mid'))
264
+ print(f'Saved to {WRITE_PATH}/recon_2.mid')
265
+
266
+ chord_3, _, melody_3, _ = val_set.__getitem__(3413)
267
+ music = get_gt(chord_3, melody_3)
268
+ music.write(os.path.join(WRITE_PATH, 'gt_3.mid'))
269
+ chord_3 = torch.from_numpy(chord_3).long().unsqueeze(0)
270
+ melody_3 = torch.from_numpy(melody_3).float().unsqueeze(0)
271
+ music = reconstruct(chord_3, melody_3)
272
+ music.write(os.path.join(WRITE_PATH, 'recon_3.mid'))
273
+ print(f'Saved to {WRITE_PATH}/recon_3.mid')
274
+
275
+ chord_4, _, melody_4, _ = val_set.__getitem__(5126)
276
+ music = get_gt(chord_4, melody_4)
277
+ music.write(os.path.join(WRITE_PATH, 'gt_4.mid'))
278
+ chord_4 = torch.from_numpy(chord_4).long().unsqueeze(0)
279
+ melody_4 = torch.from_numpy(melody_4).float().unsqueeze(0)
280
+ music = reconstruct(chord_4, melody_4)
281
+ music.write(os.path.join(WRITE_PATH, 'recon_4.mid'))
282
+
283
+
284
+ midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_1.mid')
285
+ melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats())
286
+ melody = val_set.truncate_melody(melody)
287
+ melody_1_modal_change = torch.from_numpy(melody).float().unsqueeze(0)
288
+
289
+ midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_2.mid')
290
+ melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats())
291
+ melody = val_set.truncate_melody(melody)
292
+ melody_2_modal_change = torch.from_numpy(melody).float().unsqueeze(0)
293
+
294
+ midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_3.mid')
295
+ melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats())
296
+ melody = val_set.truncate_melody(melody)
297
+ melody_3_modal_change = torch.from_numpy(melody).float().unsqueeze(0)
298
+
299
+ midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_4.mid')
300
+ melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats())
301
+ melody = val_set.truncate_melody(melody)
302
+ melody_4_modal_change = torch.from_numpy(melody).float().unsqueeze(0)
303
+
304
+
305
+
306
+
307
+ music = melody_control(chord_1, melody_1, shift(melody_1, 6))
308
+ music.write(os.path.join(WRITE_PATH, 'control_1_transpose.mid'))
309
+ print(f'Saved to {WRITE_PATH}/control_1_transpose.mid')
310
+
311
+ music = melody_control(chord_2, melody_2, shift(melody_2, 6))
312
+ music.write(os.path.join(WRITE_PATH, 'control_2_transpose.mid'))
313
+ print(f'Saved to {WRITE_PATH}/control_2_transpose.mid')
314
+
315
+ music = melody_control(chord_3, melody_3, shift(melody_3, 6))
316
+ music.write(os.path.join(WRITE_PATH, 'control_3_transpose.mid'))
317
+ print(f'Saved to {WRITE_PATH}/control_3_transpose.mid')
318
+
319
+ music = melody_control(chord_4, melody_4, shift(melody_4, 6))
320
+ music.write(os.path.join(WRITE_PATH, 'control_4_transpose1.mid'))
321
+ print(f'Saved to {WRITE_PATH}/control_4_transpose.mid')
322
+
323
+
324
+
325
+ music = melody_control(chord_1, melody_1, melody_1_modal_change)
326
+ music.write(os.path.join(WRITE_PATH, 'control_1_modal_change.mid'))
327
+ print(f'Saved to {WRITE_PATH}/control_1_modal_change.mid')
328
+
329
+ music = melody_control(chord_2, melody_2, melody_2_modal_change)
330
+ music.write(os.path.join(WRITE_PATH, 'control_2_modal_change.mid'))
331
+ print(f'Saved to {WRITE_PATH}/control_2_modal_change.mid')
332
+
333
+ music = melody_control(chord_3, melody_3, melody_3_modal_change)
334
+ music.write(os.path.join(WRITE_PATH, 'control_3_modal_change.mid'))
335
+ print(f'Saved to {WRITE_PATH}/control_3_modal_change.mid')
336
+
337
+ music = melody_control(chord_4, melody_4, melody_4_modal_change)
338
+ music.write(os.path.join(WRITE_PATH, 'control_4_modal_change.mid'))
339
+ print(f'Saved to {WRITE_PATH}/control_4_modal_change.mid')
340
+
341
+
342
+
343
+
344
+ music = melody_prior_control(melody_1)
345
+ music.write(os.path.join(WRITE_PATH, 'control_1_prior.mid'))
346
+ print(f'Saved to {WRITE_PATH}/control_1_prior.mid')
347
+
348
+ music = melody_prior_control(melody_2)
349
+ music.write(os.path.join(WRITE_PATH, 'control_2_prior.mid'))
350
+ print(f'Saved to {WRITE_PATH}/control_2_prior.mid')
351
+
352
+ music = melody_prior_control(melody_3)
353
+ music.write(os.path.join(WRITE_PATH, 'control_3_prior.mid'))
354
+ print(f'Saved to {WRITE_PATH}/control_3_prior.mid')
355
+
356
+ music = melody_prior_control(melody_4)
357
+ music.write(os.path.join(WRITE_PATH, 'control_4_prior.mid'))
358
+ print(f'Saved to {WRITE_PATH}/control_4_prior.mid')
359
+
360
+
361
+
362
+
363
+ music = melody_control(chord_1, melody_1, melody_2)
364
+ music.write(os.path.join(WRITE_PATH, 'control_1c+2m.mid'))
365
+
366
+ music = melody_control(chord_2, melody_2, melody_1)
367
+ music.write(os.path.join(WRITE_PATH, 'control_2c+1m.mid'))
368
+
369
+ music = melody_control(chord_1, melody_1, melody_3)
370
+ music.write(os.path.join(WRITE_PATH, 'control_1c+3m.mid'))
371
+
372
+ music = melody_control(chord_3, melody_3, melody_1)
373
+ music.write(os.path.join(WRITE_PATH, 'control_3c+1m.mid'))
374
+
375
+ music = melody_control(chord_1, melody_1, melody_4)
376
+ music.write(os.path.join(WRITE_PATH, 'control_1c+4m.mid'))
377
+
378
+ music = melody_control(chord_4, melody_4, melody_1)
379
+ music.write(os.path.join(WRITE_PATH, 'control_4c+1m.mid'))
380
+
381
+ music = melody_control(chord_2, melody_2, melody_3)
382
+ music.write(os.path.join(WRITE_PATH, 'control_2c+3m.mid'))
383
+
384
+ music = melody_control(chord_3, melody_3, melody_2)
385
+ music.write(os.path.join(WRITE_PATH, 'control_3c+2m.mid'))
386
+
387
+ music = melody_control(chord_2, melody_2, melody_4)
388
+ music.write(os.path.join(WRITE_PATH, 'control_2c+4m.mid'))
389
+
390
+ music = melody_control(chord_4, melody_4, melody_2)
391
+ music.write(os.path.join(WRITE_PATH, 'control_4c+2m.mid'))
392
+
393
+ music = melody_control(chord_3, melody_3, melody_4)
394
+ music.write(os.path.join(WRITE_PATH, 'control_3c+4m.mid'))
395
+
396
+ music = melody_control(chord_4, melody_4, melody_3)
397
+ music.write(os.path.join(WRITE_PATH, 'control_4c+3m.mid'))
398
+