| ''' |
| This file has been 100% copied from this PR to the Transformers library: |
| https://github.com/huggingface/transformers/pull/27557 |
| |
| Author: Saibo-creator |
| Author GitHub: https://github.com/Saibo-creator |
| |
| All credits go to the author. |
| ''' |
|
|
| import logging |
| import re |
| import time |
| from abc import ABC |
| from functools import lru_cache |
| from typing import Dict, List |
|
|
| import torch |
|
|
| from modules import shared |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| END_OF_ALTERNATE_MARKER = 0 |
| END_OF_RULE_MARKER = 0 |
| TO_BE_FILLED_MARKER = 0 |
| REF_RULE_MARKER = 1 |
| LITERAL_MARKER = 2 |
|
|
|
|
| class ParseState: |
| def __init__(self): |
| self.symbol_ids = {} |
| self.grammar_encoding = [] |
|
|
|
|
| def get_symbol_id(state, src): |
| if src not in state.symbol_ids: |
| state.symbol_ids[src] = len(state.symbol_ids) |
| return state.symbol_ids[src] |
|
|
|
|
| def generate_symbol_id(state, base_name): |
| next_id = len(state.symbol_ids) |
| state.symbol_ids[base_name + "_" + str(next_id)] = next_id |
| return next_id |
|
|
|
|
| def is_word_char(c): |
| return c.isalnum() or c == "-" or c == "_" |
|
|
|
|
| def hex_to_int(c): |
| if c.isdigit(): |
| return int(c) |
| elif "a" <= c.lower() <= "f": |
| return ord(c.lower()) - ord("a") + 10 |
| return -1 |
|
|
|
|
| def remove_leading_white_space(src, newline_ok): |
| """ |
| Skips over whitespace and comments in the input string. |
| This function processes the input string, skipping over any spaces, tabs, |
| and content following a '#' character, which denotes a comment. The parsing |
| of a comment continues until the end of the line (denoted by newline characters |
| '\r' or '\n'). If the 'newline_ok' parameter is set to False, the function |
| will stop processing and return the remaining string upon encountering a |
| newline character, otherwise it will skip over newline characters as well. |
| Parameters: |
| src (str): The input string to be processed. |
| newline_ok (bool): A flag indicating whether encountering a newline character |
| should stop the parsing (False) or if it should be skipped (True). |
| Returns: |
| str: The remaining portion of the input string after skipping whitespace and comments. |
| """ |
| pos = 0 |
| while pos < len(src) and (src[pos].isspace() or src[pos] == "#"): |
| if src[pos] == "#": |
| while pos < len(src) and src[pos] not in ("\r", "\n"): |
| pos += 1 |
| else: |
| if not newline_ok and src[pos] in ("\r", "\n"): |
| break |
| pos += 1 |
| return src[pos:] |
|
|
|
|
| def parse_name(src): |
| pos = 0 |
| while pos < len(src) and is_word_char(src[pos]): |
| pos += 1 |
| if pos == 0: |
| raise RuntimeError("expecting name at " + src) |
| return src[:pos], src[pos:] |
|
|
|
|
| def parse_char(src): |
| """ |
| parse the leading char from the input string |
| :param src: |
| :return: char, remaining_src |
| """ |
|
|
| |
| if src[0] == "\\": |
| esc = src[1] |
| if esc == "x": |
| first = hex_to_int(src[2]) |
| if first > -1: |
| second = hex_to_int(src[3]) |
| if second > -1: |
| return (first << 4) + second, src[4:] |
| raise RuntimeError("expecting \\xNN at " + src) |
| elif esc in ('"', "[", "]"): |
| return esc, src[2:] |
| elif esc == "r": |
| return "\r", src[2:] |
| elif esc == "n": |
| return "\n", src[2:] |
| elif esc == "t": |
| return "\t", src[2:] |
| raise RuntimeError("unknown escape at " + src) |
| elif src: |
| return src[0], src[1:] |
| raise RuntimeError("unexpected end of input") |
|
|
|
|
| def parse_sequence(state, src, rule_name, outbuf, is_nested): |
| out_start_pos = len(outbuf) |
|
|
| |
| outbuf.append(TO_BE_FILLED_MARKER) |
|
|
| last_sym_start = len(outbuf) |
| remaining_src = src |
| while remaining_src: |
| if remaining_src[0] == '"': |
| remaining_src = remaining_src[1:] |
| last_sym_start = len(outbuf) |
| while remaining_src[0] != '"': |
| char, remaining_src = parse_char(remaining_src) |
|
|
| |
| outbuf.append(LITERAL_MARKER) |
| outbuf.append(ord(char)) |
| outbuf.append(ord(char)) |
| remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
| elif remaining_src[0] == "[": |
| remaining_src = remaining_src[1:] |
| last_sym_start = len(outbuf) |
| |
| outbuf.append(TO_BE_FILLED_MARKER) |
| while remaining_src[0] != "]": |
| char, remaining_src = parse_char(remaining_src) |
|
|
| outbuf.append(ord(char)) |
| if remaining_src[0] == "-" and remaining_src[1] != "]": |
| endchar_pair, remaining_src = parse_char(remaining_src[1:]) |
| outbuf.append(ord(endchar_pair)) |
| else: |
| |
| outbuf.append(ord(char)) |
| |
| outbuf[last_sym_start] = len(outbuf) - last_sym_start - 1 |
| remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
| elif is_word_char(remaining_src[0]): |
| name, remaining_src = parse_name(remaining_src) |
| ref_rule_id = get_symbol_id(state, name) |
| remaining_src = remove_leading_white_space(remaining_src, is_nested) |
| last_sym_start = len(outbuf) |
| outbuf.append(REF_RULE_MARKER) |
| outbuf.append(ref_rule_id) |
| elif remaining_src[0] == "(": |
| |
| remaining_src = remove_leading_white_space(remaining_src[1:], True) |
| sub_rule_id = generate_symbol_id(state, rule_name) |
| remaining_src = parse_alternates(state, remaining_src, rule_name, sub_rule_id, True) |
| last_sym_start = len(outbuf) |
| |
| outbuf.append(REF_RULE_MARKER) |
| outbuf.append(sub_rule_id) |
| if remaining_src[0] != ")": |
| raise RuntimeError("expecting ')' at " + remaining_src) |
| remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
| elif remaining_src[0] in ("*", "+", "?"): |
| if len(outbuf) - out_start_pos - 1 == 0: |
| raise RuntimeError("expecting preceeding item to */+/? at " + remaining_src) |
| out_grammar = state.grammar_encoding |
|
|
| |
| |
| |
| |
| |
| sub_rule_id = generate_symbol_id(state, rule_name) |
| out_grammar.append(sub_rule_id) |
| sub_rule_start = len(out_grammar) |
| |
| out_grammar.append(TO_BE_FILLED_MARKER) |
| |
| out_grammar.extend(outbuf[last_sym_start:]) |
| if remaining_src[0] in ("*", "+"): |
| |
| out_grammar.append(REF_RULE_MARKER) |
| out_grammar.append(sub_rule_id) |
| |
| out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start |
| |
| out_grammar.append(END_OF_ALTERNATE_MARKER) |
| sub_rule_start = len(out_grammar) |
| |
| out_grammar.append(TO_BE_FILLED_MARKER) |
| if remaining_src[0] == "+": |
| |
| out_grammar.extend(outbuf[last_sym_start:]) |
| |
| out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start |
| |
| out_grammar.append(END_OF_ALTERNATE_MARKER) |
| out_grammar.append(END_OF_RULE_MARKER) |
|
|
| |
| outbuf[last_sym_start:] = [1, sub_rule_id] |
|
|
| remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) |
| else: |
| break |
| |
| outbuf[out_start_pos] = len(outbuf) - out_start_pos |
| |
| outbuf.append(END_OF_ALTERNATE_MARKER) |
| return remaining_src |
|
|
|
|
| def parse_alternates(state, src, rule_name, rule_id, is_nested): |
| outbuf = [] |
| remaining_src = parse_sequence(state, src, rule_name, outbuf, is_nested) |
| while remaining_src and remaining_src[0] == "|": |
| remaining_src = remove_leading_white_space(remaining_src[1:], True) |
| remaining_src = parse_sequence(state, remaining_src, rule_name, outbuf, is_nested) |
|
|
| state.grammar_encoding.append(rule_id) |
| state.grammar_encoding.extend(outbuf) |
| state.grammar_encoding.append(0) |
| return remaining_src |
|
|
|
|
| def parse_rule(state, src): |
| name, remaining_src = parse_name(src) |
| remaining_src = remove_leading_white_space(remaining_src, False) |
| rule_id = get_symbol_id(state, name) |
|
|
| if remaining_src[:3] != "::=": |
| raise RuntimeError("expecting ::= at " + remaining_src) |
| remaining_src = remove_leading_white_space(remaining_src[3:], True) |
|
|
| remaining_src = parse_alternates(state, remaining_src, name, rule_id, False) |
|
|
| if remaining_src and remaining_src[0] == "\r": |
| remaining_src = remaining_src[2:] if remaining_src[1] == "\n" else remaining_src[1:] |
| elif remaining_src and remaining_src[0] == "\n": |
| remaining_src = remaining_src[1:] |
| elif remaining_src: |
| raise RuntimeError("expecting newline or end at " + remaining_src) |
| return remove_leading_white_space(remaining_src, True) |
|
|
|
|
| def parse_ebnf(src): |
| try: |
| state = ParseState() |
| grammar_repr = remove_leading_white_space(src, True) |
| last_grammar_repr = "" |
| while grammar_repr: |
| if last_grammar_repr: |
| last_parsed_rule_len = len(last_grammar_repr) - len(grammar_repr) |
| logger.debug(f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}") |
| last_grammar_repr = grammar_repr |
| grammar_repr = parse_rule(state, grammar_repr) |
| state.grammar_encoding.append(0xFFFF) |
| return state |
| except RuntimeError as err: |
| logger.warning("error parsing grammar:", err) |
| return ParseState() |
|
|
|
|
| def print_rule(file, grammar_encoding, index, symbol_id_names): |
| rule_id = grammar_encoding[index] |
| print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file) |
| pos = index + 1 |
| while grammar_encoding[pos]: |
| if pos - 1 > index: |
| print("|", end=" ", file=file) |
| pos += 1 |
| while grammar_encoding[pos]: |
| if grammar_encoding[pos] == REF_RULE_MARKER: |
| ref_rule_id = grammar_encoding[pos + 1] |
| print( |
| f"<{pos}>{symbol_id_names[ref_rule_id]}", |
| end=" ", |
| file=file, |
| ) |
| pos += 2 |
| else: |
| print("<{}>[".format(pos), end="", file=file) |
| num_chars = grammar_encoding[pos] |
| pos += 1 |
|
|
| for i in range(0, num_chars, 2): |
| print("{}-".format(chr(grammar_encoding[pos + i])), end="", file=file) |
| if i + 1 < num_chars: |
| print("{}".format(chr(grammar_encoding[pos + i + 1])), end="", file=file) |
| print("]", end=" ", file=file) |
| pos += num_chars |
| pos += 1 |
| print(file=file) |
| return pos + 1 |
|
|
|
|
| def print_grammar(file, state): |
| pos = 0 |
| symbol_id_names = {v: k for k, v in state.symbol_ids.items()} |
| print("Grammar Rules:", file=file) |
|
|
| while state.grammar_encoding[pos] != 0xFFFF: |
| pos = print_rule(file, state.grammar_encoding, pos, symbol_id_names) |
| pos = 0 |
| print("\nBinary representation:", file=file) |
| while state.grammar_encoding[pos] != 0xFFFF: |
| print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file) |
| pos += 1 |
| print("ffff\n") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class GrammarConstraint(ABC): |
| def __init__(self, grammar_str, start_rule_name, tokenizer): |
| self.tt = 0 |
| self.nt = 0 |
| state = parse_ebnf(grammar_str) |
| grammar_encoding = state.grammar_encoding |
| self.start_rule_id = state.symbol_ids.get(start_rule_name) |
|
|
| self.eos_token_id = tokenizer.eos_token_id |
| self.token_trie = TokenTrie(tokenizer) |
| self.tokenizer = tokenizer |
| self.grammar_encoding = grammar_encoding |
|
|
| pos = 0 |
| rules: Dict[int, int] = {} |
|
|
| while grammar_encoding[pos] != 0xFFFF: |
| rule_id = grammar_encoding[pos] |
|
|
| |
| |
| rules[rule_id] = pos |
| pos += 1 |
|
|
| |
| |
| |
| while grammar_encoding[pos]: |
| pos += 1 + grammar_encoding[pos] |
| |
| |
| pos += 1 |
|
|
| self.start_rule_pos = rules[self.start_rule_id] |
| self.rules_pos_dict: Dict[int, int] = rules |
|
|
| def init_stacks(self): |
| |
| |
| |
| |
| stack = [self.start_rule_pos + 2] |
| |
| return self.advance_stack(tuple(stack)) |
|
|
| |
| |
| |
| |
| |
| @lru_cache(maxsize=32768) |
| def advance_stack(self, stack): |
| stack = list(stack) |
| |
| if len(stack) == 0: |
| return [stack] |
|
|
| |
| pos = stack[-1] |
|
|
| |
| |
| if self.grammar_encoding[pos] > 1: |
| return [stack] |
|
|
| |
| |
| |
| |
| |
| |
| referenced_rule_id = self.grammar_encoding[pos + 1] |
|
|
| |
| subpos = self.rules_pos_dict[referenced_rule_id] + 1 |
| stacks: List[List[int]] = [] |
|
|
| |
| |
| |
| while self.grammar_encoding[subpos]: |
| new_stack = stack[:-1] |
| if self.grammar_encoding[pos + 2]: |
| |
| |
| new_stack.append(pos + 2) |
|
|
| |
| if self.grammar_encoding[subpos + 1]: |
| new_stack.append(subpos + 1) |
| stacks.extend(self.advance_stack(tuple(new_stack))) |
| |
| |
| subpos += self.grammar_encoding[subpos] + 1 |
| return stacks |
|
|
| def accept_char(self, *args, **kwargs): |
| """Process a byte according to the grammar rules.""" |
| raise NotImplementedError |
|
|
| def accept_token_id(self, *args, **kwargs): |
| """Process a token according to the grammar rules.""" |
| raise NotImplementedError |
|
|
| def filter_vocab(self, *args, **kwargs): |
| raise NotImplementedError |
|
|
|
|
| class IncrementalGrammarConstraint(GrammarConstraint): |
| def __init__(self, grammar_str, start_rule_name, tokenizer): |
| super().__init__(grammar_str, start_rule_name, tokenizer) |
|
|
| def accept_char(self, byte, stacks): |
| new_stacks = [] |
| for stack in stacks: |
| |
| if not stack: |
| continue |
|
|
| pos = stack[-1] |
| num_chars = self.grammar_encoding[pos] |
|
|
| |
| pos += 1 |
| found = False |
| for i in range(0, num_chars, 2): |
| if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]: |
| found = True |
| break |
| if not found: |
| continue |
|
|
| pos += num_chars |
| new_stack = stack[:-1] |
| if self.grammar_encoding[pos]: |
| new_stack.append(pos) |
| new_stacks.extend(self.advance_stack(tuple(new_stack))) |
|
|
| return new_stacks |
|
|
| def accept_string(self, string: str, stacks: List[List[int]]): |
| _bytes = bytes(string, "utf-8") |
| for byte in _bytes: |
| stacks = self.accept_char(byte, stacks) |
| return stacks |
|
|
| def accept_token_id(self, token_id: int, stacks: List[List[int]]): |
| if token_id == self.eos_token_id: |
| if stacks and all(len(stack) != 0 for stack in stacks): |
| raise Exception( |
| f"At least one of the stack should be empty when EOS is reached. However, " |
| f"the stacks are {stacks}" |
| ) |
| return [] |
|
|
| for byte in self.token_trie.id2str(token_id): |
| stacks = self.accept_char(byte, stacks) |
| |
| |
| |
| |
|
|
| return stacks |
|
|
| def accept_token_ids(self, token_ids: List[int], stacks: List[List[int]], as_string=True): |
| if as_string: |
| string = self.tokenizer.decode(token_ids) |
| stacks = self.accept_string(string, stacks) |
| else: |
| for token_id in token_ids: |
| stacks = self.accept_token_id(token_id, stacks) |
| return stacks |
|
|
| def batch_filter_vocab(self, batch_stacks, device): |
| batch_acceptance = [] |
| for stacks in batch_stacks: |
| batch_acceptance.append(self.filter_vocab(stacks, device)) |
| return torch.stack(batch_acceptance) |
|
|
| def filter_vocab(self, stacks, device): |
| if not stacks: |
| |
| |
| vocab_size = len(self.token_trie) |
| logger.debug(f"sum of acceptance: {0}") |
| return torch.zeros(vocab_size, dtype=torch.bool, device=device) |
|
|
| acceptance_matrix = torch.cat([self.token_acceptance_for_stack(tuple(stack), device) for stack in stacks]) |
| |
| acceptance = acceptance_matrix.reshape(len(stacks), -1).any(dim=0) |
| logger.debug(f"sum of acceptance: {acceptance.sum()}") |
| return acceptance |
|
|
| |
| @lru_cache(maxsize=None) |
| def pos_char_acceptance(self, pos): |
| acceptance = [False] * 256 |
| num_chars = self.grammar_encoding[pos] |
| pos += 1 |
| for i in range(0, num_chars, 2): |
| start = self.grammar_encoding[pos + i] |
| end = self.grammar_encoding[pos + i + 1] |
| for j in range(start, end + 1): |
| acceptance[j] = True |
| return acceptance |
|
|
| |
| |
| |
| |
| |
| |
| @lru_cache(maxsize=32768) |
| def token_acceptance_for_stack(self, stack, device): |
| st = time.time() |
| stack = list(stack) |
|
|
| accepts = [False] * len(self.token_trie) |
| accepts[self.eos_token_id] = len(stack) == 0 |
| if len(stack) == 0: |
| logger.debug("empty stack") |
|
|
| def traverse_trie(trie, stacks): |
| for byte, next_trie in trie.items(): |
| if byte == LEAF: |
| token_id = next_trie |
| if token_id != self.eos_token_id: |
| accepts[token_id] = bool(stacks) |
| continue |
|
|
| new_stacks = [] |
| for stk in stacks: |
| if not stk: |
| continue |
|
|
| pos = stk[-1] |
| num_chars = self.grammar_encoding[pos] |
|
|
| if not self.pos_char_acceptance(pos)[byte]: |
| continue |
|
|
| pos += num_chars + 1 |
| new_stack = stk[:-1] |
| if self.grammar_encoding[pos]: |
| new_stack.append(pos) |
| new_stacks.extend(self.advance_stack(tuple(new_stack))) |
|
|
| if new_stacks: |
| traverse_trie(next_trie, new_stacks) |
|
|
| traverse_trie(self.token_trie.trie, [stack]) |
|
|
| et = time.time() - st |
| x = torch.tensor(accepts, dtype=torch.bool, device=device) |
| self.tt += et |
| self.nt += 1 |
| return x |
|
|
|
|
| class StaticGrammarConstraint(GrammarConstraint): |
| def __init__(self, grammar_str, start_rule_name, tokenizer): |
| super().__init__(grammar_str, start_rule_name, tokenizer) |
|
|
| def accept_char(self): |
| raise NotImplementedError |
|
|
|
|
| |
| |
| |
|
|
|
|
| LEAF = -1 |
|
|
|
|
| class TokenTrie: |
| def __init__(self, tokenizer): |
| self.eos_token_id = tokenizer.eos_token_id |
| self.tokens = [] |
| self.trie = {} |
| self.load_tokens(tokenizer) |
|
|
| def id2str(self, token_id): |
| return self.tokens[token_id] |
|
|
| def __len__(self): |
| return len(self.tokens) |
|
|
| def load_tokens(self, tokenizer): |
| def replace_hex(match): |
| hex_value = match.group(1) |
| return chr(int(hex_value, 16)) |
|
|
| if "gpt2" in tokenizer.__class__.__name__.lower(): |
| special = tokenizer.additional_special_tokens_ids |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def fmt_token(id): |
| if id in special: |
| return None |
| return bytes(tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8") |
|
|
| elif "llama" in tokenizer.__class__.__name__.lower(): |
|
|
| def fmt_token(id): |
| token = tokenizer.convert_ids_to_tokens(id) |
| token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) |
| token = token.replace("▁", " ") |
| return bytes(token, "utf-8") |
|
|
| else: |
| print("Warning: unrecognized tokenizer: using default token formatting") |
|
|
| def fmt_token(id): |
| token = tokenizer.convert_ids_to_tokens(id) |
| return bytes(token, "utf-8") |
|
|
| |
| |
| self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))] |
| for token_id, token_bytes in enumerate(self.tokens): |
| if token_bytes is not None: |
| self.insert_into_trie(self.trie, token_bytes, token_id) |
|
|
| def insert_into_trie(self, trie, token_bytes, token_id): |
| current = trie |
| for byte in token_bytes: |
| if byte not in current: |
| current[byte] = {} |
| current = current[byte] |
| current[LEAF] = token_id |
|
|
|
|
| @lru_cache(maxsize=5) |
| def initialize_grammar(grammar_string): |
| return IncrementalGrammarConstraint(grammar_string.strip(), start_rule_name="root", tokenizer=shared.tokenizer) |
|
|