Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Add craft model blocks to graph of RASPExpr.""" | |
| from typing import Any, Callable, Optional | |
| import networkx as nx | |
| from tracr.compiler import nodes | |
| from tracr.craft import bases | |
| from tracr.craft.chamber import categorical_attn | |
| from tracr.craft.chamber import categorical_mlp | |
| from tracr.craft.chamber import numerical_mlp | |
| from tracr.craft.chamber import selector_width | |
| from tracr.rasp import rasp | |
| def _transform_fun_to_basis_fun( | |
| fun: Callable[..., Any], | |
| output_direction_name: Optional[str] = None) -> Callable[..., Any]: | |
| """Transforms a function acting on values into one acting on directions.""" | |
| def bases_fun(*args): | |
| values = [d.value for d in args] | |
| result = fun(*values) | |
| if output_direction_name: | |
| return bases.BasisDirection(output_direction_name, result) | |
| return result | |
| return bases_fun | |
| def _check_selector_expression(expr, graph): | |
| """Check graph structure and encodings for an aggregate or selector width.""" | |
| sel_expr = expr.selector | |
| # Check graph structure | |
| assert sel_expr.label in graph.predecessors(expr.label) | |
| assert sel_expr.keys.label in graph.predecessors(sel_expr.label) | |
| assert sel_expr.queries.label in graph.predecessors(sel_expr.label) | |
| if (not rasp.is_categorical(sel_expr.queries) or | |
| not rasp.is_categorical(sel_expr.keys)): | |
| raise ValueError("Selector keys and queries must be categorical.") | |
| def add_craft_components_to_rasp_graph( | |
| graph: nx.DiGraph, | |
| bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"), | |
| one_dir: bases.BasisDirection = bases.BasisDirection("one"), | |
| causal: bool = False, | |
| mlp_exactness: float = 100, | |
| ) -> None: | |
| """Translates expressions to craft blocks and attaches them to the graph. | |
| Sets the `MODEL_BLOCK` attribute for all nodes in `graph`. | |
| Args: | |
| graph: RASP graph with `VALUE_SET` but not `MODEL_BLOCK` attributes. | |
| bos_dir: Basis direction representing beginning of sequence (bos) token. | |
| one_dir: Auxiliary basis direction that must contain 1. | |
| causal: If True, marks attention blocks as causal. | |
| mlp_exactness: Controls the approximation of the MLP layers. | |
| Raises: | |
| ValueError: On invalid input (if `MODEL_BLOCK` is set already, or | |
| `VALUE_SET` is not set already) | |
| NotImplementedError: If the graph contains an unsupported expression. | |
| """ | |
| one_space = bases.VectorSpaceWithBasis([one_dir]) | |
| for node_id, node in graph.nodes.items(): | |
| expr = node[nodes.EXPR] | |
| if not isinstance(expr, rasp.SOp): | |
| continue | |
| if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]: | |
| raise ValueError("Input graph cannot have model blocks set already.") | |
| if nodes.VALUE_SET not in node: | |
| raise ValueError( | |
| "Craft components can only be added after basis inference.") | |
| if expr is rasp.tokens or expr is rasp.indices: | |
| block = None | |
| elif isinstance(expr, rasp.Map): | |
| inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label] | |
| assert inner_expr.label in graph.predecessors(node_id) | |
| input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS]) | |
| output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) | |
| if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr): | |
| basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label) | |
| block = categorical_mlp.map_categorical_mlp( | |
| input_space=input_space, | |
| output_space=output_space, | |
| operation=basis_fun) | |
| elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr): | |
| block = categorical_mlp.map_categorical_to_numerical_mlp( | |
| input_space=input_space, | |
| output_space=output_space, | |
| operation=expr.f, | |
| ) | |
| elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr): | |
| block = numerical_mlp.map_numerical_to_categorical_mlp( | |
| f=expr.f, | |
| input_space=input_space, | |
| output_space=output_space, | |
| input_value_set=inner_node[nodes.VALUE_SET], | |
| one_space=one_space, | |
| hidden_name=f"_hidden_{expr.label}_", | |
| large_number=mlp_exactness) | |
| elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr): | |
| block = numerical_mlp.map_numerical_mlp( | |
| f=expr.f, | |
| input_space=input_space, | |
| output_space=output_space, | |
| input_value_set=inner_node[nodes.VALUE_SET], | |
| one_space=one_space, | |
| hidden_name=f"_hidden_{expr.label}_", | |
| large_number=mlp_exactness) | |
| else: | |
| raise NotImplementedError("Map does no support " | |
| f"in_type '{inner_expr.type}' and" | |
| f" out_type '{expr.type}'!") | |
| elif isinstance(expr, rasp.SequenceMap): | |
| fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label] | |
| snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label] | |
| # Check graph structure | |
| assert fst_expr.label in graph.predecessors(node_id) | |
| assert snd_expr.label in graph.predecessors(node_id) | |
| fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS]) | |
| snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS]) | |
| out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) | |
| if (isinstance(expr, rasp.LinearSequenceMap) and | |
| not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))): | |
| raise NotImplementedError("Linear SequenceMap only supports numerical " | |
| "inputs/outputs.") | |
| elif ( | |
| not isinstance(expr, rasp.LinearSequenceMap) and | |
| not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))): | |
| raise NotImplementedError("(Non-linear) SequenceMap only supports " | |
| "categorical inputs/outputs.") | |
| if isinstance(expr, rasp.LinearSequenceMap): | |
| assert len(fst_space.basis) == 1 | |
| assert len(snd_space.basis) == 1 | |
| assert len(out_space.basis) == 1 | |
| block = numerical_mlp.linear_sequence_map_numerical_mlp( | |
| input1_basis_direction=fst_space.basis[0], | |
| input2_basis_direction=snd_space.basis[0], | |
| output_basis_direction=out_space.basis[0], | |
| input1_factor=expr.fst_fac, | |
| input2_factor=expr.snd_fac, | |
| hidden_name=f"_hidden_{expr.label}_") | |
| elif fst_space == snd_space: | |
| # It's okay to use the local variable expr.f because it is | |
| # only used within the same loop iteration to create the MLP. | |
| # pylint: disable=cell-var-from-loop | |
| basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x), | |
| expr.label) | |
| block = categorical_mlp.map_categorical_mlp( | |
| input_space=fst_space, output_space=out_space, operation=basis_fun) | |
| else: | |
| basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label) | |
| block = categorical_mlp.sequence_map_categorical_mlp( | |
| input1_space=fst_space, | |
| input2_space=snd_space, | |
| output_space=out_space, | |
| operation=basis_fun, | |
| one_space=one_space, | |
| hidden_name=f"_hidden_{expr.label}_") | |
| elif isinstance(expr, rasp.Aggregate): | |
| sel_expr: rasp.Select = expr.selector | |
| agg_expr: rasp.Aggregate = expr | |
| if not isinstance(sel_expr, rasp.Select): | |
| raise TypeError("Compiling composite Selectors is not supported. " | |
| f"Got a {sel_expr}.") | |
| queries = graph.nodes[sel_expr.queries.label] | |
| keys = graph.nodes[sel_expr.keys.label] | |
| sop = graph.nodes[agg_expr.sop.label] | |
| _check_selector_expression(expr, graph) | |
| assert agg_expr.sop.label in graph.predecessors(node_id) | |
| if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr): | |
| raise ValueError( | |
| "sop encoding must match output encoding of the aggregate.") | |
| if rasp.is_categorical(agg_expr) and agg_expr.default is not None: | |
| raise ValueError("Default for a categorical aggregate must be None. " | |
| f"Got {agg_expr.default}") | |
| if rasp.is_numerical(agg_expr) and agg_expr.default != 0: | |
| raise ValueError("Default for a numerical aggregate must be 0. " | |
| f"Got {agg_expr.default}") | |
| bos_space = bases.VectorSpaceWithBasis([bos_dir]) | |
| one_space = bases.VectorSpaceWithBasis([one_dir]) | |
| query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS]) | |
| key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS]) | |
| value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS]) | |
| output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) | |
| # Argument order is different in craft / transformers than RASP selectors | |
| def attn_basis_fn(query: bases.BasisDirection, | |
| key: bases.BasisDirection) -> bool: | |
| # It's okay to use the local variable sel_expr because this function is | |
| # only used within the same loop iteration to create an attention head. | |
| # pylint: disable=cell-var-from-loop | |
| selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) | |
| return selector_basis_fn(key, query) | |
| block = categorical_attn.categorical_attn( | |
| query_space=query_space, | |
| key_space=key_space, | |
| value_space=value_space, | |
| output_space=output_space, | |
| bos_space=bos_space, | |
| one_space=one_space, | |
| attn_fn=attn_basis_fn, | |
| default_output=output_space.null_vector(), | |
| causal=causal, | |
| always_attend_to_bos=False, | |
| use_bos_for_default_output=True, | |
| softmax_coldness=100) | |
| elif isinstance(expr, rasp.SelectorWidth): | |
| sel_expr = expr.selector | |
| queries = graph.nodes[sel_expr.queries.label] | |
| keys = graph.nodes[sel_expr.keys.label] | |
| _check_selector_expression(expr, graph) | |
| bos_space = bases.VectorSpaceWithBasis([bos_dir]) | |
| query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS]) | |
| key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS]) | |
| output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS]) | |
| # Argument order is different in craft / transformers than RASP selectors | |
| def attn_basis_fn(query: bases.BasisDirection, | |
| key: bases.BasisDirection) -> bool: | |
| # It's okay to use the local variable sel_expr because this function is | |
| # only used within the same loop iteration to create an attention head. | |
| selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) # pylint: disable=cell-var-from-loop | |
| return selector_basis_fn(key, query) | |
| block = selector_width.selector_width( | |
| query_space=query_space, | |
| key_space=key_space, | |
| output_space=output_space, | |
| bos_space=bos_space, | |
| one_space=one_space, | |
| attn_fn=attn_basis_fn, | |
| out_value_set=node[nodes.VALUE_SET], | |
| categorical_output=rasp.is_categorical(expr), | |
| causal=False, | |
| softmax_coldness=100, | |
| mlp_large_number=mlp_exactness, | |
| label=expr.label) | |
| else: | |
| raise NotImplementedError(f"Expression {expr} cannot be translated to " | |
| "a model component.") | |
| graph.nodes[node_id][nodes.MODEL_BLOCK] = block | |