mamba / c_inferer /main.c
flopml's picture
added c inferer (mamba.h)
1485644
#include <stdlib.h>
#include <time.h>
#include <sys/time.h>
#include <stdlib.h>
#include "qmamba.h"
#include "tokenizer.h"
#include "sampler.h"
static void help(char *name, void *defaults[]) {
LOG("Usage: %s [-pntsvh]\n\n", name);
LOG("Infers a Mamba language model.\n\n");
LOG("Options:\n");
LOG("\t-p <seed_text> The seed_text to start the generation with. (default NONE)\n");
LOG("\t-n <n_predict> The number of tokens to predict. (default %lu)\n", *(uint64_t *)defaults[0]);
LOG("\t-t <temperature> The temperature of the softmax. (default %.1f)\n", *(fp32_t *)defaults[1]);
LOG("\t-s <seed> The seed for the random number generator. (default %lu)\n", *(uint64_t *)defaults[2]);
LOG("\t-v Enables verbose mode. (default %s)\n", *(bool *)defaults[3] ? "true" : "false");
LOG("\t-h Prints this help message.\n\n");
exit(EXIT_FAILURE);
}
int main(int argc, char *argv[]) {
char *seed_text = NULL;
uint64_t n_predict = 256;
for (int i = 1; i < argc; i++) {
if (argv[i][0] != '-') {
LOG("Invalid argument: %s\n", argv[i]);
return 1;
}
switch (argv[i][1]) {
case 'p':
seed_text = argv[++i];
break;
case 'n':
n_predict = strtoull(argv[++i], NULL, 10);
break;
case 't':
sampler.temperature = strtod(argv[++i], NULL);
break;
case 's':
sampler.rng_seed = strtoull(argv[++i], NULL, 10);
break;
case 'v':
sampler.verbose = true;
break;
case 'h':
goto help_;
break;
default:
LOG("Invalid argument: %s\n", argv[i]);
return 1;
}
}
if (sampler.verbose)
mamba.log(&mamba);
if (sampler.generate(&sampler, seed_text, n_predict) == EXIT_FAILURE)
goto help_;
return 0;
help_:
help(argv[0], (void *[]) {&n_predict, &sampler.temperature, &sampler.rng_seed, &sampler.verbose});
}