Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Pieces for making transformers.""" | |
| import abc | |
| import dataclasses | |
| from typing import Iterable, List, Optional, Sequence, Union | |
| import numpy as np | |
| from tracr.craft import bases | |
| from tracr.craft import vectorspace_fns | |
| project = vectorspace_fns.project | |
| def _np_softmax(x, axis=-1): | |
| x_max = np.max(x, axis=axis, keepdims=True) | |
| return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True) | |
| def _np_relu(x): | |
| return np.where(x > 0, x, 0) | |
| def relu(x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes)) | |
| class Block(abc.ABC): | |
| """Transformer block, acting on a sequence of vector space elements. | |
| Attributes: | |
| residual_space: Vector space that contains all subspaces the Block interacts | |
| with. This can be either the full residual space of a model or a subspace. | |
| """ | |
| residual_space: bases.VectorSpaceWithBasis | |
| def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| """Applies self to an input.""" | |
| class AttentionHead(Block): | |
| """A transformer attention head.""" | |
| w_qk: vectorspace_fns.ScalarBilinear | |
| w_ov: vectorspace_fns.Linear | |
| residual_space: Optional[bases.VectorSpaceWithBasis] = None | |
| causal: bool = False | |
| def __post_init__(self): | |
| """Infer residual stream and typecheck subspaces.""" | |
| if self.residual_space is None: | |
| self.residual_space = bases.join_vector_spaces(self.w_qk.left_space, | |
| self.w_qk.right_space, | |
| self.w_ov.input_space, | |
| self.w_ov.output_space) | |
| assert self.w_qk.left_space.issubspace(self.residual_space) | |
| assert self.w_qk.right_space.issubspace(self.residual_space) | |
| assert self.w_ov.input_space.issubspace(self.residual_space) | |
| assert self.w_ov.output_space.issubspace(self.residual_space) | |
| def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| assert x in self.residual_space | |
| # seq_len x query_space | |
| queries = x.project(self.w_qk.left_space) | |
| # seq_len x key_space | |
| keys = x.project(self.w_qk.right_space) | |
| attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T | |
| if self.causal: | |
| # The 1 gives us the matrix above the diagonal. | |
| mask = np.triu(np.full_like(attn_matrix, -np.inf), 1) | |
| attn_matrix = attn_matrix + mask | |
| attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to | |
| values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model | |
| magnitudes = attn_weights @ values # seq_len_from, d_model | |
| return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes) | |
| def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| """Wov but acting on the residual space.""" | |
| x = project(self.residual_space, self.w_ov.input_space)(x) | |
| out = self.w_ov(x) | |
| return project(self.w_ov.output_space, self.residual_space)(out) | |
| def num_heads(self) -> int: | |
| return 1 | |
| def as_multi(self) -> "MultiAttentionHead": | |
| return MultiAttentionHead([self]) | |
| class MultiAttentionHead(Block): | |
| """Applies attention heads in parallel.""" | |
| sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]] | |
| def __post_init__(self): | |
| spaces = [block.residual_space for block in self.sub_blocks] | |
| self.residual_space, *others = spaces | |
| assert all(s == self.residual_space for s in others) | |
| def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| # each element is seq_len x embedding | |
| outs = [block.apply(x) for block in self.sub_blocks] | |
| return bases.VectorInBasis.sum(outs) # seq_len x embedding | |
| def num_heads(self) -> int: | |
| return sum(sub_block.num_heads for sub_block in self.sub_blocks) | |
| def heads(self) -> Iterable[AttentionHead]: | |
| for sub_block in self.sub_blocks: | |
| if isinstance(sub_block, AttentionHead): | |
| yield sub_block | |
| elif isinstance(sub_block, MultiAttentionHead): | |
| yield from sub_block.heads() | |
| else: | |
| raise NotImplementedError() | |
| def as_multi(self) -> "MultiAttentionHead": | |
| return self | |
| class MLP(Block): | |
| """A transformer MLP block.""" | |
| fst: vectorspace_fns.Linear | |
| snd: vectorspace_fns.Linear | |
| residual_space: Optional[bases.VectorSpaceWithBasis] = None | |
| def __post_init__(self): | |
| """Typecheck subspaces.""" | |
| if self.residual_space is None: | |
| self.residual_space = bases.join_vector_spaces(self.fst.input_space, | |
| self.snd.output_space) | |
| assert self.fst.output_space == self.snd.input_space | |
| assert self.fst.input_space.issubspace(self.residual_space) | |
| assert self.snd.output_space.issubspace(self.residual_space) | |
| def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| assert x in self.residual_space | |
| x = project(self.residual_space, self.fst.input_space)(x) | |
| hidden = self.fst(x) | |
| hidden = relu(hidden) | |
| out = self.snd(hidden) | |
| return project(self.snd.output_space, self.residual_space)(out) | |
| def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP": | |
| fst = vectorspace_fns.Linear.combine_in_parallel( | |
| [block.fst for block in mlps]) | |
| snd = vectorspace_fns.Linear.combine_in_parallel( | |
| [block.snd for block in mlps]) | |
| return cls(fst=fst, snd=snd, residual_space=None) | |
| # Block that fits into a half-layer, without residual connections. | |
| HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead] | |
| class SeriesWithResiduals(Block): | |
| """A series of blocks with residual connections.""" | |
| blocks: List[HalfLayerBlock] | |
| def __post_init__(self): | |
| spaces = [block.residual_space for block in self.blocks] | |
| self.residual_space = bases.join_vector_spaces(*spaces) | |
| def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: | |
| x = x.project(self.residual_space) | |
| for block in self.blocks: | |
| x_in = x.project(block.residual_space) | |
| x_out = block.apply(x_in).project(self.residual_space) | |
| x = x + x_out | |
| return x | |