File size: 1,840 Bytes
1485644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
#include <stdlib.h>
#include <time.h>
#include <sys/time.h>
#include <stdlib.h>
#include "qmamba.h"
#include "tokenizer.h"
#include "sampler.h"
static void help(char *name, void *defaults[]) {
LOG("Usage: %s [-pntsvh]\n\n", name);
LOG("Infers a Mamba language model.\n\n");
LOG("Options:\n");
LOG("\t-p <seed_text> The seed_text to start the generation with. (default NONE)\n");
LOG("\t-n <n_predict> The number of tokens to predict. (default %lu)\n", *(uint64_t *)defaults[0]);
LOG("\t-t <temperature> The temperature of the softmax. (default %.1f)\n", *(fp32_t *)defaults[1]);
LOG("\t-s <seed> The seed for the random number generator. (default %lu)\n", *(uint64_t *)defaults[2]);
LOG("\t-v Enables verbose mode. (default %s)\n", *(bool *)defaults[3] ? "true" : "false");
LOG("\t-h Prints this help message.\n\n");
exit(EXIT_FAILURE);
}
int main(int argc, char *argv[]) {
char *seed_text = NULL;
uint64_t n_predict = 256;
for (int i = 1; i < argc; i++) {
if (argv[i][0] != '-') {
LOG("Invalid argument: %s\n", argv[i]);
return 1;
}
switch (argv[i][1]) {
case 'p':
seed_text = argv[++i];
break;
case 'n':
n_predict = strtoull(argv[++i], NULL, 10);
break;
case 't':
sampler.temperature = strtod(argv[++i], NULL);
break;
case 's':
sampler.rng_seed = strtoull(argv[++i], NULL, 10);
break;
case 'v':
sampler.verbose = true;
break;
case 'h':
goto help_;
break;
default:
LOG("Invalid argument: %s\n", argv[i]);
return 1;
}
}
if (sampler.verbose)
mamba.log(&mamba);
if (sampler.generate(&sampler, seed_text, n_predict) == EXIT_FAILURE)
goto help_;
return 0;
help_:
help(argv[0], (void *[]) {&n_predict, &sampler.temperature, &sampler.rng_seed, &sampler.verbose});
}
|