| |
|
|
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <ctype.h> |
| #include <stdint.h> |
| #include <time.h> |
| #include <math.h> |
| #include <string.h> |
| #include <fcntl.h> |
| #if defined _WIN32 |
| #include "win.h" |
| #else |
| #include <unistd.h> |
| #include <sys/mman.h> |
| #endif |
| |
| |
| int GS = 0; |
|
|
| |
| |
|
|
| typedef struct { |
| int dim; |
| int hidden_dim; |
| int n_layers; |
| int n_heads; |
| int n_kv_heads; |
| int vocab_size; |
| int seq_len; |
| } Config; |
|
|
| typedef struct { |
| int8_t* q; |
| float* s; |
| } QuantizedTensor; |
|
|
| typedef struct { |
| |
| QuantizedTensor *q_tokens; |
| float* token_embedding_table; |
|
|
| |
| float* rms_att_weight; |
| float* rms_ffn_weight; |
| |
| QuantizedTensor *wq; |
| QuantizedTensor *wk; |
| QuantizedTensor *wv; |
| QuantizedTensor *wo; |
| |
| QuantizedTensor *w1; |
| QuantizedTensor *w2; |
| QuantizedTensor *w3; |
| |
| float* rms_final_weight; |
| |
| QuantizedTensor *wcls; |
| } TransformerWeights; |
|
|
| typedef struct { |
| |
| float *x; |
| float *xb; |
| float *xb2; |
| float *hb; |
| float *hb2; |
| QuantizedTensor xq; |
| QuantizedTensor hq; |
| float *q; |
| float *k; |
| float *v; |
| float *att; |
| float *logits; |
| |
| float* key_cache; |
| float* value_cache; |
| } RunState; |
|
|
| typedef struct { |
| Config config; |
| TransformerWeights weights; |
| RunState state; |
| |
| int fd; |
| float* data; |
| ssize_t file_size; |
| } Transformer; |
|
|
| void malloc_run_state(RunState* s, Config* p) { |
| |
| int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; |
| s->x = calloc(p->dim, sizeof(float)); |
| s->xb = calloc(p->dim, sizeof(float)); |
| s->xb2 = calloc(p->dim, sizeof(float)); |
| s->hb = calloc(p->hidden_dim, sizeof(float)); |
| s->hb2 = calloc(p->hidden_dim, sizeof(float)); |
| s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) }; |
| s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) }; |
| s->q = calloc(p->dim, sizeof(float)); |
| s->k = calloc(kv_dim, sizeof(float)); |
| s->v = calloc(kv_dim, sizeof(float)); |
| s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); |
| s->logits = calloc(p->vocab_size, sizeof(float)); |
| s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); |
| s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); |
| |
| if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q |
| || !s->k || !s->v || !s->att || !s->logits || !s->key_cache |
| || !s->value_cache) { |
| fprintf(stderr, "malloc failed!\n"); |
| exit(EXIT_FAILURE); |
| } |
| } |
|
|
| void free_run_state(RunState* s) { |
| free(s->x); |
| free(s->xb); |
| free(s->xb2); |
| free(s->hb); |
| free(s->hb2); |
| free(s->xq.q); |
| free(s->xq.s); |
| free(s->hq.q); |
| free(s->hq.s); |
| free(s->q); |
| free(s->k); |
| free(s->v); |
| free(s->att); |
| free(s->logits); |
| free(s->key_cache); |
| free(s->value_cache); |
| } |
|
|
| |
| |
|
|
| void dequantize(QuantizedTensor *qx, float* x, int n) { |
| for (int i = 0; i < n; i++) { |
| x[i] = qx->q[i] * qx->s[i / GS]; |
| } |
| } |
|
|
| void quantize(QuantizedTensor *qx, float* x, int n) { |
| int num_groups = n / GS; |
| float Q_MAX = 127.0f; |
|
|
| for (int group = 0; group < num_groups; group++) { |
|
|
| |
| float wmax = 0.0; |
| for (int i = 0; i < GS; i++) { |
| float val = fabs(x[group * GS + i]); |
| if (val > wmax) { |
| wmax = val; |
| } |
| } |
|
|
| |
| float scale = wmax / Q_MAX; |
| qx->s[group] = scale; |
|
|
| |
| for (int i = 0; i < GS; i++) { |
| float quant_value = x[group * GS + i] / scale; |
| int8_t quantized = (int8_t) round(quant_value); |
| qx->q[group * GS + i] = quantized; |
| } |
| } |
| } |
|
|
| |
| QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) { |
| void *p = *ptr; |
| QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor)); |
| for(int i=0; i<n; i++) { |
| |
| res[i].q = (int8_t*)p; |
| p = (int8_t*)p + size_each; |
| |
| res[i].s = (float*)p; |
| p = (float*)p + size_each / GS; |
| } |
| *ptr = p; |
| return res; |
| } |
|
|
| void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) { |
| int head_size = p->dim / p->n_heads; |
| |
| float* fptr = (float*) ptr; |
| w->rms_att_weight = fptr; |
| fptr += p->n_layers * p->dim; |
| w->rms_ffn_weight = fptr; |
| fptr += p->n_layers * p->dim; |
| w->rms_final_weight = fptr; |
| fptr += p->dim; |
|
|
| |
| ptr = (void*)fptr; |
| w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim); |
| |
| w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float)); |
| dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim); |
|
|
| w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size)); |
| w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size)); |
| w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size)); |
| w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim); |
|
|
| w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim); |
| w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim); |
| w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim); |
|
|
| w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size); |
| } |
|
|
| void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, |
| int* fd, float** data, ssize_t* file_size) { |
| FILE *file = fopen(checkpoint, "rb"); |
| if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); } |
| |
| uint32_t magic_number; |
| if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { exit(EXIT_FAILURE); } |
| if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); } |
| |
| int version; |
| if (fread(&version, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); } |
| if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); } |
| int header_size = 256; |
| |
| if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); } |
| |
| uint8_t shared_classifier; |
| if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); } |
| int group_size; |
| if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); } |
| GS = group_size; |
| |
| #if defined _WIN32 |
| _fseeki64(file, 0, SEEK_END); |
| *file_size = _ftelli64(file); |
| #else |
| fseek(file, 0, SEEK_END); |
| *file_size = ftell(file); |
| #endif |
| fclose(file); |
| |
| *fd = open(checkpoint, O_RDONLY); |
| if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); } |
| *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); |
| if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } |
| void* weights_ptr = ((char*)*data) + header_size; |
| memory_map_weights(weights, config, weights_ptr, shared_classifier); |
| } |
|
|
| void build_transformer(Transformer *t, char* checkpoint_path) { |
| |
| read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); |
| |
| malloc_run_state(&t->state, &t->config); |
| } |
|
|
| void free_transformer(Transformer* t) { |
| |
| free(t->weights.q_tokens); |
| free(t->weights.token_embedding_table); |
| free(t->weights.wq); |
| free(t->weights.wk); |
| free(t->weights.wv); |
| free(t->weights.wo); |
| free(t->weights.w1); |
| free(t->weights.w2); |
| free(t->weights.w3); |
| if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); } |
| |
| if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); } |
| if (t->fd != -1) { close(t->fd); } |
| |
| free_run_state(&t->state); |
| } |
|
|
| |
| |
|
|
| void rmsnorm(float* o, float* x, float* weight, int size) { |
| |
| float ss = 0.0f; |
| for (int j = 0; j < size; j++) { |
| ss += x[j] * x[j]; |
| } |
| ss /= size; |
| ss += 1e-5f; |
| ss = 1.0f / sqrtf(ss); |
| |
| for (int j = 0; j < size; j++) { |
| o[j] = weight[j] * (ss * x[j]); |
| } |
| } |
|
|
| void softmax(float* x, int size) { |
| |
| float max_val = x[0]; |
| for (int i = 1; i < size; i++) { |
| if (x[i] > max_val) { |
| max_val = x[i]; |
| } |
| } |
| |
| float sum = 0.0f; |
| for (int i = 0; i < size; i++) { |
| x[i] = expf(x[i] - max_val); |
| sum += x[i]; |
| } |
| |
| for (int i = 0; i < size; i++) { |
| x[i] /= sum; |
| } |
| } |
|
|
| void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) { |
| |
| |
| |
|
|
| int i; |
| #pragma omp parallel for private(i) |
| for (i = 0; i < d; i++) { |
|
|
| float val = 0.0f; |
| int32_t ival = 0; |
| int in = i * n; |
|
|
| |
| int j; |
| for (j = 0; j <= n - GS; j += GS) { |
| for (int k = 0; k < GS; k++) { |
| ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]); |
| } |
| val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS]; |
| ival = 0; |
| } |
|
|
| xout[i] = val; |
| } |
| } |
|
|
| float* forward(Transformer* transformer, int token, int pos) { |
|
|
| |
| Config* p = &transformer->config; |
| TransformerWeights* w = &transformer->weights; |
| RunState* s = &transformer->state; |
| float *x = s->x; |
| int dim = p->dim; |
| int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; |
| int kv_mul = p->n_heads / p->n_kv_heads; |
| int hidden_dim = p->hidden_dim; |
| int head_size = dim / p->n_heads; |
|
|
| |
| memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float)); |
|
|
| |
| for(unsigned long long l = 0; l < p->n_layers; l++) { |
|
|
| |
| rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); |
|
|
| |
| quantize(&s->xq, s->xb, dim); |
| matmul(s->q, &s->xq, w->wq + l, dim, dim); |
| matmul(s->k, &s->xq, w->wk + l, dim, kv_dim); |
| matmul(s->v, &s->xq, w->wv + l, dim, kv_dim); |
|
|
| |
| for (int i = 0; i < p->n_heads; i++) { |
| for (int j = 0; j < head_size; j += 2) { |
| float freq = 1.0f / powf(500000.0f, (float)j / (float)head_size); |
| float val = pos * freq; |
| float fcr = cosf(val); |
| float fci = sinf(val); |
| float q0 = s->q[i * head_size + j]; |
| float q1 = s->q[i * head_size + j + 1]; |
| s->q[i * head_size + j] = q0 * fcr - q1 * fci; |
| s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr; |
| if (i < p->n_kv_heads) { |
| float k0 = s->k[i * head_size + j]; |
| float k1 = s->k[i * head_size + j + 1]; |
| s->k[i * head_size + j] = k0 * fcr - k1 * fci; |
| s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr; |
| } |
| } |
| } |
|
|
| |
| int loff = l * p->seq_len * kv_dim; |
| float* key_cache_row = s->key_cache + loff + pos * kv_dim; |
| float* value_cache_row = s->value_cache + loff + pos * kv_dim; |
| memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); |
| memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); |
|
|
| |
| int h; |
| #pragma omp parallel for private(h) |
| for (h = 0; h < p->n_heads; h++) { |
| |
| float* q = s->q + h * head_size; |
| |
| float* att = s->att + h * p->seq_len; |
| |
| for (int t = 0; t <= pos; t++) { |
| |
| float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; |
| |
| float score = 0.0f; |
| for (int i = 0; i < head_size; i++) { |
| score += q[i] * k[i]; |
| } |
| score /= sqrtf(head_size); |
| |
| att[t] = score; |
| } |
|
|
| |
| softmax(att, pos + 1); |
|
|
| |
| float* xb = s->xb + h * head_size; |
| memset(xb, 0, head_size * sizeof(float)); |
| for (int t = 0; t <= pos; t++) { |
| |
| float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; |
| |
| float a = att[t]; |
| |
| for (int i = 0; i < head_size; i++) { |
| xb[i] += a * v[i]; |
| } |
| } |
| } |
|
|
| |
| quantize(&s->xq, s->xb, dim); |
| matmul(s->xb2, &s->xq, w->wo + l, dim, dim); |
|
|
| |
| for (int i = 0; i < dim; i++) { |
| x[i] += s->xb2[i]; |
| } |
|
|
| |
| rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); |
|
|
| |
| |
| quantize(&s->xq, s->xb, dim); |
| matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim); |
| matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim); |
|
|
| |
| for (int i = 0; i < hidden_dim; i++) { |
| float val = s->hb[i]; |
| |
| val *= (1.0f / (1.0f + expf(-val))); |
| |
| val *= s->hb2[i]; |
| s->hb[i] = val; |
| } |
|
|
| |
| quantize(&s->hq, s->hb, hidden_dim); |
| matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim); |
|
|
| |
| for (int i = 0; i < dim; i++) { |
| x[i] += s->xb[i]; |
| } |
| } |
|
|
| |
| rmsnorm(x, x, w->rms_final_weight, dim); |
|
|
| |
| quantize(&s->xq, x, dim); |
| matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size); |
| return s->logits; |
| } |
|
|
| |
| |
|
|
| typedef struct { |
| char *str; |
| int id; |
| } TokenIndex; |
|
|
| typedef struct { |
| char** vocab; |
| float* vocab_scores; |
| TokenIndex *sorted_vocab; |
| int vocab_size; |
| unsigned int max_token_length; |
| unsigned char byte_pieces[512]; |
| } Tokenizer; |
|
|
| int compare_tokens(const void *a, const void *b) { |
| return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); |
| } |
|
|
| void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { |
| |
| t->vocab_size = vocab_size; |
| |
| t->vocab = (char**)malloc(vocab_size * sizeof(char*)); |
| t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); |
| t->sorted_vocab = NULL; |
| for (int i = 0; i < 256; i++) { |
| t->byte_pieces[i * 2] = (unsigned char)i; |
| t->byte_pieces[i * 2 + 1] = '\0'; |
| } |
| |
| FILE *file = fopen(tokenizer_path, "rb"); |
| if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } |
| if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } |
| int len; |
| for (int i = 0; i < vocab_size; i++) { |
| if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} |
| if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } |
| t->vocab[i] = (char *)malloc(len + 1); |
| if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } |
| t->vocab[i][len] = '\0'; |
| } |
| fclose(file); |
| } |
|
|
| void free_tokenizer(Tokenizer* t) { |
| for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } |
| free(t->vocab); |
| free(t->vocab_scores); |
| free(t->sorted_vocab); |
| } |
|
|
| char* decode(Tokenizer* t, int prev_token, int token) { |
| char *piece = t->vocab[token]; |
|
|
|
|
| |
| |
| unsigned char byte_val; |
| if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { |
| piece = (char*)t->byte_pieces + byte_val * 2; |
| } |
| return piece; |
| } |
|
|
| void safe_printf(char *piece) { |
| |
| |
| if (piece == NULL) { return; } |
| if (piece[0] == '\0') { return; } |
| if (piece[1] == '\0') { |
| unsigned char byte_val = piece[0]; |
| if (!(isprint(byte_val) || isspace(byte_val))) { |
| return; |
| } |
| } |
| printf("%s", piece); |
| } |
|
|
| int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { |
| |
| TokenIndex tok = { .str = str }; |
| TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); |
| return res != NULL ? res->id : -1; |
| } |
|
|
| void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { |
| |
| |
| if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } |
|
|
| if (t->sorted_vocab == NULL) { |
| |
| t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); |
| for (int i = 0; i < t->vocab_size; i++) { |
| t->sorted_vocab[i].str = t->vocab[i]; |
| t->sorted_vocab[i].id = i; |
| } |
| qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); |
| } |
|
|
| |
| |
| char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); |
| size_t str_len = 0; |
|
|
| |
| *n_tokens = 0; |
|
|
| |
| if (bos) tokens[(*n_tokens)++] = 128000; |
|
|
| |
| |
| |
| |
|
|
|
|
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| for (char *c = text; *c != '\0'; c++) { |
|
|
| |
| |
| |
| |
| |
| if ((*c & 0xC0) != 0x80) { |
| |
| |
| str_len = 0; |
| } |
|
|
| |
| str_buffer[str_len++] = *c; |
| str_buffer[str_len] = '\0'; |
|
|
| |
| |
| if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) { |
| continue; |
| } |
|
|
| |
| int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); |
|
|
| if (id != -1) { |
| |
| tokens[(*n_tokens)++] = id; |
| } else { |
| |
| |
| |
| for (int i=0; i < str_len; i++) { |
| tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; |
| } |
| } |
| str_len = 0; |
| } |
|
|
| |
| while (1) { |
| float best_score = -1e10; |
| int best_id = -1; |
| int best_idx = -1; |
| int best_len = 2; |
|
|
| |
| for (int i = 0; i < (*n_tokens - 1); i++) { |
| |
| sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); |
| int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); |
| if (id != -1 && t->vocab_scores[id] > best_score) { |
| |
| best_score = t->vocab_scores[id]; |
| best_id = id; |
| best_idx = i; |
| } |
| } |
|
|
| |
| if (best_idx == -1) { |
| for (int i = 0; i < (*n_tokens - 2); i++) { |
| |
| sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]); |
| int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); |
| if (id != -1 && t->vocab_scores[id] > best_score) { |
| |
| best_score = t->vocab_scores[id]; |
| best_id = id; |
| best_idx = i; |
| best_len = 3; |
| } |
| } |
| } |
|
|
| if (best_idx == -1) { |
| break; |
| } |
|
|
| |
| tokens[best_idx] = best_id; |
| |
| for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) { |
| tokens[i] = tokens[i + best_len - 1]; |
| } |
| (*n_tokens) -= (best_len - 1); |
| } |
|
|
| |
| if (eos) tokens[(*n_tokens)++] = 128001; |
|
|
| free(str_buffer); |
| } |
|
|
| |
| |
| |
|
|
| typedef struct { |
| float prob; |
| int index; |
| } ProbIndex; |
|
|
| typedef struct { |
| int vocab_size; |
| ProbIndex* probindex; |
| float temperature; |
| float topp; |
| unsigned long long rng_state; |
| } Sampler; |
|
|
| int sample_argmax(float* probabilities, int n) { |
| |
| int max_i = 0; |
| float max_p = probabilities[0]; |
| for (int i = 1; i < n; i++) { |
| if (probabilities[i] > max_p) { |
| max_i = i; |
| max_p = probabilities[i]; |
| } |
| } |
| return max_i; |
| } |
|
|
| int sample_mult(float* probabilities, int n, float coin) { |
| |
| |
| float cdf = 0.0f; |
| for (int i = 0; i < n; i++) { |
| cdf += probabilities[i]; |
| if (coin < cdf) { |
| return i; |
| } |
| } |
| return n - 1; |
| } |
|
|
| int compare(const void* a, const void* b) { |
| ProbIndex* a_ = (ProbIndex*) a; |
| ProbIndex* b_ = (ProbIndex*) b; |
| if (a_->prob > b_->prob) return -1; |
| if (a_->prob < b_->prob) return 1; |
| return 0; |
| } |
|
|
| int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { |
| |
| |
| |
| |
|
|
| int n0 = 0; |
| |
| |
| |
| const float cutoff = (1.0f - topp) / (n - 1); |
| for (int i = 0; i < n; i++) { |
| if (probabilities[i] >= cutoff) { |
| probindex[n0].index = i; |
| probindex[n0].prob = probabilities[i]; |
| n0++; |
| } |
| } |
| qsort(probindex, n0, sizeof(ProbIndex), compare); |
|
|
| |
| float cumulative_prob = 0.0f; |
| int last_idx = n0 - 1; |
| for (int i = 0; i < n0; i++) { |
| cumulative_prob += probindex[i].prob; |
| if (cumulative_prob > topp) { |
| last_idx = i; |
| break; |
| } |
| } |
|
|
| |
| float r = coin * cumulative_prob; |
| float cdf = 0.0f; |
| for (int i = 0; i <= last_idx; i++) { |
| cdf += probindex[i].prob; |
| if (r < cdf) { |
| return probindex[i].index; |
| } |
| } |
| return probindex[last_idx].index; |
| } |
|
|
| void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { |
| sampler->vocab_size = vocab_size; |
| sampler->temperature = temperature; |
| sampler->topp = topp; |
| sampler->rng_state = rng_seed; |
| |
| sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex)); |
| } |
|
|
| void free_sampler(Sampler* sampler) { |
| free(sampler->probindex); |
| } |
|
|
| unsigned int random_u32(unsigned long long *state) { |
| |
| *state ^= *state >> 12; |
| *state ^= *state << 25; |
| *state ^= *state >> 27; |
| return (*state * 0x2545F4914F6CDD1Dull) >> 32; |
| } |
| float random_f32(unsigned long long *state) { |
| return (random_u32(state) >> 8) / 16777216.0f; |
| } |
|
|
| int sample(Sampler* sampler, float* logits) { |
| |
| int next; |
| if (sampler->temperature == 0.0f) { |
| |
| next = sample_argmax(logits, sampler->vocab_size); |
| } else { |
| |
| for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; } |
| |
| softmax(logits, sampler->vocab_size); |
| |
| float coin = random_f32(&sampler->rng_state); |
| |
| if (sampler->topp <= 0 || sampler->topp >= 1) { |
| |
| next = sample_mult(logits, sampler->vocab_size, coin); |
| } else { |
| |
| next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); |
| } |
| } |
| return next; |
| } |
|
|
| |
| |
|
|
| long time_in_ms() { |
| |
| struct timespec time; |
| clock_gettime(CLOCK_REALTIME, &time); |
| return time.tv_sec * 1000 + time.tv_nsec / 1000000; |
| } |
|
|
| |
| |
|
|
| void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { |
| char *empty_prompt = ""; |
| if (prompt == NULL) { prompt = empty_prompt; } |
|
|
| |
| int num_prompt_tokens = 0; |
| int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); |
| encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); |
| if (num_prompt_tokens < 1) { |
| fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); |
| exit(EXIT_FAILURE); |
| } |
|
|
| |
| long start = 0; |
| int next; |
| int token = prompt_tokens[0]; |
| int pos = 0; |
|
|
| while (pos < steps) { |
|
|
| |
| float* logits = forward(transformer, token, pos); |
|
|
| |
| if (pos < num_prompt_tokens - 1) { |
| |
| next = prompt_tokens[pos + 1]; |
| } else { |
| |
| next = sample(sampler, logits); |
| } |
| pos++; |
|
|
| |
| if ((next == 128001 || next == 128009) && pos > num_prompt_tokens) break; |
| |
| char* piece = decode(tokenizer, token, next); |
| safe_printf(piece); |
| fflush(stdout); |
| token = next; |
|
|
| |
| if (start == 0) { start = time_in_ms(); } |
| } |
| printf("\n"); |
|
|
| |
| if (pos > 1) { |
| long end = time_in_ms(); |
| fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); |
| } |
|
|
| free(prompt_tokens); |
| } |
|
|
| void read_stdin(const char* guide, char* buffer, size_t bufsize) { |
| |
| printf("%s", guide); |
| if (fgets(buffer, bufsize, stdin) != NULL) { |
| size_t len = strlen(buffer); |
| if (len > 0 && buffer[len - 1] == '\n') { |
| buffer[len - 1] = '\0'; |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
|
|
| void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, |
| char *cli_user_prompt, char *cli_system_prompt, int steps) { |
|
|
| |
| |
| char* system_prompt = (char*)malloc(32768 * sizeof(char)); |
| char* user_prompt = (char*)malloc(32768 * sizeof(char)); |
| int num_prompt_tokens = 0; |
| int* prompt_tokens = (int*)malloc(32768 * sizeof(int)); |
| int* system_prompt_tokens = (int*)malloc(32768 * sizeof(int)); |
| int* user_prompt_tokens = (int*)malloc(32768 * sizeof(int)); |
| int user_idx=0; |
|
|
| |
| int8_t user_turn = 1; |
| int next; |
| int token; |
|
|
| int pos = 0; |
| while (pos < steps) { |
|
|
| |
| if (user_turn) { |
| |
| if (pos == 0) { |
| |
| prompt_tokens[num_prompt_tokens++] = 128000; |
| prompt_tokens[num_prompt_tokens++] = 128006; |
| prompt_tokens[num_prompt_tokens++] = 9125; |
| prompt_tokens[num_prompt_tokens++] = 128007; |
| prompt_tokens[num_prompt_tokens++] = 271; |
| if (cli_system_prompt == NULL) { |
| |
| read_stdin("Enter system prompt (optional): ", system_prompt, 32768); |
| } else { |
| |
| strcpy(system_prompt, cli_system_prompt); |
| } |
| if (system_prompt != NULL) { |
| int num_system_prompt_tokens = 0; |
| encode(tokenizer, system_prompt, 0, 0, system_prompt_tokens, &num_system_prompt_tokens); |
| for (int i=0; i<num_system_prompt_tokens; i++) { |
| prompt_tokens[num_prompt_tokens++] = system_prompt_tokens[i]; |
| } |
| } |
| prompt_tokens[num_prompt_tokens++] = 128009; |
| } else { |
| num_prompt_tokens = 0; |
| } |
| prompt_tokens[num_prompt_tokens++] = 128006; |
| prompt_tokens[num_prompt_tokens++] = 882; |
| prompt_tokens[num_prompt_tokens++] = 128007; |
| prompt_tokens[num_prompt_tokens++] = 271; |
| |
| if (pos == 0 && cli_user_prompt != NULL) { |
| |
| strcpy(user_prompt, cli_user_prompt); |
| } else { |
| |
| read_stdin("User (or exit): ", user_prompt, 32768); |
| if(strcmp(user_prompt, "exit")==0) break; |
| } |
| int num_user_prompt_tokens = 0; |
| |
| encode(tokenizer, user_prompt, 0, 0, user_prompt_tokens, &num_user_prompt_tokens); |
| for (int i=0; i<num_user_prompt_tokens; i++) { |
| prompt_tokens[num_prompt_tokens++] = user_prompt_tokens[i]; |
| } |
| prompt_tokens[num_prompt_tokens++] = 128009; |
| prompt_tokens[num_prompt_tokens++] = 128006; |
| prompt_tokens[num_prompt_tokens++] = 78191; |
| prompt_tokens[num_prompt_tokens++] = 128007; |
| prompt_tokens[num_prompt_tokens++] = 271; |
|
|
|
|
| user_idx = 0; |
| user_turn = 0; |
| printf("Assistant: "); |
| } |
|
|
| |
| if (user_idx < num_prompt_tokens) { |
| |
| token = prompt_tokens[user_idx++]; |
| } else { |
| |
| token = next; |
| } |
| |
| if (user_idx >= num_prompt_tokens && (token == 128009 || token == 128001)) { user_turn = 1; } |
|
|
| |
| float* logits = forward(transformer, token, pos); |
| next = sample(sampler, logits); |
| pos++; |
|
|
| if (user_idx >= num_prompt_tokens && next != 128009 && next != 128001 && next != 128006) { |
| |
| char* piece = decode(tokenizer, token, next); |
| safe_printf(piece); |
| fflush(stdout); |
| } |
| if (user_idx >= num_prompt_tokens && next == 128009 || next == 128001) { printf("\n"); } |
| } |
| printf("\n"); |
| free(prompt_tokens); |
| free(system_prompt_tokens); |
| free(user_prompt_tokens); |
| free(system_prompt); |
| free(user_prompt); |
| } |
|
|
|
|
| |
| |
| #ifndef TESTING |
|
|
| void error_usage() { |
| fprintf(stderr, "Usage: run <checkpoint> [options]\n"); |
| fprintf(stderr, "Example: run model.bin -n 4096 -i \"Once upon a time\"\n"); |
| fprintf(stderr, "Options:\n"); |
| fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n"); |
| fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n"); |
| fprintf(stderr, " -s <int> random seed, default time(NULL)\n"); |
| fprintf(stderr, " -n <int> number of steps to run for, default 4096. 0 = max_seq_len\n"); |
| fprintf(stderr, " -i <string> input prompt\n"); |
| fprintf(stderr, " -z <string> optional path to custom tokenizer\n"); |
| fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n"); |
| fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n"); |
| exit(EXIT_FAILURE); |
| } |
|
|
| int main(int argc, char *argv[]) { |
|
|
| |
| char *checkpoint_path = NULL; |
| char *tokenizer_path = "tokenizer.bin"; |
| float temperature = 1.0f; |
| float topp = 0.9f; |
| int steps = 4096; |
| char *prompt = NULL; |
| unsigned long long rng_seed = 0; |
| char *mode = "generate"; |
| char *system_prompt = NULL; |
|
|
| |
| if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } |
| for (int i = 2; i < argc; i+=2) { |
| |
| if (i + 1 >= argc) { error_usage(); } |
| if (argv[i][0] != '-') { error_usage(); } |
| if (strlen(argv[i]) != 2) { error_usage(); } |
| |
| if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); } |
| else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); } |
| else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } |
| else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } |
| else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } |
| else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } |
| else if (argv[i][1] == 'm') { mode = argv[i + 1]; } |
| else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; } |
| else { error_usage(); } |
| } |
|
|
| |
| if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL); |
| if (temperature < 0.0) temperature = 0.0; |
| if (topp < 0.0 || 1.0 < topp) topp = 0.9; |
| if (steps < 0) steps = 0; |
|
|
| |
| Transformer transformer; |
| build_transformer(&transformer, checkpoint_path); |
| if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; |
|
|
| |
| Tokenizer tokenizer; |
| build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size); |
|
|
| |
| Sampler sampler; |
| build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); |
|
|
| |
| if (strcmp(mode, "generate") == 0) { |
| generate(&transformer, &tokenizer, &sampler, prompt, steps); |
| } else if (strcmp(mode, "chat") == 0) { |
| chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps); |
| } else { |
| fprintf(stderr, "unknown mode: %s\n", mode); |
| error_usage(); |
| } |
|
|
| |
| free_sampler(&sampler); |
| free_tokenizer(&tokenizer); |
| free_transformer(&transformer); |
| return 0; |
| } |
| #endif |
|
|