BachNet / src /dataset.py
hoom4n's picture
Upload 17 files
f320de7 verified
import tensorflow as tf
import numpy as np
import keras
import glob
import os
AUTOTUNE = tf.data.AUTOTUNE
def NoteEncoder(vocab_path, samples_path=None):
"""Loads or builds a vocabulary from CSV note files and returns IntegerLookup layers for encoding and decoding notes"""
vocab_file = os.path.join(vocab_path, "vocab.npy")
if os.path.exists(vocab_file):
print("vocab.npy found, loading from disk...")
vocab = np.load(vocab_file)
elif samples_path is not None:
print("vocab.npy not found, adapting from sample files...")
files = glob.glob(os.path.join(samples_path, "*.csv"))
vocab = np.unique(np.hstack([np.loadtxt(p, delimiter=",", skiprows=1).flatten() for p in files]))
os.makedirs(vocab_path, exist_ok=True)
np.save(vocab_file, vocab)
print(f"vocab adapted and saved to {vocab_file}")
else:
raise ValueError("vocab file not found and samples_path not provided.")
note2id = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab)
id2note = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab, invert=True)
return note2id, id2note, vocab
def parse_and_flatten(line):
"""Parses a line of csv note data and flattens it into individual note tensors."""
fields = tf.io.decode_csv(line, [0,0,0,0])
return tf.data.Dataset.from_tensor_slices(fields)
def seq2seq_from_chorale(path, seq_len, window_shift):
"""creates seq2seq overlapping windows from a sequence"""
return tf.data.TextLineDataset(path).skip(1)\
.flat_map(parse_and_flatten)\
.window(seq_len + window_shift, shift=window_shift, drop_remainder=True)\
.flat_map(lambda yushi: yushi.batch(seq_len + window_shift))\
.map(lambda aiden: (aiden[:-window_shift] , aiden[window_shift:]), AUTOTUNE)
def seq2seq_dataset(files_path, lookup_fn, seq_len=256, window_shift=1,
batch_size=64, shuffle_buffer=None, seed=42):
"""Converts a single chorale CSV file into input–target note sequences using sliding windows."""
dataset = tf.data.Dataset.list_files(files_path, shuffle=False)\
.map(lambda geralt: seq2seq_from_chorale(geralt, seq_len, window_shift), AUTOTUNE)\
.flat_map(lambda joe:joe)\
.map(lambda inp, tar: (lookup_fn(inp), lookup_fn(tar)), AUTOTUNE)\
.cache()
if shuffle_buffer:
dataset = dataset.shuffle(shuffle_buffer, seed=seed)
return dataset.batch(batch_size).prefetch(AUTOTUNE)