File size: 1,840 Bytes
1485644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

#include <stdlib.h>

#include <time.h>
#include <sys/time.h>
#include <stdlib.h>



#include "qmamba.h"
#include "tokenizer.h"
#include "sampler.h"











static void help(char *name, void *defaults[]) {
	LOG("Usage: %s [-pntsvh]\n\n", name);
	LOG("Infers a Mamba language model.\n\n");
	LOG("Options:\n");
	LOG("\t-p <seed_text>		  The seed_text to start the generation with.	(default NONE)\n");
	LOG("\t-n <n_predict>	   The number of tokens to predict.			(default %lu)\n", *(uint64_t *)defaults[0]);
	LOG("\t-t <temperature>	 The temperature of the softmax.			 (default %.1f)\n", *(fp32_t *)defaults[1]);
	LOG("\t-s <seed>			The seed for the random number generator.   (default %lu)\n", *(uint64_t *)defaults[2]);
	LOG("\t-v				   Enables verbose mode.					   (default %s)\n", *(bool *)defaults[3] ? "true" : "false");
	LOG("\t-h				   Prints this help message.\n\n");

	exit(EXIT_FAILURE);
}



int main(int argc, char *argv[]) {

	char *seed_text = NULL;
	uint64_t n_predict = 256;


	for (int i = 1; i < argc; i++) {
		if (argv[i][0] != '-') {
			LOG("Invalid argument: %s\n", argv[i]);
			return 1;
		}

		switch (argv[i][1]) {
			case 'p':
				seed_text = argv[++i];
				break;
			case 'n':
				n_predict = strtoull(argv[++i], NULL, 10);
				break;
			case 't':
				sampler.temperature = strtod(argv[++i], NULL);
				break;
			case 's':
				sampler.rng_seed = strtoull(argv[++i], NULL, 10);
				break;
			case 'v':
				sampler.verbose = true;
				break;
			case 'h':
				goto help_;
				break;
			default:
				LOG("Invalid argument: %s\n", argv[i]);
				return 1;
		}
	}

	if (sampler.verbose)
		mamba.log(&mamba);

	if (sampler.generate(&sampler, seed_text, n_predict) == EXIT_FAILURE)
		goto help_;

	return 0;

	help_:
		help(argv[0], (void *[]) {&n_predict, &sampler.temperature, &sampler.rng_seed, &sampler.verbose});		
}