import tensorflow as tf import numpy as np import keras import glob import os AUTOTUNE = tf.data.AUTOTUNE def NoteEncoder(vocab_path, samples_path=None): """Loads or builds a vocabulary from CSV note files and returns IntegerLookup layers for encoding and decoding notes""" vocab_file = os.path.join(vocab_path, "vocab.npy") if os.path.exists(vocab_file): print("vocab.npy found, loading from disk...") vocab = np.load(vocab_file) elif samples_path is not None: print("vocab.npy not found, adapting from sample files...") files = glob.glob(os.path.join(samples_path, "*.csv")) vocab = np.unique(np.hstack([np.loadtxt(p, delimiter=",", skiprows=1).flatten() for p in files])) os.makedirs(vocab_path, exist_ok=True) np.save(vocab_file, vocab) print(f"vocab adapted and saved to {vocab_file}") else: raise ValueError("vocab file not found and samples_path not provided.") note2id = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab) id2note = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab, invert=True) return note2id, id2note, vocab def parse_and_flatten(line): """Parses a line of csv note data and flattens it into individual note tensors.""" fields = tf.io.decode_csv(line, [0,0,0,0]) return tf.data.Dataset.from_tensor_slices(fields) def seq2seq_from_chorale(path, seq_len, window_shift): """creates seq2seq overlapping windows from a sequence""" return tf.data.TextLineDataset(path).skip(1)\ .flat_map(parse_and_flatten)\ .window(seq_len + window_shift, shift=window_shift, drop_remainder=True)\ .flat_map(lambda yushi: yushi.batch(seq_len + window_shift))\ .map(lambda aiden: (aiden[:-window_shift] , aiden[window_shift:]), AUTOTUNE) def seq2seq_dataset(files_path, lookup_fn, seq_len=256, window_shift=1, batch_size=64, shuffle_buffer=None, seed=42): """Converts a single chorale CSV file into input–target note sequences using sliding windows.""" dataset = tf.data.Dataset.list_files(files_path, shuffle=False)\ .map(lambda geralt: seq2seq_from_chorale(geralt, seq_len, window_shift), AUTOTUNE)\ .flat_map(lambda joe:joe)\ .map(lambda inp, tar: (lookup_fn(inp), lookup_fn(tar)), AUTOTUNE)\ .cache() if shuffle_buffer: dataset = dataset.shuffle(shuffle_buffer, seed=seed) return dataset.batch(batch_size).prefetch(AUTOTUNE)