#include #include #include #include #include "qmamba.h" #include "tokenizer.h" #include "sampler.h" static void help(char *name, void *defaults[]) { LOG("Usage: %s [-pntsvh]\n\n", name); LOG("Infers a Mamba language model.\n\n"); LOG("Options:\n"); LOG("\t-p The seed_text to start the generation with. (default NONE)\n"); LOG("\t-n The number of tokens to predict. (default %lu)\n", *(uint64_t *)defaults[0]); LOG("\t-t The temperature of the softmax. (default %.1f)\n", *(fp32_t *)defaults[1]); LOG("\t-s The seed for the random number generator. (default %lu)\n", *(uint64_t *)defaults[2]); LOG("\t-v Enables verbose mode. (default %s)\n", *(bool *)defaults[3] ? "true" : "false"); LOG("\t-h Prints this help message.\n\n"); exit(EXIT_FAILURE); } int main(int argc, char *argv[]) { char *seed_text = NULL; uint64_t n_predict = 256; for (int i = 1; i < argc; i++) { if (argv[i][0] != '-') { LOG("Invalid argument: %s\n", argv[i]); return 1; } switch (argv[i][1]) { case 'p': seed_text = argv[++i]; break; case 'n': n_predict = strtoull(argv[++i], NULL, 10); break; case 't': sampler.temperature = strtod(argv[++i], NULL); break; case 's': sampler.rng_seed = strtoull(argv[++i], NULL, 10); break; case 'v': sampler.verbose = true; break; case 'h': goto help_; break; default: LOG("Invalid argument: %s\n", argv[i]); return 1; } } if (sampler.verbose) mamba.log(&mamba); if (sampler.generate(&sampler, seed_text, n_predict) == EXIT_FAILURE) goto help_; return 0; help_: help(argv[0], (void *[]) {&n_predict, &sampler.temperature, &sampler.rng_seed, &sampler.verbose}); }