File size: 337 Bytes
36a0566
 
 
cee9630
36a0566
 
 
 
 
 
 
cee9630
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data import Dataset

class MIDIDataset(Dataset):
    def __init__(self, X_tensor, Y_tensor):
        self.X = X_tensor          # [N, feature_dim]
        self.Y = Y_tensor          # [N, seq_len]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]