Spaces:
Running
Running
| # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Sampling utilities.""" | |
| import abc | |
| import collections | |
| import inspect | |
| import types | |
| from typing import Any, Callable, List, Optional, Tuple | |
| from absl import logging | |
| from clrs._src import algorithms | |
| from clrs._src import probing | |
| from clrs._src import specs | |
| import jax | |
| import numpy as np | |
| _Array = np.ndarray | |
| _DataPoint = probing.DataPoint | |
| Trajectory = List[_DataPoint] | |
| Trajectories = List[Trajectory] | |
| Algorithm = Callable[..., Any] | |
| Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths']) | |
| FeaturesChunked = collections.namedtuple( | |
| 'Features', ['inputs', 'hints', 'is_first', 'is_last']) | |
| Feedback = collections.namedtuple('Feedback', ['features', 'outputs']) | |
| # CLRS-30 baseline spec. | |
| CLRS30 = types.MappingProxyType({ | |
| 'train': { | |
| 'num_samples': 1000, | |
| 'length': 16, | |
| 'seed': 1, | |
| }, | |
| 'val': { | |
| 'num_samples': 32, | |
| 'length': 16, | |
| 'seed': 2, | |
| }, | |
| 'test': { | |
| 'num_samples': 32, | |
| 'length': 64, | |
| 'seed': 3, | |
| }, | |
| }) | |
| class Sampler(abc.ABC): | |
| """Sampler abstract base class.""" | |
| def __init__( | |
| self, | |
| algorithm: Algorithm, | |
| spec: specs.Spec, | |
| num_samples: int, | |
| *args, | |
| seed: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| """Initializes a `Sampler`. | |
| Args: | |
| algorithm: The algorithm to sample from | |
| spec: The algorithm spec. | |
| num_samples: Number of algorithm unrolls to sample. If positive, all the | |
| samples will be generated in the constructor, and at each call | |
| of the `next` method a batch will be randomly selected among them. | |
| If -1, samples are generated on the fly with each call to `next`. | |
| *args: Algorithm args. | |
| seed: RNG seed. | |
| **kwargs: Algorithm kwargs. | |
| """ | |
| # Use `RandomState` to ensure deterministic sampling across Numpy versions. | |
| self._rng = np.random.RandomState(seed) | |
| self._spec = spec | |
| self._num_samples = num_samples | |
| self._algorithm = algorithm | |
| self._args = args | |
| self._kwargs = kwargs | |
| if num_samples < 0: | |
| logging.warning('Sampling dataset on-the-fly, unlimited samples.') | |
| # Just get an initial estimate of max hint length | |
| self.max_steps = -1 | |
| for _ in range(1000): | |
| data = self._sample_data(*args, **kwargs) | |
| _, probes = algorithm(*data) | |
| _, _, hint = probing.split_stages(probes, spec) | |
| for dp in hint: | |
| assert dp.data.shape[1] == 1 # batching axis | |
| if dp.data.shape[0] > self.max_steps: | |
| self.max_steps = dp.data.shape[0] | |
| else: | |
| logging.info('Creating a dataset with %i samples.', num_samples) | |
| (self._inputs, self._outputs, self._hints, | |
| self._lengths) = self._make_batch(num_samples, spec, 0, algorithm, *args, | |
| **kwargs) | |
| def _make_batch(self, num_samples: int, spec: specs.Spec, min_length: int, | |
| algorithm: Algorithm, *args, **kwargs): | |
| """Generate a batch of data.""" | |
| inputs = [] | |
| outputs = [] | |
| hints = [] | |
| for _ in range(num_samples): | |
| data = self._sample_data(*args, **kwargs) | |
| _, probes = algorithm(*data) | |
| inp, outp, hint = probing.split_stages(probes, spec) | |
| inputs.append(inp) | |
| outputs.append(outp) | |
| hints.append(hint) | |
| if len(hints) % 1000 == 0: | |
| logging.info('%i samples created', len(hints)) | |
| # Batch and pad trajectories to max(T). | |
| inputs = _batch_io(inputs) | |
| outputs = _batch_io(outputs) | |
| hints, lengths = _batch_hints(hints, min_length) | |
| return inputs, outputs, hints, lengths | |
| def next(self, batch_size: Optional[int] = None) -> Feedback: | |
| """Subsamples trajectories from the pre-generated dataset. | |
| Args: | |
| batch_size: Optional batch size. If `None`, returns entire dataset. | |
| Returns: | |
| Subsampled trajectories. | |
| """ | |
| if batch_size: | |
| if self._num_samples < 0: # generate on the fly | |
| inputs, outputs, hints, lengths = self._make_batch( | |
| batch_size, self._spec, self.max_steps, | |
| self._algorithm, *self._args, **self._kwargs) | |
| if hints[0].data.shape[0] > self.max_steps: | |
| logging.warning('Increasing hint lengh from %i to %i', | |
| self.max_steps, hints[0].data.shape[0]) | |
| self.max_steps = hints[0].data.shape[0] | |
| else: | |
| if batch_size > self._num_samples: | |
| raise ValueError( | |
| f'Batch size {batch_size} > dataset size {self._num_samples}.') | |
| # Returns a fixed-size random batch. | |
| indices = self._rng.choice(self._num_samples, (batch_size,), | |
| replace=True) | |
| inputs = _subsample_data(self._inputs, indices, axis=0) | |
| outputs = _subsample_data(self._outputs, indices, axis=0) | |
| hints = _subsample_data(self._hints, indices, axis=1) | |
| lengths = self._lengths[indices] | |
| else: | |
| # Returns the full dataset. | |
| assert self._num_samples >= 0 | |
| inputs = self._inputs | |
| hints = self._hints | |
| lengths = self._lengths | |
| outputs = self._outputs | |
| return Feedback(Features(inputs, hints, lengths), outputs) | |
| def _sample_data(self, length: int, *args, **kwargs) -> List[_Array]: | |
| pass | |
| def _random_sequence(self, length, low=0.0, high=1.0): | |
| """Random sequence.""" | |
| return self._rng.uniform(low=low, high=high, size=(length,)) | |
| def _random_string(self, length, chars=4): | |
| """Random string.""" | |
| return self._rng.randint(0, high=chars, size=(length,)) | |
| def _random_er_graph(self, nb_nodes, p=0.5, directed=False, acyclic=False, | |
| weighted=False, low=0.0, high=1.0): | |
| """Random Erdos-Renyi graph.""" | |
| mat = self._rng.binomial(1, p, size=(nb_nodes, nb_nodes)) | |
| if not directed: | |
| mat *= np.transpose(mat) | |
| elif acyclic: | |
| mat = np.triu(mat, k=1) | |
| p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions | |
| mat = mat[p, :][:, p] | |
| if weighted: | |
| weights = self._rng.uniform(low=low, high=high, size=(nb_nodes, nb_nodes)) | |
| if not directed: | |
| weights *= np.transpose(weights) | |
| weights = np.sqrt(weights + 1e-3) # Add epsilon to protect underflow | |
| mat = mat.astype(float) * weights | |
| return mat | |
| def _random_community_graph(self, nb_nodes, k=4, p=0.5, eps=0.01, | |
| directed=False, acyclic=False, weighted=False, | |
| low=0.0, high=1.0): | |
| """Random perturbed k-community graph.""" | |
| mat = np.zeros((nb_nodes, nb_nodes)) | |
| if k > nb_nodes: | |
| raise ValueError(f'Cannot generate graph of too many ({k}) communities.') | |
| los, his = [], [] | |
| lo = 0 | |
| for i in range(k): | |
| if i == k - 1: | |
| hi = nb_nodes | |
| else: | |
| hi = lo + nb_nodes // k | |
| mat[lo:hi, lo:hi] = self._random_er_graph( | |
| hi - lo, p=p, directed=directed, | |
| acyclic=acyclic, weighted=weighted, | |
| low=low, high=high) | |
| los.append(lo) | |
| his.append(hi) | |
| lo = hi | |
| toggle = self._random_er_graph(nb_nodes, p=eps, directed=directed, | |
| acyclic=acyclic, weighted=weighted, | |
| low=low, high=high) | |
| # Prohibit closing new cycles | |
| for i in range(k): | |
| for j in range(i): | |
| toggle[los[i]:his[i], los[j]:his[j]] *= 0 | |
| mat = np.where(toggle > 0.0, (1.0 - (mat > 0.0)) * toggle, mat) | |
| p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions | |
| mat = mat[p, :][:, p] | |
| return mat | |
| def _random_bipartite_graph(self, n, m, p=0.25): | |
| """Random bipartite graph-based flow network.""" | |
| nb_nodes = n + m + 2 | |
| s = 0 | |
| t = n + m + 1 | |
| mat = np.zeros((nb_nodes, nb_nodes)) | |
| mat[s, 1:n+1] = 1.0 # supersource | |
| mat[n+1:n+m+1, t] = 1.0 # supersink | |
| mat[1:n+1, n+1:n+m+1] = self._rng.binomial(1, p, size=(n, m)) | |
| return mat | |
| def build_sampler( | |
| name: str, | |
| num_samples: int, | |
| *args, | |
| seed: Optional[int] = None, | |
| **kwargs, | |
| ) -> Tuple[Sampler, specs.Spec]: | |
| """Builds a sampler. See `Sampler` documentation.""" | |
| if name not in specs.SPECS or name not in SAMPLERS: | |
| raise NotImplementedError(f'No implementation of algorithm {name}.') | |
| spec = specs.SPECS[name] | |
| algorithm = getattr(algorithms, name) | |
| sampler_class = SAMPLERS[name] | |
| # Ignore kwargs not accepted by the sampler. | |
| sampler_args = inspect.signature(sampler_class._sample_data).parameters # pylint:disable=protected-access | |
| clean_kwargs = {k: kwargs[k] for k in kwargs if k in sampler_args} | |
| if set(clean_kwargs) != set(kwargs): | |
| logging.warning('Ignoring kwargs %s when building sampler class %s', | |
| set(kwargs).difference(clean_kwargs), sampler_class) | |
| sampler = sampler_class(algorithm, spec, num_samples, seed=seed, | |
| *args, **clean_kwargs) | |
| return sampler, spec | |
| class SortingSampler(Sampler): | |
| """Sorting sampler. Generates a random sequence of U[0, 1].""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| arr = self._random_sequence(length=length, low=low, high=high) | |
| return [arr] | |
| class SearchSampler(Sampler): | |
| """Search sampler. Generates a random sequence and target (of U[0, 1]).""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| arr = self._random_sequence(length=length, low=low, high=high) | |
| arr.sort() | |
| x = self._rng.uniform(low=low, high=high) | |
| return [x, arr] | |
| class MaxSubarraySampler(Sampler): | |
| """Maximum subarray sampler. Generates a random sequence of U[-1, 1].""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| low: float = -1., | |
| high: float = 1., | |
| ): | |
| arr = self._random_sequence(length=length, low=low, high=high) | |
| return [arr] | |
| class LCSSampler(Sampler): | |
| """Longest Common Subsequence sampler. Generates two random ATCG strings.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| length_2: Optional[int] = None, | |
| chars: int = 4, | |
| ): | |
| if length_2 is None: | |
| # Assume provided length is total length. | |
| length_2 = length // 2 | |
| length -= length_2 | |
| a = self._random_string(length=length, chars=chars) | |
| b = self._random_string(length=length_2, chars=chars) | |
| return [a, b] | |
| class OptimalBSTSampler(Sampler): | |
| """Optimal BST sampler. Samples array of probabilities, splits it into two.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| ): | |
| tot_length = length + (length + 1) | |
| arr = self._random_sequence(length=tot_length, low=0.0, high=1.0) | |
| arr /= np.sum(arr) | |
| p = arr[:length] | |
| q = arr[length:] | |
| return [p, q] | |
| class ActivitySampler(Sampler): | |
| """Activity sampler. Samples start and finish times from U[0, 1].""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| arr_1 = self._random_sequence(length=length, low=low, high=high) | |
| arr_2 = self._random_sequence(length=length, low=low, high=high) | |
| return [np.minimum(arr_1, arr_2), np.maximum(arr_1, arr_2)] | |
| class TaskSampler(Sampler): | |
| """Task sampler. Samples deadlines (integers) and values (U[0, 1]).""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| max_deadline: Optional[int] = None, | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| if max_deadline is None: | |
| max_deadline = length | |
| d = self._random_string(length=length, chars=max_deadline) + 1 | |
| w = self._random_sequence(length=length, low=low, high=high) | |
| return [d, w] | |
| class DfsSampler(Sampler): | |
| """DFS sampler.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.5,), | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, p=self._rng.choice(p), | |
| directed=True, acyclic=False, weighted=False) | |
| return [graph] | |
| class BfsSampler(Sampler): | |
| """BFS sampler.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.5,), | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, p=self._rng.choice(p), | |
| directed=False, acyclic=False, weighted=False) | |
| source_node = self._rng.choice(length) | |
| return [graph, source_node] | |
| class TopoSampler(Sampler): | |
| """Topological Sorting sampler.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.5,), | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, p=self._rng.choice(p), | |
| directed=True, acyclic=True, weighted=False) | |
| return [graph] | |
| class ArticulationSampler(Sampler): | |
| """Articulation Point sampler.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.2,), | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, p=self._rng.choice(p), directed=False, | |
| acyclic=False, weighted=False) | |
| return [graph] | |
| class MSTSampler(Sampler): | |
| """MST sampler for Kruskal's algorithm.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.2,), # lower p to account for class imbalance | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, | |
| p=self._rng.choice(p), | |
| directed=False, | |
| acyclic=False, | |
| weighted=True, | |
| low=low, | |
| high=high) | |
| return [graph] | |
| class BellmanFordSampler(Sampler): | |
| """Bellman-Ford sampler.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.5,), | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, | |
| p=self._rng.choice(p), | |
| directed=False, | |
| acyclic=False, | |
| weighted=True, | |
| low=low, | |
| high=high) | |
| source_node = self._rng.choice(length) | |
| return [graph, source_node] | |
| class DAGPathSampler(Sampler): | |
| """Sampler for DAG shortest paths.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.5,), | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, | |
| p=self._rng.choice(p), | |
| directed=True, | |
| acyclic=True, | |
| weighted=True, | |
| low=low, | |
| high=high) | |
| source_node = self._rng.choice(length) | |
| return [graph, source_node] | |
| class FloydWarshallSampler(Sampler): | |
| """Sampler for all-pairs shortest paths.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| p: Tuple[float, ...] = (0.5,), | |
| low: float = 0., | |
| high: float = 1., | |
| ): | |
| graph = self._random_er_graph( | |
| nb_nodes=length, | |
| p=self._rng.choice(p), | |
| directed=False, | |
| acyclic=False, | |
| weighted=True, | |
| low=low, | |
| high=high) | |
| return [graph] | |
| class SccSampler(Sampler): | |
| """Sampler for strongly connected component (SCC) tasks.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| k: int = 4, | |
| p: Tuple[float, ...] = (0.5,), | |
| eps: float = 0.01, | |
| ): | |
| graph = self._random_community_graph( | |
| nb_nodes=length, k=k, p=self._rng.choice(p), eps=eps, | |
| directed=True, acyclic=False, weighted=False) | |
| return [graph] | |
| class BipartiteSampler(Sampler): | |
| """Sampler for bipartite matching-based flow networks.""" | |
| def _sample_data( | |
| self, | |
| length: int, | |
| length_2: Optional[int] = None, | |
| p: Tuple[float, ...] = (0.3,), | |
| ): | |
| if length_2 is None: | |
| # Assume provided length is total length. | |
| length_2 = length // 2 | |
| length -= length_2 | |
| graph = self._random_bipartite_graph(n=length, m=length_2, | |
| p=self._rng.choice(p)) | |
| return [graph, length, length_2, 0, length + length_2 + 1] | |
| class MatcherSampler(Sampler): | |
| """String matching sampler; embeds needle in a random haystack.""" | |
| def _sample_data( | |
| self, | |
| length: int, # length of haystack + needle, i.e., total number of nodes | |
| length_needle: Optional[int] = None, | |
| chars: int = 4, | |
| ): | |
| if length_needle is None: | |
| if length < 5: | |
| length_needle = 1 | |
| else: | |
| length_needle = length // 5 | |
| elif length_needle < 0: # randomize needle length | |
| length_needle = self._rng.randint(1, high=1 - length_needle) | |
| length_haystack = length - length_needle | |
| needle = self._random_string(length=length_needle, chars=chars) | |
| haystack = self._random_string(length=length_haystack, chars=chars) | |
| embed_pos = self._rng.choice(length_haystack - length_needle) | |
| haystack[embed_pos:embed_pos + length_needle] = needle | |
| return [haystack, needle] | |
| class SegmentsSampler(Sampler): | |
| """Two-segment sampler of points from (U[0, 1], U[0, 1]).""" | |
| def _sample_data(self, length: int, low: float = 0., high: float = 1.): | |
| del length # There are exactly four endpoints. | |
| # Quick CCW check (ignoring collinearity) for rejection sampling | |
| def ccw(x_a, y_a, x_b, y_b, x_c, y_c): | |
| return (y_c - y_a) * (x_b - x_a) > (y_b - y_a) * (x_c - x_a) | |
| def intersect(xs, ys): | |
| return ccw(xs[0], ys[0], xs[2], ys[2], xs[3], ys[3]) != ccw( | |
| xs[1], ys[1], xs[2], ys[2], xs[3], ys[3]) and ccw( | |
| xs[0], ys[0], xs[1], ys[1], xs[2], ys[2]) != ccw( | |
| xs[0], ys[0], xs[1], ys[1], xs[3], ys[3]) | |
| # Decide (with uniform probability) should this sample intersect | |
| coin_flip = self._rng.binomial(1, 0.5) | |
| xs = self._random_sequence(length=4, low=low, high=high) | |
| ys = self._random_sequence(length=4, low=low, high=high) | |
| while intersect(xs, ys) != coin_flip: | |
| xs = self._random_sequence(length=4, low=low, high=high) | |
| ys = self._random_sequence(length=4, low=low, high=high) | |
| return [xs, ys] | |
| class ConvexHullSampler(Sampler): | |
| """Convex hull sampler of points over a disk of radius r.""" | |
| def _sample_data(self, length: int, origin_x: float = 0., | |
| origin_y: float = 0., radius: float = 2.): | |
| thetas = self._random_sequence(length=length, low=0.0, high=2.0 * np.pi) | |
| rs = radius * np.sqrt( | |
| self._random_sequence(length=length, low=0.0, high=1.0)) | |
| xs = rs * np.cos(thetas) + origin_x | |
| ys = rs * np.sin(thetas) + origin_y | |
| return [xs, ys] | |
| SAMPLERS = { | |
| 'insertion_sort': SortingSampler, | |
| 'bubble_sort': SortingSampler, | |
| 'heapsort': SortingSampler, | |
| 'quicksort': SortingSampler, | |
| 'quickselect': SortingSampler, | |
| 'minimum': SortingSampler, | |
| 'binary_search': SearchSampler, | |
| 'find_maximum_subarray': MaxSubarraySampler, | |
| 'find_maximum_subarray_kadane': MaxSubarraySampler, | |
| 'matrix_chain_order': SortingSampler, | |
| 'lcs_length': LCSSampler, | |
| 'optimal_bst': OptimalBSTSampler, | |
| 'activity_selector': ActivitySampler, | |
| 'task_scheduling': TaskSampler, | |
| 'dfs': DfsSampler, | |
| 'topological_sort': TopoSampler, | |
| 'strongly_connected_components': SccSampler, | |
| 'articulation_points': ArticulationSampler, | |
| 'bridges': ArticulationSampler, | |
| 'bfs': BfsSampler, | |
| 'mst_kruskal': MSTSampler, | |
| 'mst_prim': BellmanFordSampler, | |
| 'bellman_ford': BellmanFordSampler, | |
| 'dag_shortest_paths': DAGPathSampler, | |
| 'dijkstra': BellmanFordSampler, | |
| 'floyd_warshall': FloydWarshallSampler, | |
| 'bipartite_matching': BipartiteSampler, | |
| 'naive_string_matcher': MatcherSampler, | |
| 'kmp_matcher': MatcherSampler, | |
| 'segments_intersect': SegmentsSampler, | |
| 'graham_scan': ConvexHullSampler, | |
| 'jarvis_march': ConvexHullSampler, | |
| } | |
| def _batch_io(traj_io: Trajectories) -> Trajectory: | |
| """Batches a trajectory of input/output samples along the time axis per probe. | |
| Args: | |
| traj_io: An i/o trajectory of `DataPoint`s indexed by time then probe. | |
| Returns: | |
| A |num probes| list of `DataPoint`s with the time axis stacked into `data`. | |
| """ | |
| assert traj_io # non-empty | |
| for sample_io in traj_io: | |
| for i, dp in enumerate(sample_io): | |
| assert dp.data.shape[0] == 1 # batching axis | |
| assert traj_io[0][i].name == dp.name | |
| return jax.tree_util.tree_map(lambda *x: np.concatenate(x), *traj_io) | |
| def _batch_hints( | |
| traj_hints: Trajectories, min_steps: int) -> Tuple[Trajectory, List[int]]: | |
| """Batches a trajectory of hints samples along the time axis per probe. | |
| Unlike i/o, hints have a variable-length time dimension. Before batching, each | |
| trajectory is padded to the maximum trajectory length. | |
| Args: | |
| traj_hints: A hint trajectory of `DataPoints`s indexed by time then probe | |
| min_steps: Hints will be padded at least to this length - if any hint is | |
| longer than this, the greater length will be used. | |
| Returns: | |
| A |num probes| list of `DataPoint`s with the time axis stacked into `data`, | |
| and a |sample| list containing the length of each trajectory. | |
| """ | |
| max_steps = min_steps | |
| assert traj_hints # non-empty | |
| for sample_hint in traj_hints: | |
| for dp in sample_hint: | |
| assert dp.data.shape[1] == 1 # batching axis | |
| if dp.data.shape[0] > max_steps: | |
| max_steps = dp.data.shape[0] | |
| time_and_batch = (max_steps, len(traj_hints)) | |
| # Create zero-filled space for the batched hints, then copy each hint | |
| # up to the corresponding length. | |
| batched_traj = jax.tree_util.tree_map( | |
| lambda x: np.zeros(time_and_batch + x.shape[2:]), | |
| traj_hints[0]) | |
| hint_lengths = np.zeros(len(traj_hints)) | |
| for sample_idx, cur_sample in enumerate(traj_hints): | |
| for i in range(len(cur_sample)): | |
| assert batched_traj[i].name == cur_sample[i].name | |
| cur_data = cur_sample[i].data | |
| cur_length = cur_data.shape[0] | |
| batched_traj[i].data[:cur_length, sample_idx:sample_idx+1] = cur_data | |
| if i > 0: | |
| assert hint_lengths[sample_idx] == cur_length | |
| else: | |
| hint_lengths[sample_idx] = cur_length | |
| return batched_traj, hint_lengths | |
| def _subsample_data( | |
| trajectory: Trajectory, | |
| idx: List[int], | |
| axis: int = 0, | |
| ) -> Trajectory: | |
| """New `Trajectory` where each `DataPoint`'s data is subsampled along axis.""" | |
| sampled_traj = [] | |
| for dp in trajectory: | |
| sampled_data = np.take(dp.data, idx, axis=axis) | |
| sampled_traj.append( | |
| probing.DataPoint(dp.name, dp.location, dp.type_, sampled_data)) | |
| return sampled_traj | |
| def _preprocess_permutations(probes, enforce_permutations): | |
| """Replace should-be permutations with proper permutation pointer + mask.""" | |
| output = [] | |
| for x in probes: | |
| if x.type_ != specs.Type.SHOULD_BE_PERMUTATION: | |
| output.append(x) | |
| continue | |
| assert x.location == specs.Location.NODE | |
| if enforce_permutations: | |
| new_x, mask = probing.predecessor_to_cyclic_predecessor_and_first(x.data) | |
| output.append( | |
| probing.DataPoint( | |
| name=x.name, | |
| location=x.location, | |
| type_=specs.Type.PERMUTATION_POINTER, | |
| data=new_x)) | |
| output.append( | |
| probing.DataPoint( | |
| name=x.name + '_mask', | |
| location=x.location, | |
| type_=specs.Type.MASK_ONE, | |
| data=mask)) | |
| else: | |
| output.append(probing.DataPoint(name=x.name, location=x.location, | |
| type_=specs.Type.POINTER, data=x.data)) | |
| return output | |
| def process_permutations(spec, sample_iterator, enforce_permutations): | |
| """Replace should-be permutations with proper permutation pointer + mask.""" | |
| def _iterate(): | |
| while True: | |
| feedback = next(sample_iterator) | |
| features = feedback.features | |
| inputs = _preprocess_permutations(features.inputs, enforce_permutations) | |
| hints = _preprocess_permutations(features.hints, enforce_permutations) | |
| outputs = _preprocess_permutations(feedback.outputs, enforce_permutations) | |
| features = features._replace(inputs=tuple(inputs), | |
| hints=tuple(hints)) | |
| feedback = feedback._replace(features=features, | |
| outputs=outputs) | |
| yield feedback | |
| new_spec = {} | |
| for k in spec: | |
| if (spec[k][1] == specs.Location.NODE and | |
| spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION): | |
| if enforce_permutations: | |
| new_spec[k] = (spec[k][0], spec[k][1], specs.Type.PERMUTATION_POINTER) | |
| new_spec[k + '_mask'] = (spec[k][0], spec[k][1], specs.Type.MASK_ONE) | |
| else: | |
| new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER) | |
| else: | |
| new_spec[k] = spec[k] | |
| return new_spec, _iterate() | |
| def process_pred_as_input(spec, sample_iterator): | |
| """Move pred_h hint to pred input.""" | |
| def _iterate(): | |
| while True: | |
| feedback = next(sample_iterator) | |
| features = feedback.features | |
| pred_h = [h for h in features.hints if h.name == 'pred_h'] | |
| if pred_h: | |
| assert len(pred_h) == 1 | |
| pred_h = pred_h[0] | |
| hints = [h for h in features.hints if h.name != 'pred_h'] | |
| for i in range(len(features.lengths)): | |
| assert np.sum(np.abs(pred_h.data[1:int(features.lengths[i]), i] - | |
| pred_h.data[0, i])) == 0.0 | |
| inputs = tuple(features.inputs) + ( | |
| probing.DataPoint(name='pred', location=pred_h.location, | |
| type_=pred_h.type_, data=pred_h.data[0]),) | |
| features = features._replace(inputs=tuple(inputs), | |
| hints=tuple(hints)) | |
| feedback = feedback._replace(features=features) | |
| yield feedback | |
| new_spec = {} | |
| for k in spec: | |
| if k == 'pred_h': | |
| assert spec[k] == (specs.Stage.HINT, specs.Location.NODE, | |
| specs.Type.POINTER) | |
| new_spec['pred'] = (specs.Stage.INPUT, specs.Location.NODE, | |
| specs.Type.POINTER) | |
| else: | |
| new_spec[k] = spec[k] | |
| return new_spec, _iterate() | |
| def process_random_pos(sample_iterator, rng): | |
| """Randomize the `pos` input from a sampler. | |
| The `pos` input is, by default, a scalar uniformly spaced between 0 and 1 | |
| across the nodes. The exception are string algorithms (naive_string_matcher, | |
| kmp_string_matcher and lcs_length), where the `pos` sequence is split into | |
| needle and haystack (or first and second string, for lcs_length). Here | |
| we replace the uniformly spaced `pos` with an ordered sequence of random | |
| scalars, or, for string algorithms, two ordered sequences of random scalars. | |
| Args: | |
| sample_iterator: An iterator producing samples with non-random `pos` inputs. | |
| rng: Numpy random generator | |
| Returns: | |
| An iterator returning the samples with randomized `pos` inputs. | |
| """ | |
| def _iterate(): | |
| while True: | |
| feedback = next(sample_iterator) | |
| inputs = feedback.features.inputs | |
| pos, = [x for x in inputs if x.name == 'pos'] | |
| batch_size, num_nodes = pos.data.shape | |
| unsorted = rng.uniform(size=(batch_size, num_nodes)) | |
| new_pos = [] | |
| for i in range(batch_size): # we check one example at a time. | |
| # We find if there are splits in the pos sequence, marked by zeros. | |
| # We know there will always be at least 1 zero, if there's no split. | |
| split, = np.where(pos.data[i] == 0) | |
| split = np.concatenate([split, [num_nodes]]) | |
| # We construct the randomized pos by sorting the random values in each | |
| # split and concatenating them. | |
| new_pos.append( | |
| np.concatenate([np.sort(unsorted[i, split[j]:split[j+1]]) | |
| for j in range(len(split) - 1)])) | |
| pos.data = np.array(new_pos) | |
| inputs = [(pos if x.name == 'pos' else x) for x in inputs] | |
| features = feedback.features._replace(inputs=inputs) | |
| feedback = feedback._replace(features=features) | |
| yield feedback | |
| return _iterate() | |