from torch.utils.data import Dataset class MIDIDataset(Dataset): def __init__(self, X_tensor, Y_tensor): self.X = X_tensor # [N, feature_dim] self.Y = Y_tensor # [N, seq_len] def __len__(self): return len(self.X) def __getitem__(self, idx): return self.X[idx], self.Y[idx]