Diva / utility /dataset.py
rrayy
Changes to be committed: 전처리 중 EOS, padding index 변경 100,-1 -> 차원별, 16
cee9630
from torch.utils.data import Dataset
class MIDIDataset(Dataset):
def __init__(self, X_tensor, Y_tensor):
self.X = X_tensor # [N, feature_dim]
self.Y = Y_tensor # [N, seq_len]
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.Y[idx]