Spaces:
Sleeping
Sleeping
| import copy | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import muspy | |
| from prettytable import PrettyTable | |
| from constants import PitchToken, DurationToken | |
| import constants | |
| import generation_config | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| def append_dict(dest_d, source_d): | |
| for k, v in source_d.items(): | |
| dest_d[k].append(v) | |
| def print_params(model): | |
| table = PrettyTable(["Modules", "Parameters"]) | |
| total_params = 0 | |
| for name, parameter in model.named_parameters(): | |
| if not parameter.requires_grad: | |
| continue | |
| param = parameter.numel() | |
| table.add_row([name, param]) | |
| total_params += param | |
| print(table) | |
| print(f"Total Trainable Parameters: {total_params}") | |
| return total_params | |
| def print_divider(): | |
| print('—' * 40) | |
| # Builds multitrack pianoroll (mtp) from content tensor containing logits and | |
| # structure binary tensor | |
| # c_logits: num_nodes x MAX_SIMU_TOKENS x d_token | |
| # s_tensor: n_batches x n_bars x n_tracks x n_timesteps | |
| def mtp_from_logits(c_logits, s_tensor): | |
| mtp = torch.zeros((s_tensor.size(0), s_tensor.size(1), s_tensor.size(2), | |
| s_tensor.size(3), c_logits.size(-2), c_logits.size(-1)), | |
| device=c_logits.device, dtype=c_logits.dtype) | |
| size = mtp.size() | |
| mtp = mtp.reshape(-1, mtp.size(-2), mtp.size(-1)) | |
| silence = torch.zeros((mtp.size(-2), mtp.size(-1)), | |
| device=c_logits.device, dtype=c_logits.dtype) | |
| # Create silences with pitch EOS and PAD tokens | |
| silence[0, PitchToken.EOS.value] = 1. | |
| silence[1:, PitchToken.PAD.value] = 1. | |
| # Fill the multitrack pianoroll | |
| mtp[s_tensor.bool().reshape(-1)] = c_logits | |
| mtp[torch.logical_not(s_tensor.bool().reshape(-1))] = silence | |
| mtp = mtp.reshape(size) | |
| return mtp | |
| # mtp: n_bars x n_tracks x n_timesteps x MAX_SIMU_TOKENS x d_token | |
| def muspy_from_mtp(mtp): | |
| n_timesteps = mtp.size(2) | |
| resolution = n_timesteps // 4 | |
| # Collapse bars dimension | |
| mtp = mtp.permute(1, 0, 2, 3, 4) | |
| size = (mtp.shape[0], -1, mtp.shape[3], mtp.shape[4]) | |
| mtp = mtp.reshape(*size) | |
| tracks = [] | |
| for track_idx in range(mtp.size(0)): | |
| notes = [] | |
| for t in range(mtp.size(1)): | |
| for note_idx in range(mtp.size(2)): | |
| # Compute pitch and duration values | |
| pitch = mtp[track_idx, t, note_idx, :constants.N_PITCH_TOKENS] | |
| dur = mtp[track_idx, t, note_idx, constants.N_PITCH_TOKENS:] | |
| pitch, dur = torch.argmax(pitch), torch.argmax(dur) | |
| if (pitch == PitchToken.EOS.value or | |
| pitch == PitchToken.PAD.value or | |
| dur == DurationToken.EOS.value or | |
| dur == DurationToken.PAD.value): | |
| # The chord contains no additional notes, go to next chord | |
| break | |
| if (pitch == PitchToken.SOS.value or | |
| pitch == PitchToken.SOS.value): | |
| # Skip this note | |
| continue | |
| # Remapping duration values from [0, 95] to [1, 96] | |
| dur = dur + 1 | |
| # Do not sustain notes beyond sequence limit | |
| dur = min(dur.item(), mtp.size(1) - t) | |
| notes.append(muspy.Note(t, pitch.item(), dur, 64)) | |
| track_name = constants.TRACKS[track_idx] | |
| midi_program = generation_config.MIDI_PROGRAMS[track_name] | |
| is_drum = (track_name == 'Drums') | |
| track = muspy.Track( | |
| name=track_name, | |
| is_drum=is_drum, | |
| program=(0 if is_drum else midi_program), | |
| notes=copy.deepcopy(notes) | |
| ) | |
| tracks.append(track) | |
| meta = muspy.Metadata() | |
| music = muspy.Music(tracks=tracks, metadata=meta, resolution=resolution) | |
| return music | |
| def loop_muspy_music(muspy_music, n_loop, num_bars, resolution): | |
| # Get a deep copy of the original music object to avoid modifying it | |
| looped_music = copy.deepcopy(muspy_music) | |
| # Loop over the number of times we want to repeat the sequence | |
| for i in range(1, n_loop): | |
| # Loop over each track in the original music object | |
| for track_idx, track in enumerate(muspy_music.tracks): | |
| # Adjust the start times of the notes for each repetition and | |
| # add them to the corresponding track in the looped_music object | |
| for note in track.notes: | |
| new_note = copy.deepcopy(note) | |
| new_note.time += i * num_bars * 4 * resolution | |
| looped_music.tracks[track_idx].notes.append(new_note) | |
| return looped_music | |
| def add_end_of_track(muspy_music): | |
| bar_length = 32 | |
| # Determine the last timestep in the song | |
| last_time = max( | |
| note.start for track in muspy_music.tracks for note in track.notes) | |
| # Calculate the position of the current bar line | |
| current_bar_position = (last_time // bar_length) * bar_length | |
| # Check if there's any note starting or extending | |
| # into the current bar's end position | |
| notes = [ | |
| note for track in muspy_music.tracks for note in track.notes | |
| if note.start + note.duration == current_bar_position + bar_length | |
| ] | |
| if not notes: | |
| # If no notes extend into the current bar's end position, | |
| # add a low-pitched note | |
| muspy_music.tracks[0].notes.append( | |
| muspy.Note(time=current_bar_position + bar_length, | |
| duration=1, pitch=70, velocity=64)) | |
| return muspy_music | |
| def save_midi(muspy_song, save_dir, name): | |
| # Add low MIDI note at last timestep | |
| muspy_song.tracks[0].notes.append( | |
| muspy.Note(time=61, duration=3, pitch=23, velocity=1)) | |
| muspy.write_midi(os.path.join(save_dir, name + ".mid"), muspy_song) | |
| def save_audio(muspy_song, save_dir, name): | |
| soundfont_path = (generation_config.SOUNDFONT_PATH | |
| if os.path.exists(generation_config.SOUNDFONT_PATH) | |
| else None) | |
| muspy.write_audio(os.path.join(save_dir, name + ".wav"), muspy_song, | |
| soundfont_path=soundfont_path) | |