| #pragma once
|
|
|
| #include "llama.h"
|
|
|
| #include <map>
|
| #include <regex>
|
| #include <string>
|
| #include <vector>
|
|
|
| struct llama_vocab;
|
|
|
|
|
| enum llama_gretype {
|
|
|
| LLAMA_GRETYPE_END = 0,
|
|
|
|
|
| LLAMA_GRETYPE_ALT = 1,
|
|
|
|
|
| LLAMA_GRETYPE_RULE_REF = 2,
|
|
|
|
|
| LLAMA_GRETYPE_CHAR = 3,
|
|
|
|
|
| LLAMA_GRETYPE_CHAR_NOT = 4,
|
|
|
|
|
|
|
| LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
|
|
|
|
|
|
| LLAMA_GRETYPE_CHAR_ALT = 6,
|
|
|
|
|
| LLAMA_GRETYPE_CHAR_ANY = 7,
|
|
|
|
|
| LLAMA_GRETYPE_TOKEN = 8,
|
|
|
|
|
| LLAMA_GRETYPE_TOKEN_NOT = 9,
|
| };
|
|
|
| typedef struct llama_grammar_element {
|
| enum llama_gretype type;
|
| uint32_t value;
|
| } llama_grammar_element;
|
|
|
| struct llama_partial_utf8 {
|
| uint32_t value;
|
| int n_remain;
|
| };
|
|
|
| struct llama_grammar_candidate {
|
| size_t index;
|
| const uint32_t * code_points;
|
| llama_partial_utf8 partial_utf8;
|
| llama_token id;
|
| };
|
|
|
| using llama_grammar_rule = std::vector< llama_grammar_element>;
|
| using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
|
|
| using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
| using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
| using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
|
|
|
|
| const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
| llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
|
|
|
|
|
|
|
|
|
|
| void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
|
|
|
| std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
| const llama_grammar_rules & rules,
|
| const llama_grammar_stack & stack,
|
| const llama_grammar_candidates & candidates);
|
|
|
| struct llama_grammar_parser {
|
| const llama_vocab * vocab;
|
| std::map<std::string, uint32_t> symbol_ids;
|
|
|
| llama_grammar_rules rules;
|
|
|
| llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
|
|
|
| llama_grammar_stack c_rules() const;
|
|
|
| uint32_t get_symbol_id(const char * src, size_t len);
|
| uint32_t generate_symbol_id(const std::string & base_name);
|
|
|
| void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
|
|
|
| const char * parse_alternates(
|
| const char * src,
|
| const std::string & rule_name,
|
| uint32_t rule_id,
|
| bool is_nested);
|
|
|
| const char * parse_sequence(
|
| const char * src,
|
| const std::string & rule_name,
|
| llama_grammar_rule & rule,
|
| bool is_nested);
|
|
|
| const char * parse_rule(const char * src);
|
|
|
| bool parse(const char * src);
|
| void print(FILE * file);
|
| };
|
|
|
| struct llama_grammar_trigger_pattern {
|
| std::string pattern;
|
| std::regex regex;
|
|
|
| size_t find(const std::string & input) const;
|
| };
|
|
|
| struct llama_grammar {
|
|
|
| using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
|
|
|
|
|
| const llama_vocab * vocab;
|
|
|
| const llama_grammar_rules rules;
|
| llama_grammar_stacks stacks;
|
|
|
|
|
| llama_partial_utf8 partial_utf8;
|
|
|
|
|
|
|
|
|
| bool lazy = false;
|
| bool awaiting_trigger = false;
|
| std::string trigger_buffer;
|
| std::vector<token_pos> trigger_buffer_positions;
|
| std::vector<llama_token> trigger_tokens;
|
| std::vector<llama_grammar_trigger_pattern>
|
| trigger_patterns;
|
|
|
|
|
| };
|
|
|
|
|
|
|
|
|
|
|
|
|
| struct llama_grammar * llama_grammar_init_impl(
|
| const struct llama_vocab * vocab,
|
| const llama_grammar_element ** rules,
|
| size_t n_rules,
|
| size_t start_rule_index);
|
|
|
| struct llama_grammar * llama_grammar_init_impl(
|
| const struct llama_vocab * vocab,
|
| const char * grammar_str,
|
| const char * grammar_root,
|
| bool lazy,
|
| const char ** trigger_patterns,
|
| size_t num_trigger_patterns,
|
| const llama_token * trigger_tokens,
|
| size_t num_trigger_tokens);
|
|
|
| void llama_grammar_free_impl(struct llama_grammar * grammar);
|
|
|
| struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
|
|
|
|
|
| void llama_grammar_apply_impl(
|
| const struct llama_grammar & grammar,
|
| llama_token_data_array * cur_p);
|
|
|
| void llama_grammar_accept_impl(
|
| struct llama_grammar & grammar,
|
| llama_token token);
|
|
|
| void llama_grammar_accept_str(
|
| struct llama_grammar & grammar,
|
| const std::string & piece);
|
|
|
| void llama_grammar_accept_token(
|
| struct llama_grammar & grammar,
|
| llama_token token,
|
| const std::string & piece);
|
|
|