| from torch.utils.data import Dataset | |
| class MIDIDataset(Dataset): | |
| def __init__(self, X_tensor, Y_tensor): | |
| self.X = X_tensor # [N, feature_dim] | |
| self.Y = Y_tensor # [N, seq_len] | |
| def __len__(self): | |
| return len(self.X) | |
| def __getitem__(self, idx): | |
| return self.X[idx], self.Y[idx] |