Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Inferring the vector spaces taken on by certain operations.""" | |
| import dataclasses | |
| import itertools | |
| import networkx as nx | |
| from tracr.compiler import nodes | |
| from tracr.craft import bases | |
| from tracr.rasp import rasp | |
| from tracr.utils import errors | |
| Node = nodes.Node | |
| class InferBasesOutput: | |
| graph: nx.DiGraph | |
| def infer_bases( | |
| graph: nx.DiGraph, | |
| sink: Node, | |
| vocab: set[rasp.Value], | |
| max_seq_len: int, | |
| ) -> None: | |
| """Infers in-place the possible output values and vector bases of the SOps.""" | |
| def compute_value_set(sop: rasp.SOp) -> set[rasp.Value]: | |
| """Computes value set using already-computed predecessor value sets.""" | |
| if sop is rasp.tokens: | |
| return vocab | |
| elif sop is rasp.indices: | |
| return set(range(max_seq_len)) | |
| elif isinstance(sop, rasp.SelectorWidth): | |
| return set(range(0, max_seq_len + 1)) | |
| elif isinstance(sop, rasp.Full): | |
| return {sop.fill} | |
| elif isinstance(sop, rasp.Map): | |
| inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET] | |
| out = set() | |
| for x in inner_value_set: | |
| res = errors.ignoring_arithmetic_errors(sop.f)(x) | |
| if res is not None: | |
| out.add(res) | |
| return out | |
| elif isinstance(sop, rasp.SequenceMap): | |
| f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) | |
| fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET] | |
| snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET] | |
| out = set() | |
| for l, r in itertools.product(fst_value_set, snd_value_set): | |
| res = f_ignore_error(l, r) | |
| if res is not None: | |
| out.add(res) | |
| return out | |
| elif isinstance(sop, rasp.Aggregate): | |
| if rasp.is_categorical(sop): | |
| # Simply pass on the value set of the underlying S-Op. | |
| return graph.nodes[sop.sop.label][nodes.VALUE_SET] | |
| elif rasp.is_numerical(sop): | |
| # TODO(b/255936408): This doesn't work if we average arbitrary values. | |
| # But most examples only average binary variables. | |
| sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET] | |
| if {int(x) for x in sop_value_set} != {0, 1}: | |
| raise NotImplementedError( | |
| "Attention patterns can currently only " | |
| "average binary variables. Not:", sop_value_set) | |
| value_set = set() | |
| for value in sop_value_set: | |
| for length in range(1, max_seq_len + 1): | |
| value_set.add(value / length) | |
| return value_set | |
| raise ValueError(f"Unsupported S-Op: {sop}") | |
| for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]): | |
| expr = graph.nodes[node_id][nodes.EXPR] | |
| if not isinstance(expr, rasp.SOp): | |
| # Only S-Ops have output vector spaces. | |
| continue | |
| value_set = compute_value_set(expr) | |
| graph.nodes[node_id][nodes.VALUE_SET] = value_set | |
| if rasp.is_categorical(expr): | |
| out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set) | |
| elif rasp.is_numerical(expr): | |
| out_space = bases.VectorSpaceWithBasis.from_names([expr.label]) | |
| else: | |
| raise ValueError(f"Unsupported S-Op type: {expr.type}") | |
| graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis | |