| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from collections import defaultdict |
| | import torch |
| | from gemma import config |
| | from gemma import model as gemma_model |
| | import numpy as np |
| | import argparse |
| | import os |
| |
|
| | |
| |
|
| | def check_file_exists(value): |
| | if not os.path.exists(str(value)): |
| | raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) |
| | return value |
| | |
| |
|
| | def check_model_types(value): |
| | if str(value).lower() not in ["2b", "7b"]: |
| | raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) |
| | return value |
| | |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--tokenizer", |
| | dest="tokenizer", |
| | default="models/tokenizer.spm", |
| | help="Location of tokenizer file (.model or .spm)", |
| | type=check_file_exists, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--weights", |
| | dest="weights", |
| | default="models/gemma-2b-it.ckpt", |
| | help="Location of input checkpoint file (.ckpt)", |
| | type=check_file_exists, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--output_file", |
| | dest="output_file", |
| | default="2bit-f32.sbs", |
| | help="Location to write converted weights", |
| | type=str, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--model_type", |
| | dest="model_type", |
| | default="2b", |
| | help="Model size / type (2b, 7b)", |
| | type=check_model_types, |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| |
|
| | TRANSFORMATIONS = { |
| | "2b":defaultdict( |
| | lambda: lambda x: x, |
| | { |
| | "embedder.weight": lambda x: x, |
| | "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), |
| | "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), |
| | "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| | "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| | "mlp.down_proj.weight": lambda x: x, |
| | } |
| | ), |
| | "7b":defaultdict( |
| | lambda: lambda x: x, |
| | { |
| | "embedder.weight": lambda x: x, |
| | "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), |
| | "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), |
| | "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| | "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| | "mlp.down_proj.weight": lambda x: x, |
| | } |
| | ), |
| | } |
| |
|
| | VALIDATIONS = { |
| | "2b": { |
| | "embedder.weight": lambda x: x.shape == (256000, 2048), |
| | "model.norm.weight": lambda x: x.shape == (2048,), |
| | "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), |
| | "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), |
| | "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| | "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| | "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), |
| | "input_layernorm.weight": lambda x: x.shape == (2048,), |
| | "post_attention_layernorm.weight": lambda x: x.shape == (2048,), |
| | }, |
| | "7b": { |
| | "embedder.weight": lambda x: x.shape == (256000, 3072), |
| | "model.norm.weight": lambda x: x.shape == (3072,), |
| | "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), |
| | "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), |
| | "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| | "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| | "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), |
| | "input_layernorm.weight": lambda x: x.shape == (3072,), |
| | "post_attention_layernorm.weight": lambda x: x.shape == (3072,), |
| | }, |
| | } |
| |
|
| |
|
| | def param_names(num_hidden_layers: int): |
| | """Return parameter names in the order they are expected for deserialization.""" |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | names = [ |
| | ("embedder.weight", ) * 2, |
| | ("model.norm.weight", ) * 2 |
| | ] |
| | layer_params = [ |
| | "self_attn.o_proj.weight", |
| | "self_attn.qkv_proj.weight", |
| | "mlp.gate_proj.weight", |
| | "mlp.up_proj.weight", |
| | "mlp.down_proj.weight", |
| | "input_layernorm.weight", |
| | "post_attention_layernorm.weight", |
| | ] |
| | |
| | for layer in range(num_hidden_layers): |
| | for layer_param in layer_params: |
| | names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] |
| | return names |
| |
|
| |
|
| | def convert_weights(): |
| | model_type = args.model_type |
| | output_file = args.output_file |
| | |
| | model_config = config.get_model_config(model_type) |
| | model_config.dtype = "float32" |
| | model_config.tokenizer = args.tokenizer |
| | device = torch.device("cpu") |
| | torch.set_default_dtype(torch.float) |
| | model = gemma_model.GemmaForCausalLM(model_config) |
| | |
| | model.load_weights(args.weights) |
| | model.to(device).eval() |
| | |
| | model_dict = dict(model.named_parameters()) |
| | param_order = param_names(model_config.num_hidden_layers) |
| |
|
| | all_ok = True |
| | print("Checking transformations ...") |
| | for name, layer_name in param_order: |
| | arr = model_dict[name].detach().numpy() |
| | arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| | check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| |
|
| | if check == "FAILED": |
| | all_ok = False |
| | print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| |
|
| | if all_ok: |
| | print("Writing parameters ...") |
| | gate = None |
| | with open(output_file, "wb") as bin_handle: |
| | for name, layer_name in param_order: |
| | arr = model_dict[name].detach().numpy() |
| | arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| | check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| | print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| | arr.flatten().astype(np.float32).tofile(bin_handle) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | convert_weights() |
| | print("Done") |
| |
|