| # Convert MNIS h5 transformer model to ggml format | |
| # | |
| # Load the (state_dict) saved model using PyTorch | |
| # Iterate over all variables and write them to a binary file. | |
| # | |
| # For each variable, write the following: | |
| # - Number of dimensions (int) | |
| # - Name length (int) | |
| # - Dimensions (int[n_dims]) | |
| # - Name (char[name_length]) | |
| # - Data (float[n_dims]) | |
| # | |
| # At the start of the ggml file we write the model parameters | |
| import sys | |
| import struct | |
| import json | |
| import numpy as np | |
| import re | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.datasets as dsets | |
| import torchvision.transforms as transforms | |
| from torch.autograd import Variable | |
| if len(sys.argv) != 2: | |
| print("Usage: convert-h5-to-ggml.py model\n") | |
| sys.exit(1) | |
| state_dict_file = sys.argv[1] | |
| fname_out = "models/mnist/ggml-model-f32.bin" | |
| state_dict = torch.load(state_dict_file, map_location=torch.device('cpu')) | |
| #print (model) | |
| list_vars = state_dict | |
| print (list_vars) | |
| fout = open(fname_out, "wb") | |
| fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex | |
| for name in list_vars.keys(): | |
| data = list_vars[name].squeeze().numpy() | |
| print("Processing variable: " + name + " with shape: ", data.shape) | |
| n_dims = len(data.shape); | |
| fout.write(struct.pack("i", n_dims)) | |
| data = data.astype(np.float32) | |
| for i in range(n_dims): | |
| fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) | |
| # data | |
| data.tofile(fout) | |
| fout.close() | |
| print("Done. Output file: " + fname_out) | |
| print("") | |