| | """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" |
| |
|
| | |
| | from pathlib import Path |
| |
|
| | from itertools import groupby |
| | from typing import ( |
| | Any, |
| | Set, |
| | List, |
| | Optional, |
| | Tuple, |
| | Union, |
| | ) |
| |
|
| | LLAMA_GRAMMAR_DEFAULT_ROOT = "root" |
| |
|
| |
|
| | class LlamaGrammar: |
| | def __init__(self, *args, _grammar: str, **kwargs): |
| | self._grammar = _grammar |
| | self._root = LLAMA_GRAMMAR_DEFAULT_ROOT |
| |
|
| | @classmethod |
| | def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": |
| | return cls(_grammar=grammar) |
| |
|
| | @classmethod |
| | def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": |
| | try: |
| | with open(file) as f: |
| | grammar = f.read() |
| | except Exception as err: |
| | raise Exception( |
| | f"{cls.from_file.__name__}: error reading grammar file: {err}" |
| | ) |
| |
|
| | if grammar: |
| | return cls.from_string(grammar, verbose=verbose) |
| |
|
| | raise ValueError( |
| | f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" |
| | ) |
| |
|
| | @classmethod |
| | def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar": |
| | return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose) |
| |
|
| |
|
| | """llama.cpp gbnf rules from vendor/llama.cpp/grammars""" |
| |
|
| | ARITHMETIC_GBNF = r""" |
| | root ::= (expr "=" ws term "\n")+ |
| | expr ::= term ([-+*/] term)* |
| | term ::= ident | num | "(" ws expr ")" ws |
| | ident ::= [a-z] [a-z0-9_]* ws |
| | num ::= [0-9]+ ws |
| | ws ::= [ \t\n]* |
| | """ |
| |
|
| | C_GBNF = r""" |
| | root ::= (declaration)* |
| | |
| | declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" |
| | |
| | dataType ::= "int" ws | "float" ws | "char" ws |
| | identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* |
| | |
| | parameter ::= dataType identifier |
| | |
| | statement ::= |
| | ( dataType identifier ws "=" ws expression ";" ) | |
| | ( identifier ws "=" ws expression ";" ) | |
| | ( identifier ws "(" argList? ")" ";" ) | |
| | ( "return" ws expression ";" ) | |
| | ( "while" "(" condition ")" "{" statement* "}" ) | |
| | ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | |
| | ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | |
| | ( singleLineComment ) | |
| | ( multiLineComment ) |
| | |
| | forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression |
| | forUpdate ::= identifier ws "=" ws expression |
| | |
| | condition ::= expression relationOperator expression |
| | relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") |
| | |
| | expression ::= term (("+" | "-") term)* |
| | term ::= factor(("*" | "/") factor)* |
| | |
| | factor ::= identifier | number | unaryTerm | funcCall | parenExpression |
| | unaryTerm ::= "-" factor |
| | funcCall ::= identifier "(" argList? ")" |
| | parenExpression ::= "(" ws expression ws ")" |
| | |
| | argList ::= expression ("," ws expression)* |
| | |
| | number ::= [0-9]+ |
| | |
| | singleLineComment ::= "//" [^\n]* "\n" |
| | multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" |
| | |
| | ws ::= ([ \t\n]+) |
| | """ |
| |
|
| | CHESS_GBNF = r""" |
| | root ::= object |
| | value ::= object | array | string | number | ("true" | "false" | "null") ws |
| | |
| | object ::= |
| | "{" ws ( |
| | string ":" ws value |
| | ("," ws string ":" ws value)* |
| | )? "}" ws |
| | |
| | array ::= |
| | "[" ws ( |
| | value |
| | ("," ws value)* |
| | )? "]" ws |
| | |
| | string ::= |
| | "\"" ( |
| | [^"\\] | |
| | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes |
| | )* "\"" ws |
| | |
| | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws |
| | |
| | # Optional space: by convention, applied in this grammar after literal chars when allowed |
| | ws ::= ([ \t\n] ws)? |
| | """ |
| |
|
| | JAPANESE_GBNF = r""" |
| | root ::= object |
| | value ::= object | array | string | number | ("true" | "false" | "null") ws |
| | |
| | object ::= |
| | "{" ws ( |
| | string ":" ws value |
| | ("," ws string ":" ws value)* |
| | )? "}" ws |
| | |
| | array ::= |
| | "[" ws ( |
| | value |
| | ("," ws value)* |
| | )? "]" ws |
| | |
| | string ::= |
| | "\"" ( |
| | [^"\\] | |
| | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes |
| | )* "\"" ws |
| | |
| | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws |
| | |
| | # Optional space: by convention, applied in this grammar after literal chars when allowed |
| | ws ::= ([ \t\n] ws)? |
| | """ |
| |
|
| | JSON_ARR_GBNF = r""" |
| | # This is the same as json.gbnf but we restrict whitespaces at the end of the root array |
| | # Useful for generating JSON arrays |
| | |
| | root ::= arr |
| | value ::= object | array | string | number | ("true" | "false" | "null") ws |
| | |
| | arr ::= |
| | "[\n" ws ( |
| | value |
| | (",\n" ws value)* |
| | )? "]" |
| | |
| | object ::= |
| | "{" ws ( |
| | string ":" ws value |
| | ("," ws string ":" ws value)* |
| | )? "}" ws |
| | |
| | array ::= |
| | "[" ws ( |
| | value |
| | ("," ws value)* |
| | )? "]" ws |
| | |
| | string ::= |
| | "\"" ( |
| | [^"\\\x7F\x00-\x1F] | |
| | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes |
| | )* "\"" ws |
| | |
| | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws |
| | |
| | # Optional space: by convention, applied in this grammar after literal chars when allowed |
| | ws ::= ([ \t\n] ws)? |
| | """ |
| |
|
| |
|
| | JSON_GBNF = r""" |
| | root ::= object |
| | value ::= object | array | string | number | ("true" | "false" | "null") ws |
| | |
| | object ::= |
| | "{" ws ( |
| | string ":" ws value |
| | ("," ws string ":" ws value)* |
| | )? "}" ws |
| | |
| | array ::= |
| | "[" ws ( |
| | value |
| | ("," ws value)* |
| | )? "]" ws |
| | |
| | string ::= |
| | "\"" ( |
| | [^"\\\x7F\x00-\x1F] | |
| | "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes |
| | )* "\"" ws |
| | |
| | number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws |
| | |
| | # Optional space: by convention, applied in this grammar after literal chars when allowed |
| | ws ::= | " " | "\n" [ \t]{0,20} |
| | """ |
| |
|
| | LIST_GBNF = r""" |
| | root ::= item+ |
| | |
| | # Excludes various line break characters |
| | item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" |
| | """ |
| |
|
| | """llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" |
| | import json |
| | import re |
| | from typing import List, Optional |
| |
|
| | |
| | |
| | SPACE_RULE = '" "?' |
| |
|
| |
|
| | INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") |
| | GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') |
| | GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} |
| |
|
| | |
| | |
| | SPACE_RULE = '" "?' |
| |
|
| |
|
| | def _build_repetition( |
| | item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False |
| | ): |
| | if not separator_rule: |
| | if min_items == 0 and max_items == 1: |
| | return f"{item_rule}?" |
| | elif min_items == 1 and max_items is None: |
| | return f"{item_rule}+" |
| |
|
| | result = "" |
| |
|
| | if min_items > 0: |
| | if item_rule_is_literal and separator_rule is None: |
| | result = '"' + (item_rule[1:-1] * min_items) + '"' |
| | else: |
| | result = (f" {separator_rule} " if separator_rule else " ").join( |
| | [item_rule] * min_items |
| | ) |
| |
|
| | def opt_repetitions(up_to_n, prefix_with_sep=False): |
| | """ |
| | - n=4, no sep: '(a (a (a (a)?)?)?)?' |
| | - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' |
| | - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' |
| | """ |
| |
|
| | content = ( |
| | f"{separator_rule} {item_rule}" |
| | if prefix_with_sep and separator_rule |
| | else item_rule |
| | ) |
| | if up_to_n == 0: |
| | return "" |
| | elif up_to_n == 1: |
| | return f"({content})?" |
| | elif separator_rule and not prefix_with_sep: |
| | return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?" |
| | else: |
| | return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n) |
| |
|
| | if min_items > 0 and max_items != min_items: |
| | result += " " |
| |
|
| | if max_items is not None: |
| | result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) |
| | else: |
| | item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' |
| |
|
| | if min_items == 0 and separator_rule: |
| | result = f"({item_rule} {item_operator}*)?" |
| | else: |
| | result += f"{item_operator}*" |
| |
|
| | return result |
| |
|
| |
|
| | class BuiltinRule: |
| | def __init__(self, content: str, deps: list = None): |
| | self.content = content |
| | self.deps = deps or [] |
| |
|
| |
|
| | _up_to_15_digits = _build_repetition("[0-9]", 0, 15) |
| |
|
| | PRIMITIVE_RULES = { |
| | "boolean": BuiltinRule('("true" | "false") space', []), |
| | "decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []), |
| | "integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []), |
| | "number": BuiltinRule( |
| | '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', |
| | ["integral-part", "decimal-part"], |
| | ), |
| | "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]), |
| | "value": BuiltinRule( |
| | "object | array | string | number | boolean | null", |
| | ["object", "array", "string", "number", "boolean", "null"], |
| | ), |
| | "object": BuiltinRule( |
| | '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', |
| | ["string", "value"], |
| | ), |
| | "array": BuiltinRule( |
| | '"[" space ( value ("," space value)* )? "]" space', ["value"] |
| | ), |
| | "uuid": BuiltinRule( |
| | r'"\"" ' |
| | + ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12]) |
| | + r' "\"" space', |
| | [], |
| | ), |
| | "char": BuiltinRule( |
| | r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', |
| | [], |
| | ), |
| | "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]), |
| | "null": BuiltinRule('"null" space', []), |
| | } |
| |
|
| | |
| | STRING_FORMAT_RULES = { |
| | "date": BuiltinRule( |
| | '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )', |
| | [], |
| | ), |
| | "time": BuiltinRule( |
| | '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', |
| | [], |
| | ), |
| | "date-time": BuiltinRule('date "T" time', ["date", "time"]), |
| | "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]), |
| | "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]), |
| | "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]), |
| | } |
| |
|
| | DOTALL = "[\\U00000000-\\U0010FFFF]" |
| | DOT = "[^\\x0A\\x0D]" |
| |
|
| | RESERVED_NAMES = set( |
| | ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()] |
| | ) |
| |
|
| |
|
| | NON_LITERAL_SET = set("|.()[]{}*+?") |
| | ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?") |
| |
|
| |
|
| | class SchemaConverter: |
| | def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): |
| | self._prop_order = prop_order |
| | self._allow_fetch = allow_fetch |
| | self._dotall = dotall |
| | self._raw_pattern = raw_pattern |
| | self._rules = { |
| | "space": SPACE_RULE, |
| | } |
| | self._refs = {} |
| | self._refs_being_resolved = set() |
| |
|
| | def _format_literal(self, literal): |
| | escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( |
| | lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal |
| | ) |
| | return f'"{escaped}"' |
| |
|
| | def not_literal( |
| | self, literal: str, dotall: bool = True, maybe_escaped_underscores=False |
| | ) -> str: |
| | """ |
| | not_literal('a') -> '[^a]' |
| | not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' |
| | """ |
| | assert len(literal) > 0, "Empty literal not supported" |
| |
|
| | def recurse(i: int): |
| | c = literal[i] |
| | if maybe_escaped_underscores and c == "_": |
| | yield f"[^{c}\\\\]" |
| | yield " | " |
| | yield f'"\\\\"? "{c}"' |
| | else: |
| | yield f"[^{c}]" |
| | if i < len(literal) - 1: |
| | yield " | " |
| | yield self._format_literal(c) |
| | yield " (" |
| | yield from recurse(i + 1) |
| | yield ")?" |
| |
|
| | return "".join(("(", *recurse(0), ")")) |
| |
|
| | def _add_rule(self, name, rule): |
| | esc_name = INVALID_RULE_CHARS_RE.sub("-", name) |
| | if esc_name not in self._rules or self._rules[esc_name] == rule: |
| | key = esc_name |
| | else: |
| | i = 0 |
| | while ( |
| | f"{esc_name}{i}" in self._rules |
| | and self._rules[f"{esc_name}{i}"] != rule |
| | ): |
| | i += 1 |
| | key = f"{esc_name}{i}" |
| | self._rules[key] = rule |
| | return key |
| |
|
| | def resolve_refs(self, schema: dict, url: str): |
| | """ |
| | Resolves all $ref fields in the given schema, fetching any remote schemas, |
| | replacing $ref with absolute reference URL and populating self._refs with the |
| | respective referenced (sub)schema dictionaries. |
| | """ |
| |
|
| | def visit(n: dict): |
| | if isinstance(n, list): |
| | return [visit(x) for x in n] |
| | elif isinstance(n, dict): |
| | ref = n.get("$ref") |
| | if ref is not None and ref not in self._refs: |
| | if ref.startswith("https://"): |
| | assert ( |
| | self._allow_fetch |
| | ), "Fetching remote schemas is not allowed (use --allow-fetch for force)" |
| | import requests |
| |
|
| | frag_split = ref.split("#") |
| | base_url = frag_split[0] |
| |
|
| | target = self._refs.get(base_url) |
| | if target is None: |
| | target = self.resolve_refs( |
| | requests.get(ref).json(), base_url |
| | ) |
| | self._refs[base_url] = target |
| |
|
| | if len(frag_split) == 1 or frag_split[-1] == "": |
| | return target |
| | elif ref.startswith("#/"): |
| | target = schema |
| | ref = f"{url}{ref}" |
| | n["$ref"] = ref |
| | else: |
| | raise ValueError(f"Unsupported ref {ref}") |
| |
|
| | for sel in ref.split("#")[-1].split("/")[1:]: |
| | assert ( |
| | target is not None and sel in target |
| | ), f"Error resolving ref {ref}: {sel} not in {target}" |
| | target = target[sel] |
| |
|
| | self._refs[ref] = target |
| | else: |
| | for v in n.values(): |
| | visit(v) |
| |
|
| | return n |
| |
|
| | return visit(schema) |
| |
|
| | def _generate_union_rule(self, name, alt_schemas): |
| | return " | ".join( |
| | ( |
| | self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') |
| | for i, alt_schema in enumerate(alt_schemas) |
| | ) |
| | ) |
| |
|
| | def _visit_pattern(self, pattern, name): |
| | """ |
| | Transforms a regular expression pattern into a GBNF rule. |
| | |
| | Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions |
| | Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md |
| | |
| | Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. |
| | |
| | Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which |
| | we define sub-rules to keep the output lean. |
| | """ |
| |
|
| | assert pattern.startswith("^") and pattern.endswith( |
| | "$" |
| | ), 'Pattern must start with "^" and end with "$"' |
| | pattern = pattern[1:-1] |
| | sub_rule_ids = {} |
| |
|
| | i = 0 |
| | length = len(pattern) |
| |
|
| | def to_rule(s: Tuple[str, bool]) -> str: |
| | (txt, is_literal) = s |
| | return '"' + txt + '"' if is_literal else txt |
| |
|
| | def transform() -> Tuple[str, bool]: |
| | """ |
| | Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. |
| | """ |
| | nonlocal i |
| | nonlocal pattern |
| | nonlocal sub_rule_ids |
| |
|
| | start = i |
| | |
| | |
| | |
| | |
| | seq: list[Tuple[str, bool]] = [] |
| |
|
| | def get_dot(): |
| | if self._dotall: |
| | rule = DOTALL |
| | else: |
| | |
| | rule = DOT |
| | return self._add_rule(f"dot", rule) |
| |
|
| | def join_seq(): |
| | nonlocal seq |
| | ret = [] |
| | for is_literal, g in groupby(seq, lambda x: x[1]): |
| | if is_literal: |
| | ret.append(("".join(x[0] for x in g), True)) |
| | else: |
| | ret.extend(g) |
| | if len(ret) == 1: |
| | return ret[0] |
| | return (" ".join(to_rule(x) for x in seq), False) |
| |
|
| | while i < length: |
| | c = pattern[i] |
| | if c == ".": |
| | seq.append((get_dot(), False)) |
| | i += 1 |
| | elif c == "(": |
| | i += 1 |
| | if i < length: |
| | assert ( |
| | pattern[i] != "?" |
| | ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' |
| | seq.append((f"({to_rule(transform())})", False)) |
| | elif c == ")": |
| | i += 1 |
| | assert ( |
| | start > 0 and pattern[start - 1] == "(" |
| | ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}" |
| | return join_seq() |
| | elif c == "[": |
| | square_brackets = c |
| | i += 1 |
| | while i < length and pattern[i] != "]": |
| | if pattern[i] == "\\": |
| | square_brackets += pattern[i : i + 2] |
| | i += 2 |
| | else: |
| | square_brackets += pattern[i] |
| | i += 1 |
| | assert ( |
| | i < length |
| | ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}" |
| | square_brackets += "]" |
| | i += 1 |
| | seq.append((square_brackets, False)) |
| | elif c == "|": |
| | seq.append(("|", False)) |
| | i += 1 |
| | elif c in ("*", "+", "?"): |
| | seq[-1] = (to_rule(seq[-1]) + c, False) |
| | i += 1 |
| | elif c == "{": |
| | curly_brackets = c |
| | i += 1 |
| | while i < length and pattern[i] != "}": |
| | curly_brackets += pattern[i] |
| | i += 1 |
| | assert ( |
| | i < length |
| | ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}" |
| | curly_brackets += "}" |
| | i += 1 |
| | nums = [s.strip() for s in curly_brackets[1:-1].split(",")] |
| | min_times = 0 |
| | max_times = None |
| | try: |
| | if len(nums) == 1: |
| | min_times = int(nums[0]) |
| | max_times = min_times |
| | else: |
| | assert len(nums) == 2 |
| | min_times = int(nums[0]) if nums[0] else 0 |
| | max_times = int(nums[1]) if nums[1] else None |
| | except ValueError: |
| | raise ValueError( |
| | f"Invalid quantifier {curly_brackets} in /{pattern}/" |
| | ) |
| |
|
| | (sub, sub_is_literal) = seq[-1] |
| |
|
| | if not sub_is_literal: |
| | id = sub_rule_ids.get(sub) |
| | if id is None: |
| | id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub) |
| | sub_rule_ids[sub] = id |
| | sub = id |
| |
|
| | seq[-1] = ( |
| | _build_repetition( |
| | f'"{sub}"' if sub_is_literal else sub, |
| | min_times, |
| | max_times, |
| | item_rule_is_literal=sub_is_literal, |
| | ), |
| | False, |
| | ) |
| | else: |
| | literal = "" |
| | while i < length: |
| | if pattern[i] == "\\" and i < length - 1: |
| | next = pattern[i + 1] |
| | if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: |
| | i += 1 |
| | literal += pattern[i] |
| | i += 1 |
| | else: |
| | literal += pattern[i : i + 2] |
| | i += 2 |
| | elif pattern[i] == '"' and not self._raw_pattern: |
| | literal += '\\"' |
| | i += 1 |
| | elif pattern[i] not in NON_LITERAL_SET and ( |
| | i == length - 1 |
| | or literal == "" |
| | or pattern[i + 1] == "." |
| | or pattern[i + 1] not in NON_LITERAL_SET |
| | ): |
| | literal += pattern[i] |
| | i += 1 |
| | else: |
| | break |
| | if literal: |
| | seq.append((literal, True)) |
| |
|
| | return join_seq() |
| |
|
| | return self._add_rule( |
| | name, |
| | ( |
| | to_rule(transform()) |
| | if self._raw_pattern |
| | else '"\\"" ' + to_rule(transform()) + ' "\\"" space' |
| | ), |
| | ) |
| |
|
| | def _resolve_ref(self, ref): |
| | ref_name = ref.split("/")[-1] |
| | if ref_name not in self._rules and ref not in self._refs_being_resolved: |
| | self._refs_being_resolved.add(ref) |
| | resolved = self._refs[ref] |
| | ref_name = self.visit(resolved, ref_name) |
| | self._refs_being_resolved.remove(ref) |
| | return ref_name |
| |
|
| | def _generate_constant_rule(self, value): |
| | return self._format_literal(json.dumps(value)) |
| |
|
| | def visit(self, schema, name): |
| | schema_type = schema.get("type") |
| | schema_format = schema.get("format") |
| | rule_name = name + "-" if name in RESERVED_NAMES else name or "root" |
| |
|
| | if (ref := schema.get("$ref")) is not None: |
| | return self._add_rule(rule_name, self._resolve_ref(ref)) |
| |
|
| | elif "oneOf" in schema or "anyOf" in schema: |
| | return self._add_rule( |
| | rule_name, |
| | self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]), |
| | ) |
| |
|
| | elif isinstance(schema_type, list): |
| | return self._add_rule( |
| | rule_name, |
| | self._generate_union_rule(name, [{"type": t} for t in schema_type]), |
| | ) |
| |
|
| | elif "const" in schema: |
| | return self._add_rule( |
| | rule_name, self._generate_constant_rule(schema["const"]) |
| | ) |
| |
|
| | elif "enum" in schema: |
| | rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"])) |
| | return self._add_rule(rule_name, rule) |
| |
|
| | elif schema_type in (None, "object") and ( |
| | "properties" in schema |
| | or ( |
| | "additionalProperties" in schema |
| | and schema["additionalProperties"] is not True |
| | ) |
| | ): |
| | required = set(schema.get("required", [])) |
| | properties = list(schema.get("properties", {}).items()) |
| | return self._add_rule( |
| | rule_name, |
| | self._build_object_rule( |
| | properties, required, name, schema.get("additionalProperties") |
| | ), |
| | ) |
| |
|
| | elif schema_type in (None, "object") and "allOf" in schema: |
| | required = set() |
| | properties = [] |
| | hybrid_name = name |
| |
|
| | def add_component(comp_schema, is_required): |
| | if (ref := comp_schema.get("$ref")) is not None: |
| | comp_schema = self._refs[ref] |
| |
|
| | if "properties" in comp_schema: |
| | for prop_name, prop_schema in comp_schema["properties"].items(): |
| | properties.append((prop_name, prop_schema)) |
| | if is_required: |
| | required.add(prop_name) |
| |
|
| | for t in schema["allOf"]: |
| | if "anyOf" in t: |
| | for tt in t["anyOf"]: |
| | add_component(tt, is_required=False) |
| | else: |
| | add_component(t, is_required=True) |
| |
|
| | return self._add_rule( |
| | rule_name, |
| | self._build_object_rule( |
| | properties, required, hybrid_name, additional_properties=[] |
| | ), |
| | ) |
| |
|
| | elif schema_type in (None, "array") and ( |
| | "items" in schema or "prefixItems" in schema |
| | ): |
| | items = schema.get("items") or schema["prefixItems"] |
| | if isinstance(items, list): |
| | return self._add_rule( |
| | rule_name, |
| | '"[" space ' |
| | + ' "," space '.join( |
| | self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') |
| | for i, item in enumerate(items) |
| | ) |
| | + ' "]" space', |
| | ) |
| | else: |
| | item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') |
| | min_items = schema.get("minItems", 0) |
| | max_items = schema.get("maxItems") |
| | return self._add_rule( |
| | rule_name, |
| | '"[" space ' |
| | + _build_repetition( |
| | item_rule_name, min_items, max_items, separator_rule='"," space' |
| | ) |
| | + ' "]" space', |
| | ) |
| |
|
| | elif schema_type in (None, "string") and "pattern" in schema: |
| | return self._visit_pattern(schema["pattern"], rule_name) |
| |
|
| | elif schema_type in (None, "string") and re.match( |
| | r"^uuid[1-5]?$", schema_format or "" |
| | ): |
| | return self._add_primitive( |
| | "root" if rule_name == "root" else schema_format, |
| | PRIMITIVE_RULES["uuid"], |
| | ) |
| |
|
| | elif ( |
| | schema_type in (None, "string") |
| | and f"{schema_format}-string" in STRING_FORMAT_RULES |
| | ): |
| | prim_name = f"{schema_format}-string" |
| | return self._add_rule( |
| | rule_name, |
| | self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]), |
| | ) |
| |
|
| | elif schema_type == "string" and ( |
| | "minLength" in schema or "maxLength" in schema |
| | ): |
| | char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"]) |
| | min_len = schema.get("minLength", 0) |
| | max_len = schema.get("maxLength") |
| |
|
| | return self._add_rule( |
| | rule_name, |
| | r'"\"" ' |
| | + _build_repetition(char_rule, min_len, max_len) |
| | + r' "\"" space', |
| | ) |
| |
|
| | elif (schema_type == "object") or (len(schema) == 0): |
| | return self._add_rule( |
| | rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"]) |
| | ) |
| |
|
| | else: |
| | assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" |
| | |
| | return self._add_primitive( |
| | "root" if rule_name == "root" else schema_type, |
| | PRIMITIVE_RULES[schema_type], |
| | ) |
| |
|
| | def _add_primitive(self, name: str, rule: BuiltinRule): |
| | n = self._add_rule(name, rule.content) |
| |
|
| | for dep in rule.deps: |
| | dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) |
| | assert dep_rule, f"Rule {dep} not known" |
| | if dep not in self._rules: |
| | self._add_primitive(dep, dep_rule) |
| | return n |
| |
|
| | def _build_object_rule( |
| | self, |
| | properties: List[Tuple[str, Any]], |
| | required: Set[str], |
| | name: str, |
| | additional_properties: Union[bool, Any], |
| | ): |
| | prop_order = self._prop_order |
| | |
| | sorted_props = [ |
| | kv[0] |
| | for _, kv in sorted( |
| | enumerate(properties), |
| | key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]), |
| | ) |
| | ] |
| |
|
| | prop_kv_rule_names = {} |
| | for prop_name, prop_schema in properties: |
| | prop_rule_name = self.visit( |
| | prop_schema, f'{name}{"-" if name else ""}{prop_name}' |
| | ) |
| | prop_kv_rule_names[prop_name] = self._add_rule( |
| | f'{name}{"-" if name else ""}{prop_name}-kv', |
| | rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}', |
| | ) |
| | required_props = [k for k in sorted_props if k in required] |
| | optional_props = [k for k in sorted_props if k not in required] |
| |
|
| | if additional_properties == True or isinstance(additional_properties, dict): |
| | sub_name = f'{name}{"-" if name else ""}additional' |
| | value_rule = self.visit( |
| | {} if additional_properties == True else additional_properties, |
| | f"{sub_name}-value", |
| | ) |
| | prop_kv_rule_names["*"] = self._add_rule( |
| | f"{sub_name}-kv", |
| | self._add_primitive("string", PRIMITIVE_RULES["string"]) |
| | + f' ":" space {value_rule}', |
| | ) |
| | optional_props.append("*") |
| |
|
| | rule = '"{" space ' |
| | rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) |
| |
|
| | if optional_props: |
| | rule += " (" |
| | if required_props: |
| | rule += ' "," space ( ' |
| |
|
| | def get_recursive_refs(ks, first_is_optional): |
| | [k, *rest] = ks |
| | kv_rule_name = prop_kv_rule_names[k] |
| | if k == "*": |
| | res = self._add_rule( |
| | f'{name}{"-" if name else ""}additional-kvs', |
| | f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*", |
| | ) |
| | elif first_is_optional: |
| | res = f'( "," space {kv_rule_name} )?' |
| | else: |
| | res = kv_rule_name |
| | if len(rest) > 0: |
| | res += " " + self._add_rule( |
| | f'{name}{"-" if name else ""}{k}-rest', |
| | get_recursive_refs(rest, first_is_optional=True), |
| | ) |
| | return res |
| |
|
| | rule += " | ".join( |
| | get_recursive_refs(optional_props[i:], first_is_optional=False) |
| | for i in range(len(optional_props)) |
| | ) |
| | if required_props: |
| | rule += " )" |
| | rule += " )?" |
| |
|
| | rule += ' "}" space' |
| |
|
| | return rule |
| |
|
| | def format_grammar(self): |
| | return "\n".join( |
| | f"{name} ::= {rule}" |
| | for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) |
| | ) |
| |
|
| |
|
| | def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): |
| | prop_order = prop_order or [] |
| | schema = json.loads(schema) |
| | prop_order = {name: idx for idx, name in enumerate(prop_order)} |
| | converter = SchemaConverter( |
| | prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False |
| | ) |
| | schema = converter.resolve_refs(schema, "stdin") |
| | converter.visit(schema, "") |
| | return converter.format_grammar() |
| |
|