Spaces:
Paused
Paused
| import torch | |
| from torchvision.datasets import ImageFolder | |
| from torch.utils.data import Dataset | |
| import random | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| from PIL import Image | |
| from datasets import load_dataset | |
| class SortDataset(Dataset): | |
| def __init__(self, N): | |
| self.N = N | |
| def __len__(self): | |
| return 10000000 | |
| def __getitem__(self, idx): | |
| data = torch.zeros(self.N).normal_() | |
| ordering = torch.argsort(data) | |
| inputs = data | |
| return (inputs), (ordering) | |
| class QAMNISTDataset(Dataset): | |
| """A QAMNIST dataset that includes plus and minus operations on MNIST digits.""" | |
| def __init__(self, base_dataset, num_images, num_images_delta, num_repeats_per_input, num_operations, num_operations_delta): | |
| self.base_dataset = base_dataset | |
| self.num_images = num_images | |
| self.num_images_delta = num_images_delta | |
| self.num_images_range = self._calculate_num_images_range() | |
| self.operators = ["+", "-"] | |
| self.num_operations = num_operations | |
| self.num_operations_delta = num_operations_delta | |
| self.num_operations_range = self._calculate_num_operations_range() | |
| self.num_repeats_per_input = num_repeats_per_input | |
| self.current_num_digits = num_images | |
| self.current_num_operations = num_operations | |
| self.modulo_base = 10 | |
| self.output_range = [0, 9] | |
| def _calculate_num_images_range(self): | |
| min_val = self.num_images - self.num_images_delta | |
| max_val = self.num_images + self.num_images_delta | |
| assert min_val >= 1, f"Minimum number of images must be at least 1, got {min_val}" | |
| return [min_val, max_val] | |
| def _calculate_num_operations_range(self): | |
| min_val = self.num_operations - self.num_operations_delta | |
| max_val = self.num_operations + self.num_operations_delta | |
| assert min_val >= 1, f"Minimum number of operations must be at least 1, got {min_val}" | |
| return [min_val, max_val] | |
| def set_num_digits(self, num_digits): | |
| self.current_num_digits = num_digits | |
| def set_num_operations(self, num_operations): | |
| self.current_num_operations = num_operations | |
| def _get_target_and_question(self, targets): | |
| question = [] | |
| equations = [] | |
| num_digits = self.current_num_digits | |
| num_operations = self.current_num_operations | |
| # Select the initial digit | |
| selection_idx = np.random.randint(num_digits) | |
| first_digit = targets[selection_idx] | |
| question.extend([selection_idx] * self.num_repeats_per_input) | |
| # Set current_value to the initial digit (mod is applied in each operation) | |
| current_value = first_digit % self.modulo_base | |
| # For each operation, build an equation line | |
| for _ in range(num_operations): | |
| # Choose the operator ('+' or '-') | |
| operator_idx = np.random.randint(len(self.operators)) | |
| operator = self.operators[operator_idx] | |
| encoded_operator = -(operator_idx + 1) # -1 for '+', -2 for '-' | |
| question.extend([encoded_operator] * self.num_repeats_per_input) | |
| # Choose the next digit | |
| selection_idx = np.random.randint(num_digits) | |
| digit = targets[selection_idx] | |
| question.extend([selection_idx] * self.num_repeats_per_input) | |
| # Compute the new value with immediate modulo reduction | |
| if operator == '+': | |
| new_value = (current_value + digit) % self.modulo_base | |
| else: # operator is '-' | |
| new_value = (current_value - digit) % self.modulo_base | |
| # Build the equation string for this step | |
| equations.append(f"({current_value} {operator} {digit}) mod {self.modulo_base} = {new_value}") | |
| # Update current value for the next operation | |
| current_value = new_value | |
| target = current_value | |
| question_readable = "\n".join(equations) | |
| return target, question, question_readable | |
| def __len__(self): | |
| return len(self.base_dataset) | |
| def __getitem__(self, idx): | |
| images, targets = [],[] | |
| for _ in range(self.current_num_digits): | |
| image, target = self.base_dataset[np.random.randint(self.__len__())] | |
| images.append(image) | |
| targets.append(target) | |
| observations = torch.repeat_interleave(torch.stack(images, 0), repeats=self.num_repeats_per_input, dim=0) | |
| target, question, question_readable = self._get_target_and_question(targets) | |
| return observations, question, question_readable, target | |
| class ImageNet(Dataset): | |
| def __init__(self, which_split, transform): | |
| """ | |
| Most simple form of the custom dataset structure. | |
| Args: | |
| base_dataset (Dataset): The base dataset to sample from. | |
| N (int): The number of images to construct into an observable sequence. | |
| R (int): number of repeats | |
| operators (list): list of operators from which to sample | |
| action to take on observations (str): can be 'global' to compute operator over full observations, or 'select_K', where K=integer. | |
| """ | |
| dataset = load_dataset('imagenet-1k', split=which_split, trust_remote_code=True) | |
| self.transform = transform | |
| self.base_dataset = dataset | |
| def __len__(self): | |
| return len(self.base_dataset) | |
| def __getitem__(self, idx): | |
| data_item = self.base_dataset[idx] | |
| image = self.transform(data_item['image'].convert('RGB')) | |
| target = data_item['label'] | |
| return image, target | |
| class MazeImageFolder(ImageFolder): | |
| """ | |
| A custom dataset class that extends the ImageFolder class. | |
| Args: | |
| root (string): Root directory path. | |
| transform (callable, optional): A function/transform that takes in | |
| a sample and returns a transformed version. | |
| E.g, ``transforms.RandomCrop`` for images. | |
| target_transform (callable, optional): A function/transform that takes | |
| in the target and transforms it. | |
| loader (callable, optional): A function to load an image given its path. | |
| is_valid_file (callable, optional): A function that takes path of an Image file | |
| and check if the file is a valid file (used to check of corrupt files) | |
| Attributes: | |
| classes (list): List of the class names. | |
| class_to_idx (dict): Dict with items (class_name, class_index). | |
| imgs (list): List of (image path, class_index) tuples | |
| """ | |
| def __init__(self, root, transform=None, target_transform=None, | |
| loader=Image.open, | |
| is_valid_file=None, | |
| which_set='train', | |
| augment_p=0.5, | |
| maze_route_length=10, | |
| trunc=False, | |
| expand_range=True): | |
| super(MazeImageFolder, self).__init__(root, transform, target_transform, loader, is_valid_file) | |
| self.which_set = which_set | |
| self.augment_p = augment_p | |
| self.maze_route_length = maze_route_length | |
| self.all_paths = {} | |
| self.trunc = trunc | |
| self.expand_range = expand_range | |
| self._preload() | |
| print('Solving all mazes...') | |
| for index in range(len(self.preloaded_samples)): | |
| path = self.get_solution(self.preloaded_samples[index]) | |
| self.all_paths[index] = path | |
| def _preload(self): | |
| preloaded_samples = [] | |
| with tqdm(total=self.__len__(), initial=0, leave=True, position=0, dynamic_ncols=True) as pbar: | |
| for index in range(self.__len__()): | |
| pbar.set_description('Loading mazes') | |
| path, target = self.samples[index] | |
| sample = self.loader(path) | |
| sample = np.array(sample).astype(np.float32)/255 | |
| preloaded_samples.append(sample) | |
| pbar.update(1) | |
| if self.trunc and index == 999: break | |
| self.preloaded_samples = preloaded_samples | |
| def __len__(self): | |
| if hasattr(self, 'preloaded_samples') and self.preloaded_samples is not None: | |
| return len(self.preloaded_samples) | |
| else: | |
| return super().__len__() | |
| def get_solution(self, x): | |
| x = np.copy(x) | |
| # Find start (red) and end (green) pixel coordinates | |
| start_coords = np.argwhere((x == [1, 0, 0]).all(axis=2)) | |
| end_coords = np.argwhere((x == [0, 1, 0]).all(axis=2)) | |
| if len(start_coords) == 0 or len(end_coords) == 0: | |
| print("Start or end point not found.") | |
| return None | |
| start_y, start_x = start_coords[0] | |
| end_y, end_x = end_coords[0] | |
| current_y, current_x = start_y, start_x | |
| path = [4] * self.maze_route_length | |
| pi = 0 | |
| while (current_y, current_x) != (end_y, end_x): | |
| next_y, next_x = -1, -1 # Initialize to invalid coordinates | |
| direction = -1 # Initialize to an invalid direction | |
| # Check Up | |
| if current_y > 0 and ((x[current_y - 1, current_x] == [0, 0, 1]).all() or (x[current_y - 1, current_x] == [0, 1, 0]).all()): | |
| next_y, next_x = current_y - 1, current_x | |
| direction = 0 | |
| # Check Down | |
| elif current_y < x.shape[0] - 1 and ((x[current_y + 1, current_x] == [0, 0, 1]).all() or (x[current_y + 1, current_x] == [0, 1, 0]).all()): | |
| next_y, next_x = current_y + 1, current_x | |
| direction = 1 | |
| # Check Left | |
| elif current_x > 0 and ((x[current_y, current_x - 1] == [0, 0, 1]).all() or (x[current_y, current_x - 1] == [0, 1, 0]).all()): | |
| next_y, next_x = current_y, current_x - 1 | |
| direction = 2 | |
| # Check Right | |
| elif current_x < x.shape[1] - 1 and ((x[current_y, current_x + 1] == [0, 0, 1]).all() or (x[current_y, current_x + 1] == [0, 1, 0]).all()): | |
| next_y, next_x = current_y, current_x + 1 | |
| direction = 3 | |
| path[pi] = direction | |
| pi += 1 | |
| x[current_y, current_x] = [255,255,255] # mark the current as white to avoid going in circles | |
| current_y, current_x = next_y, next_x | |
| if pi == len(path): | |
| break | |
| return np.array(path) | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (sample, target) where target is class_index of the target class. | |
| """ | |
| sample = np.copy(self.preloaded_samples[index]) | |
| path = np.copy(self.all_paths[index]) | |
| if self.which_set == 'train': | |
| # Randomly rotate -90 or +90 degrees | |
| if random.random() < self.augment_p: | |
| which_rot = random.choice([-1, 1]) | |
| sample = np.rot90(sample, k=which_rot, axes=(0, 1)) | |
| for pi in range(len(path)): | |
| if path[pi] == 0: path[pi] = 3 if which_rot == -1 else 2 | |
| elif path[pi] == 1: path[pi] = 2 if which_rot == -1 else 3 | |
| elif path[pi] == 2: path[pi] = 0 if which_rot == -1 else 1 | |
| elif path[pi] == 3: path[pi] = 1 if which_rot == -1 else 0 | |
| # Random horizontal flip | |
| if random.random() < self.augment_p: | |
| sample = np.fliplr(sample) | |
| for pi in range(len(path)): | |
| if path[pi] == 2: path[pi] = 3 | |
| elif path[pi] == 3: path[pi] = 2 | |
| # Random vertical flip | |
| if random.random() < self.augment_p: | |
| sample = np.flipud(sample) | |
| for pi in range(len(path)): | |
| if path[pi] == 0: path[pi] = 1 | |
| elif path[pi] == 1: path[pi] = 0 | |
| sample = torch.from_numpy(np.copy(sample)).permute(2,0,1) | |
| blue_mask = (sample[0] == 0) & (sample[1] == 0) & (sample[2] == 1) | |
| sample[:, blue_mask] = 1 | |
| target = path | |
| if not self.expand_range: | |
| return sample, target | |
| return (sample*2)-1, (target) | |
| class ParityDataset(Dataset): | |
| def __init__(self, sequence_length=64, length=100000): | |
| self.sequence_length = sequence_length | |
| self.length = length | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| vector = 2 * torch.randint(0, 2, (self.sequence_length,)) - 1 | |
| vector = vector.float() | |
| negatives = (vector == -1).to(torch.long) | |
| cumsum = torch.cumsum(negatives, dim=0) | |
| target = (cumsum % 2 != 0).to(torch.long) | |
| return vector, target | |