from music21 import stream, chord import tensorflow as tf import keras import numpy as np import random import glob def predict_next_token(model, input_sequence, temperature=1, seed=42): "predict next token given a context, sample from a categorical distribution controllable via temperature" assert keras.ops.ndim(input_sequence) == 2, "function expects input_sequence to be (batch_size, sequence_len)" logits = model.predict_on_batch(input_sequence)[:, -1, :] scaled_logits = logits / temperature return tf.random.categorical(scaled_logits, num_samples=1, seed=seed) def generate_sequence(init_context, model, include_init_context=False, max_len=25, temperature=1 ,seed=42): """Generates a continuation of a given seed sequence by autoregressively sampling from the trained model.""" assert keras.ops.ndim(init_context) == 2, "function expects init_context to be (batch_size, sequence_len)" seq_len_init_context = init_context.shape[1] context = init_context for _ in range(max_len * 4): next_token = predict_next_token(model, context, temperature=temperature, seed=seed) context = keras.ops.concatenate([context, next_token], axis=1) return context if include_init_context else context[:,seq_len_init_context:] def generate_chorale(model, sample_seed_path,note2id,id2note, file_name= "samples/chorale.mid", max_len=25, temperature=1, sample_seed_rows: slice = slice(0,100), include_init_context=False, seed=42): """Generates a Bach-style MIDI chorale from a random seed sequence using the trained model.""" sample_seed = np.loadtxt(sample_seed_path, skiprows=1, delimiter=",").flatten()[sample_seed_rows].reshape(1,-1) sample_seed = note2id(sample_seed) chorale = generate_sequence(sample_seed, model, include_init_context=include_init_context, max_len=max_len, temperature=temperature ,seed=seed) chorale = keras.ops.convert_to_numpy(keras.ops.reshape(id2note(chorale), (-1,4))) strm = stream.Stream([chord.Chord(chorale[s].tolist()) for s in range(len(chorale))]) strm.write('midi', fp=file_name) print(f"chorale saved as {file_name}") def draw_random_sample(csv_dir, seed=42): """Selects and returns a random CSV file path from the given directory for sampling.""" files = glob.glob(csv_dir + '/*.csv') random.seed(seed) random.shuffle(files) return files[0]