File size: 2,494 Bytes
f320de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
import numpy as np
import keras
import glob
import os

AUTOTUNE = tf.data.AUTOTUNE


def NoteEncoder(vocab_path, samples_path=None):
    """Loads or builds a vocabulary from CSV note files and returns IntegerLookup layers for encoding and decoding notes"""
    vocab_file = os.path.join(vocab_path, "vocab.npy")

    if os.path.exists(vocab_file):
        print("vocab.npy found, loading from disk...")
        vocab = np.load(vocab_file)
    elif samples_path is not None:
        print("vocab.npy not found, adapting from sample files...")
        files = glob.glob(os.path.join(samples_path, "*.csv"))
        vocab = np.unique(np.hstack([np.loadtxt(p, delimiter=",", skiprows=1).flatten() for p in files]))
        os.makedirs(vocab_path, exist_ok=True)
        np.save(vocab_file, vocab)
        print(f"vocab adapted and saved to {vocab_file}")
    else:
        raise ValueError("vocab file not found and samples_path not provided.")

    note2id = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab)
    id2note = keras.layers.IntegerLookup(num_oov_indices=0, vocabulary=vocab, invert=True)
    return note2id, id2note, vocab

def parse_and_flatten(line):
    """Parses a line of csv note data and flattens it into individual note tensors."""
    fields = tf.io.decode_csv(line, [0,0,0,0])
    return tf.data.Dataset.from_tensor_slices(fields)
    
def seq2seq_from_chorale(path, seq_len, window_shift):
    """creates seq2seq overlapping windows from a sequence"""
    return tf.data.TextLineDataset(path).skip(1)\
        .flat_map(parse_and_flatten)\
        .window(seq_len + window_shift, shift=window_shift, drop_remainder=True)\
        .flat_map(lambda yushi: yushi.batch(seq_len + window_shift))\
        .map(lambda aiden: (aiden[:-window_shift] , aiden[window_shift:]), AUTOTUNE)

def seq2seq_dataset(files_path, lookup_fn, seq_len=256, window_shift=1,
                    batch_size=64, shuffle_buffer=None, seed=42):
    """Converts a single chorale CSV file into input–target note sequences using sliding windows."""
    dataset = tf.data.Dataset.list_files(files_path, shuffle=False)\
    .map(lambda geralt: seq2seq_from_chorale(geralt, seq_len, window_shift), AUTOTUNE)\
    .flat_map(lambda joe:joe)\
    .map(lambda inp, tar: (lookup_fn(inp), lookup_fn(tar)), AUTOTUNE)\
    .cache()

    if shuffle_buffer:
        dataset = dataset.shuffle(shuffle_buffer, seed=seed)
    
    return dataset.batch(batch_size).prefetch(AUTOTUNE)