Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Assemble weights of a transformer model from a craft residual stack.""" | |
| import dataclasses | |
| from typing import Any, Callable, Optional, Protocol | |
| import chex | |
| import einops | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from tracr.craft import bases | |
| from tracr.craft import transformers | |
| from tracr.craft import vectorspace_fns | |
| from tracr.transformer import encoder | |
| from tracr.transformer import model | |
| class AssembledTransformerModelOutput: | |
| decoded: list[Any] # length T. | |
| unembedded: jax.Array # [B, T] B = 1 always. | |
| layer_outputs: list[jax.Array] # [B, T, D] | |
| residuals: list[jax.Array] # [B, T, D] | |
| attn_logits: list[jax.Array] # [B, T, T, H] | |
| transformer_output: jax.Array # [B, T, D] | |
| input_embeddings: jax.Array | |
| class ModelForward(Protocol): | |
| def __call__( | |
| self, | |
| params: hk.Params, | |
| emb: jax.Array, | |
| ) -> model.CompiledTransformerModelOutput: | |
| """A hk-transformed forward pass through the compiled model.""" | |
| class AssembledTransformerModel: | |
| """Model architecture and parameters from assembling a model.""" | |
| forward: ModelForward | |
| get_compiled_model: Callable[[], model.CompiledTransformerModel] | |
| params: hk.Params | |
| model_config: model.TransformerConfig | |
| residual_labels: list[str] | |
| input_encoder: Optional[encoder.Encoder] = None | |
| output_encoder: Optional[encoder.Encoder] = None | |
| def apply(self, tokens: list[bases.Value]) -> AssembledTransformerModelOutput: | |
| """Returns output from running the model on a set of input tokens.""" | |
| if self.input_encoder: | |
| tokens = self.input_encoder.encode(tokens) | |
| tokens = jnp.array([tokens]) | |
| output = self.forward(self.params, tokens) | |
| decoded = output.unembedded_output[0].tolist() | |
| if self.output_encoder: | |
| decoded = self.output_encoder.decode(decoded) | |
| if self.input_encoder.bos_token: | |
| # Special case for decoding the bos token position, for which the output | |
| # decoder might have unspecified behavior. | |
| decoded = [self.input_encoder.bos_token] + decoded[1:] | |
| return AssembledTransformerModelOutput( | |
| decoded=decoded, | |
| unembedded=output.unembedded_output, | |
| layer_outputs=output.transformer_output.layer_outputs, | |
| residuals=output.transformer_output.residuals, | |
| attn_logits=output.transformer_output.attn_logits, | |
| transformer_output=output.transformer_output.output, | |
| input_embeddings=output.transformer_output.input_embeddings) | |
| class EmbeddingModules: | |
| """Modules for embedding and tokens and positions and unembedding results.""" | |
| token_embed: model.CallableHaikuModule | |
| pos_embed: model.CallableHaikuModule | |
| unembed: model.CallableHaikuModule | |
| def _get_model_config_and_module_names( | |
| craft_model: transformers.SeriesWithResiduals | |
| ) -> tuple[model.TransformerConfig, list[str]]: | |
| """Returns model config and locations (in params) for halflayers.""" | |
| multi_attn_heads: list[list[transformers.AttentionHead]] = [] | |
| mlps: list[transformers.MLP] = [] | |
| module_names: list[str] = [] | |
| candidate_module_names = [] | |
| for layer in range(len(craft_model.blocks)): | |
| candidate_module_names.append(f"transformer/layer_{layer}/attn") | |
| candidate_module_names.append(f"transformer/layer_{layer}/mlp") | |
| candidate_module_names = iter(candidate_module_names) | |
| for module in craft_model.blocks: | |
| if isinstance(module, transformers.MLP): | |
| mlps.append(module) | |
| layer_type = "mlp" | |
| else: | |
| multi_attn_heads.append(list(module.as_multi().heads())) | |
| layer_type = "attn" | |
| # Find next layer with the necessary type. Modules in-between, that are not | |
| # added to module_names will be disabled later by setting all weights to 0. | |
| module_name = next(candidate_module_names) | |
| while layer_type not in module_name: | |
| module_name = next(candidate_module_names) | |
| module_names.append(module_name) | |
| num_layers = int(module_names[-1].split("_")[1].split("/")[0]) + 1 | |
| heads = sum(multi_attn_heads, []) | |
| if multi_attn_heads: | |
| num_heads = max(len(heads) for heads in multi_attn_heads) | |
| key_size = max(max(head.w_qk.matrix.shape) for head in heads) | |
| else: | |
| num_heads, key_size = 1, 1 | |
| if mlps: | |
| mlp_hidden_size = max(mlp.fst.output_space.num_dims for mlp in mlps) | |
| else: | |
| mlp_hidden_size = 1 | |
| model_config = model.TransformerConfig( | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| key_size=key_size, | |
| mlp_hidden_size=mlp_hidden_size, | |
| dropout_rate=0., | |
| activation_function=jax.nn.relu, | |
| layer_norm=False, | |
| causal=False, | |
| ) | |
| return model_config, module_names | |
| def _make_embedding_modules( | |
| residual_space: bases.VectorSpaceWithBasis, | |
| tokens_space: bases.VectorSpaceWithBasis, | |
| indices_space: bases.VectorSpaceWithBasis, | |
| output_space: bases.VectorSpaceWithBasis) -> EmbeddingModules: | |
| """Creates embedding and unembedding modules from vector spaces. | |
| Args: | |
| residual_space: Full residual space of the model. | |
| tokens_space: Subspace to embed tokens to. | |
| indices_space: Subspace to embed indices/position embeddings to. | |
| output_space: Subspace to unembed outputs from. | |
| Returns: | |
| EmbeddingModules containing modules for token embeddings, position | |
| embeddings and unembeddings. | |
| """ | |
| tokens_to_res = vectorspace_fns.project(tokens_space, residual_space) | |
| # If we use the 'one' direction, make sure all inputs have a 1 here | |
| one_dir = bases.BasisDirection("one") | |
| if one_dir in residual_space: | |
| one_to_res = vectorspace_fns.Linear.from_action( | |
| tokens_space, residual_space, | |
| lambda x: residual_space.vector_from_basis_direction(one_dir)) | |
| tokens_to_res = vectorspace_fns.Linear.combine_in_parallel( | |
| [tokens_to_res, one_to_res]) | |
| # Token embeddings. | |
| res_to_out = vectorspace_fns.project(residual_space, output_space) | |
| token_embed = hk.Embed( | |
| embedding_matrix=tokens_to_res.matrix, name="token_embed") | |
| # Positional embeddings. | |
| index_to_res = vectorspace_fns.project(indices_space, residual_space) | |
| # The zeroth position should not have any positional embeddings, | |
| # so we add one line of padding at the zeroth position. | |
| pos_matrix = np.concatenate( | |
| [np.zeros((1, residual_space.num_dims)), index_to_res.matrix], axis=0) | |
| pos_embed = hk.Embed(embedding_matrix=pos_matrix, name="pos_embed") | |
| def unembed(x, use_unembed_argmax): | |
| out = x @ res_to_out.matrix | |
| if use_unembed_argmax: | |
| return jnp.argmax(out, axis=-1) | |
| elif out.shape[-1] == 1: | |
| return out.squeeze(-1) | |
| return out | |
| unembed_mod = hk.to_module(unembed)() | |
| return EmbeddingModules( | |
| token_embed=token_embed, pos_embed=pos_embed, unembed=unembed_mod) | |
| def assemble_craft_model( | |
| craft_model: transformers.SeriesWithResiduals, | |
| tokens_space: bases.VectorSpaceWithBasis, | |
| indices_space: bases.VectorSpaceWithBasis, | |
| output_space: bases.VectorSpaceWithBasis, | |
| categorical_output: bool, | |
| causal: bool = False, | |
| ) -> AssembledTransformerModel: | |
| """Assembles the given components into a Haiku model with parameters. | |
| Args: | |
| craft_model: Model to assemble weights for. | |
| tokens_space: Vectorspace to embed the input tokens to. | |
| indices_space: Vectorspace to embed the indices to (position encodings). | |
| output_space: Vectorspace that the model will write outputs to that should | |
| be unembedded. | |
| categorical_output: Whether the output is categorical. If True, we take an | |
| argmax when unembedding. | |
| causal: Whether to output a causally-masked model. | |
| Returns: | |
| An AssembledTransformerModel that contains the model and parameters of the | |
| assembled transformer. | |
| """ | |
| # TODO(b/255936413): Make embeddings only retain the tokens and indices that | |
| # are actually used. | |
| # TODO(b/255936496): Think about enabling layer norm and reversing it somehow | |
| model_config, module_names = _get_model_config_and_module_names(craft_model) | |
| model_config.causal = causal | |
| residual_space = bases.join_vector_spaces(craft_model.residual_space, | |
| tokens_space, indices_space, | |
| output_space) | |
| residual_labels = [str(basis_dir) for basis_dir in residual_space.basis] | |
| # Build model with embedding and unembedding layers | |
| def get_compiled_model(): | |
| transformer = model.Transformer(model_config) | |
| embed_modules = _make_embedding_modules( | |
| residual_space=residual_space, | |
| tokens_space=tokens_space, | |
| indices_space=indices_space, | |
| output_space=output_space) | |
| return model.CompiledTransformerModel( | |
| transformer=transformer, | |
| token_embed=embed_modules.token_embed, | |
| position_embed=embed_modules.pos_embed, | |
| unembed=embed_modules.unembed, | |
| use_unembed_argmax=categorical_output) | |
| def forward(emb): | |
| compiled_model = get_compiled_model() | |
| return compiled_model(emb, use_dropout=False) | |
| params = forward.init(jax.random.PRNGKey(0), jnp.array([[1, 2, 3]])) | |
| for key in params: | |
| if "transformer" in key: | |
| for par in params[key]: | |
| params[key][par] = np.zeros_like(params[key][par]) | |
| # Assemble attention and MLP weights. | |
| project = lambda space: vectorspace_fns.project(residual_space, space).matrix | |
| for module_name, module in zip(module_names, craft_model.blocks): | |
| if isinstance(module, transformers.MLP): | |
| hidden_size = module.fst.output_space.num_dims | |
| residual_to_fst_input = project(module.fst.input_space) | |
| snd_output_to_residual = project(module.snd.output_space).T | |
| params[f"{module_name}/linear_1"]["w"][:, :hidden_size] = ( | |
| residual_to_fst_input @ module.fst.matrix) | |
| params[f"{module_name}/linear_2"]["w"][:hidden_size, :] = ( | |
| module.snd.matrix @ snd_output_to_residual) | |
| else: # Attention module | |
| query, key, value, linear = [], [], [], [] | |
| for head in module.as_multi().heads(): | |
| key_size = head.w_qk.matrix.shape[1] | |
| query_mat = np.zeros((residual_space.num_dims, model_config.key_size)) | |
| residual_to_query = project(head.w_qk.left_space) | |
| query_mat[:, :key_size] = residual_to_query @ head.w_qk.matrix | |
| query.append(query_mat) | |
| key_mat = np.zeros((residual_space.num_dims, model_config.key_size)) | |
| key_mat[:, :key_size] = project(head.w_qk.right_space) | |
| key.append(key_mat) | |
| value_size = head.w_ov.matrix.shape[1] | |
| value_mat = np.zeros((residual_space.num_dims, model_config.key_size)) | |
| residual_to_ov_input = project(head.w_ov.input_space) | |
| value_mat[:, :value_size] = residual_to_ov_input @ head.w_ov.matrix | |
| value.append(value_mat) | |
| linear_mat = np.zeros((model_config.key_size, residual_space.num_dims)) | |
| linear_mat[:value_size, :] = project(head.w_ov.output_space).T | |
| linear.append(linear_mat) | |
| # Fill up heads that are not used with zero weights | |
| for _ in range(model_config.num_heads - module.as_multi().num_heads): | |
| query.append(np.zeros_like(query[0])) | |
| key.append(np.zeros_like(key[0])) | |
| value.append(np.zeros_like(value[0])) | |
| linear.append(np.zeros_like(linear[0])) | |
| query = einops.rearrange(query, | |
| "heads input output -> input (heads output)") | |
| key = einops.rearrange(key, "heads input output -> input (heads output)") | |
| value = einops.rearrange(value, | |
| "heads input output -> input (heads output)") | |
| linear = einops.rearrange(linear, | |
| "heads input output -> (heads input) output") | |
| params[f"{module_name}/query"]["w"][:, :] = query | |
| params[f"{module_name}/key"]["w"][:, :] = key | |
| params[f"{module_name}/value"]["w"][:, :] = value | |
| params[f"{module_name}/linear"]["w"][:, :] = linear | |
| params = jax.tree_util.tree_map(jnp.array, params) | |
| return AssembledTransformerModel( | |
| forward=forward.apply, | |
| get_compiled_model=get_compiled_model, | |
| params=params, | |
| model_config=model_config, | |
| residual_labels=residual_labels, | |
| ) | |