| | import numpy as np |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from sklearn.model_selection import train_test_split |
| | from torch.utils.data.dataset import IterableDataset |
| | from collections import deque |
| |
|
| | from numpy.random import default_rng |
| |
|
| | DATA = np.load( |
| | |
| | "sudoku_reshaped_3_million.npz" |
| | ) |
| |
|
| | rng = np.random.default_rng() |
| |
|
| |
|
| | def get_datasets( |
| | add_proba_fill=False, train_size=1280 // 2, test_size=1280 // 2, max_holes=None |
| | ): |
| | quizzes = DATA["quizzes"][: train_size + test_size] |
| | solutions = DATA["solutions"][: train_size + test_size] |
| | X = quizzes |
| | if max_holes: |
| | while True: |
| | x_holes = X[:, 1].sum(-1) == 0 |
| | x_nb_holes = x_holes.sum((1, 2)) |
| | mask_x_max_holes = x_nb_holes > max_holes |
| | if not any(mask_x_max_holes): |
| | break |
| | for idx_x in np.nonzero(mask_x_max_holes)[0]: |
| | sub_x_holes = x_holes[idx_x] |
| | idx_fill = rng.choice(np.transpose(np.nonzero(sub_x_holes))) |
| | X[idx_x, :, idx_fill[0], idx_fill[1], :] = solutions[ |
| | idx_x, :, idx_fill[0], idx_fill[1], : |
| | ] |
| | X = X.reshape(X.shape[0], 2, 9 * 9 * 9) |
| | solutions = solutions.reshape(solutions.shape[0], 2, 9 * 9 * 9) |
| |
|
| | X_train, X_test, solutions_train, solutions_test = train_test_split( |
| | X, solutions, test_size=test_size, random_state=42 |
| | ) |
| | if add_proba_fill: |
| | X_train_bis = X_train.copy() |
| | mask = solutions_train == 1 |
| | X_train_bis[mask] = np.random.randint(0, 2, size=mask.sum()) |
| | X_train = np.concatenate([X_train, X_train_bis]) |
| | solutions_train = np.concatenate([solutions_train, solutions_train]) |
| |
|
| | train = torch.utils.data.TensorDataset( |
| | torch.Tensor(X_train), torch.Tensor(solutions_train) |
| | ) |
| | test = torch.utils.data.TensorDataset( |
| | torch.Tensor(X_test), torch.Tensor(solutions_test) |
| | ) |
| | return train, test |
| |
|
| |
|
| | train_dataset, test_dataset = get_datasets() |
| |
|
| |
|
| | def data_loader(batch_size=32, add_proba_fill=False): |
| | train, test = get_datasets(add_proba_fill=add_proba_fill) |
| |
|
| | train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size) |
| |
|
| | test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size) |
| |
|
| | return train_loader, test_loader |
| |
|
| |
|
| | class DataIterBuffer(IterableDataset): |
| | def __init__(self, raw_dataset=[], buffer_optim=50, prop_new=0.1, seed=1): |
| | self.raw_dataset = raw_dataset |
| | |
| | self.buffer = deque() |
| | self.buffer_optim = buffer_optim |
| | self.prop_new = prop_new |
| | self.rng = default_rng(seed=seed) |
| | self.idx_dataset = 0 |
| |
|
| | def __iter__(self): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | while True: |
| | if (np.random.random() < self.prop_new) and ( |
| | len(self.buffer) <= self.buffer_optim |
| | ): |
| | if self.idx_dataset >= len(self.raw_dataset): |
| | if len(self.buffer) != 0: |
| | yield self.buffer.popleft() |
| | else: |
| | break |
| | else: |
| | yield self.raw_dataset[self.idx_dataset] |
| | self.idx_dataset += 1 |
| | else: |
| | if len(self.buffer) != 0: |
| | yield self.buffer.popleft() |
| | else: |
| | if self.idx_dataset >= len(self.raw_dataset): |
| | break |
| | else: |
| | yield self.raw_dataset[self.idx_dataset] |
| | self.idx_dataset += 1 |
| |
|
| | def append(self, X, Y) -> None: |
| | """Add experience to the buffer. |
| | |
| | Args: |
| | experience: tuple (state, action, reward, done, new_state) |
| | """ |
| |
|
| | X[Y == 0] = 0 |
| | mask = ~(X == Y).view(-1, 2 * 729).all(dim=1) |
| |
|
| | for x, y in zip(X[mask], Y[mask]): |
| | self.buffer.append((x, y)) |
| |
|
| | def __len__(self): |
| | return len(self.buffer) + len(self.raw_dataset) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|