Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Create a craft model from a computational graph.""" | |
| import collections | |
| from typing import Dict, List, Sequence | |
| import networkx as nx | |
| from tracr.compiler import nodes | |
| from tracr.craft import bases | |
| from tracr.craft import transformers | |
| from tracr.rasp import rasp | |
| Node = nodes.Node | |
| NodeID = nodes.NodeID | |
| def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node], | |
| node: Node) -> int: | |
| """Returns the lengths of the longest path from sources to node. | |
| Only SOps count towards the length of a path. | |
| Args: | |
| graph: DAG to compute longest path in. | |
| sources: List of starting nodes, longest path will be a maximum over all. | |
| node: Target node. | |
| Returns: | |
| Number of steps needed for the longest path from the source to the node, or | |
| -1 if there is no path from any of the sources to the target node. | |
| """ | |
| if node in sources: | |
| return 0 | |
| def num_sops(path: Sequence[NodeID]) -> int: | |
| num = 0 | |
| for node_id in path: | |
| if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp): | |
| num += 1 | |
| return num | |
| result = -1 | |
| for source in sources: | |
| all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID]) | |
| longest_path_len = max(map(num_sops, all_paths), default=-1) - 1 | |
| if longest_path_len > result: | |
| result = longest_path_len | |
| return result | |
| def _node_is_attn(node: Node) -> bool: | |
| """Returns True if node is an attention layer.""" | |
| return nodes.MODEL_BLOCK in node and isinstance( | |
| node[nodes.MODEL_BLOCK], | |
| (transformers.AttentionHead, transformers.MultiAttentionHead)) | |
| def _node_is_mlp(node: Node) -> bool: | |
| """Returns True if node is an MLP layer.""" | |
| return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK], | |
| transformers.MLP) | |
| def _node_is_residual_block(node: Node) -> bool: | |
| """Returns True if node is a valid residual block (Attn followed by MLP).""" | |
| block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None | |
| if block and isinstance(block, transformers.SeriesWithResiduals): | |
| if len(block.blocks) == 2: | |
| attn, mlp = block.blocks | |
| if (isinstance( | |
| attn, | |
| (transformers.AttentionHead, transformers.MultiAttentionHead)) and | |
| isinstance(mlp, transformers.MLP)): | |
| return True | |
| return False | |
| def _all_attn_nodes(node_list: Sequence[Node]) -> bool: | |
| """Returns True iff all nodes are attention layers (or nodes is empty).""" | |
| for node in node_list: | |
| if not _node_is_attn(node): | |
| return False | |
| return True | |
| def _all_mlp_nodes(node_list: Sequence[Node]) -> bool: | |
| """Returns True iff all nodes are MLP layers (or nodes is empty).""" | |
| for node in node_list: | |
| if not _node_is_mlp(node): | |
| return False | |
| return True | |
| def _allocate_modules_to_layers(graph: nx.DiGraph, | |
| sources: Sequence[Node]) -> Dict[int, int]: | |
| """Allocate all nodes in compute graph to layers. | |
| First, computes the longest path from the input to each node that is a model | |
| component (not input and output nodes). The longest path to a model component | |
| (its "depth") determines a layer in which we can place it while ensuring that | |
| all necessary previous computations have already happened. | |
| This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...] | |
| In the special case where there are only Attention layers at one depth level | |
| and only MLP layers in the next depth layer, they are treated as if there | |
| are at the same depth because attention layers always come before MLP layers | |
| for the same depth. | |
| Args: | |
| graph: RASP graph with craft blocks. | |
| sources: List of input nodes | |
| Returns: | |
| A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ... | |
| are in the order attention, mlp, attention, mlp, ... | |
| """ | |
| layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1) | |
| depth_by_node_id: Dict[int, int] = dict() | |
| nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list) | |
| # Compute depth of all model components (longest path from source to node) | |
| for node_id, node in graph.nodes.items(): | |
| if (_node_is_mlp(node) or _node_is_attn(node) | |
| or _node_is_residual_block(node)): | |
| # Node is a model component | |
| longest_path_len = _get_longest_path_length_to_node(graph, sources, node) | |
| depth_by_node_id[node_id] = longest_path_len | |
| nodes_by_depth[longest_path_len].append(node) | |
| # If at level `depth` there are only attention heads and at level `depths + 1` | |
| # there are only MLPs, we can condense them into one level | |
| # TODO(b/255936816): Think about improving this heuristic. The heuristic is | |
| # not optimal, and only catches very basic opportunities for optimization. It | |
| # is easy to come up with opportunities for optimization that it does not | |
| # catch. | |
| min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys()) | |
| depth = min_depth | |
| while depth < max_depth: | |
| if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes( | |
| nodes_by_depth[depth + 1]): | |
| # Condense by decrementing the depth of all nodes starting from depth+1 | |
| for update_depth in range(depth + 1, max_depth + 1): | |
| for node in nodes_by_depth[update_depth]: | |
| node_id = node[nodes.ID] | |
| depth_by_node_id[node_id] = update_depth - 1 | |
| nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth]) | |
| nodes_by_depth[update_depth] = [] | |
| max_depth -= 1 | |
| depth += 1 | |
| # Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ... | |
| current_layer = 0 | |
| current_depth = 1 | |
| for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]): | |
| while depth > current_depth: | |
| current_depth += 1 | |
| current_layer += 2 | |
| if depth == current_depth: | |
| if _node_is_residual_block(graph.nodes[node_id]): | |
| layer_allocation[node_id] = current_layer | |
| else: | |
| is_mlp = _node_is_mlp(graph.nodes[node_id]) | |
| layer_allocation[node_id] = current_layer + int(is_mlp) | |
| return layer_allocation | |
| def craft_graph_to_model( | |
| graph: nx.DiGraph, | |
| sources: Sequence[Node]) -> transformers.SeriesWithResiduals: | |
| """Translates a RASP graph with craft blocks into a full craft model. | |
| 1. Allocate modules to layers, assuming layers in the order | |
| 2. Creates subspaces for all inputs and outputs, and builds residual stream. | |
| 3. Assembles everything into a craft model and returns it. | |
| Args: | |
| graph: RASP graph with craft blocks. | |
| sources: List of input nodes | |
| Returns: | |
| A craft model that can be compiled to model weights. | |
| Raises: | |
| ValueError: On invalid input (if the craft_graph does not have craft blocks | |
| already specified) | |
| """ | |
| layer_allocation = _allocate_modules_to_layers(graph, sources) | |
| blocks_by_layer = collections.defaultdict(list) | |
| model_blocks = [] | |
| residual_space = bases.VectorSpaceWithBasis([]) | |
| for node_id, layer_no in layer_allocation.items(): | |
| node = graph.nodes[node_id] | |
| block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None | |
| if _node_is_residual_block(node): | |
| assert isinstance(block, transformers.SeriesWithResiduals) | |
| assert len(block.blocks) == 2 | |
| residual_space = bases.join_vector_spaces(residual_space, | |
| block.blocks[0].residual_space, | |
| block.blocks[1].residual_space) | |
| blocks_by_layer[layer_no].append(block.blocks[0]) | |
| blocks_by_layer[layer_no + 1].append(block.blocks[1]) | |
| elif block: | |
| residual_space = bases.join_vector_spaces( | |
| residual_space, node[nodes.MODEL_BLOCK].residual_space) | |
| blocks_by_layer[layer_no].append(block) | |
| for layer_no, layer_blocks in sorted( | |
| blocks_by_layer.items(), key=lambda x: x[0]): | |
| for block in layer_blocks: | |
| block.residual_space = residual_space | |
| if layer_blocks: | |
| if layer_no % 2 == 0: # Attention Layer | |
| multi_head_attn = transformers.MultiAttentionHead(layer_blocks) | |
| model_blocks.append(multi_head_attn) | |
| else: # MLP Layer | |
| parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks) | |
| model_blocks.append(parallel_mlp) | |
| return transformers.SeriesWithResiduals(model_blocks) | |