| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include "ggml/ggml.h" |
| |
|
| | #include "main-mtl.h" |
| |
|
| | #include <cmath> |
| | #include <cstdio> |
| | #include <cstring> |
| | #include <ctime> |
| | #include <fstream> |
| | #include <vector> |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | int mnist_eval( |
| | const char * fname_cgraph, |
| | std::vector<float> digit |
| | ) { |
| | |
| | struct ggml_context * ctx_data = NULL; |
| | struct ggml_context * ctx_eval = NULL; |
| |
|
| | struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval); |
| |
|
| | |
| | static size_t buf_size = 128ull*1024*1024; |
| | static void * buf = malloc(buf_size); |
| |
|
| | struct ggml_init_params params = { |
| | buf_size, |
| | buf, |
| | false, |
| | }; |
| |
|
| | struct ggml_context * ctx_work = ggml_init(params); |
| |
|
| | |
| | auto ctx_mtl = mnist_mtl_init(ctx_data, ctx_eval, ctx_work, &gf); |
| |
|
| | int prediction = -1; |
| |
|
| | for (int i = 0; i < 1; ++i) { |
| | struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "input"); |
| |
|
| | if (i % 2 == 0) { |
| | memcpy(input->data, digit.data(), ggml_nbytes(input)); |
| | } else { |
| | memset(input->data, 0, ggml_nbytes(input)); |
| | } |
| |
|
| | |
| | prediction = mnist_mtl_eval(ctx_mtl, &gf); |
| | } |
| |
|
| | mnist_mtl_free(ctx_mtl); |
| |
|
| | ggml_free(ctx_work); |
| | ggml_free(ctx_data); |
| | ggml_free(ctx_eval); |
| |
|
| | return prediction; |
| | } |
| |
|
| | int main(int argc, char ** argv) { |
| | srand(time(NULL)); |
| | ggml_time_init(); |
| |
|
| | if (argc != 3) { |
| | fprintf(stderr, "Usage: %s models/mnist/mnist.ggml models/mnist/t10k-images.idx3-ubyte\n", argv[0]); |
| | exit(0); |
| | } |
| |
|
| | uint8_t buf[784]; |
| | std::vector<float> digit; |
| |
|
| | |
| | { |
| | std::ifstream fin(argv[2], std::ios::binary); |
| | if (!fin) { |
| | fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]); |
| | return 1; |
| | } |
| |
|
| | |
| | fin.seekg(16 + 784 * (rand() % 10000)); |
| | fin.read((char *) &buf, sizeof(buf)); |
| | } |
| |
|
| | |
| | { |
| | digit.resize(sizeof(buf)); |
| |
|
| | for (int row = 0; row < 28; row++) { |
| | for (int col = 0; col < 28; col++) { |
| | fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_'); |
| | digit[row*28 + col] = ((float)buf[row*28 + col]); |
| | } |
| |
|
| | fprintf(stderr, "\n"); |
| | } |
| |
|
| | fprintf(stderr, "\n"); |
| | } |
| |
|
| | const int prediction = mnist_eval(argv[1], digit); |
| |
|
| | fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction); |
| |
|
| | return 0; |
| | } |
| |
|