| | ''' |
| | This file has been 100% copied from this PR to the Transformers library: |
| | https://github.com/huggingface/transformers/pull/27557 |
| | |
| | Author: Saibo-creator |
| | Author GitHub: https://github.com/Saibo-creator |
| | |
| | All credits go to the author. |
| | ''' |
| |
|
| | import math |
| |
|
| | import torch |
| | from transformers.generation.logits_process import LogitsProcessor |
| | from transformers.utils import add_start_docstrings |
| |
|
| | LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" |
| | Args: |
| | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| | Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
| | scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): |
| | Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam |
| | search or log softmax for each vocabulary token when using beam search |
| | |
| | Return: |
| | `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. |
| | |
| | """ |
| |
|
| |
|
| | class GrammarConstrainedLogitsProcessor(LogitsProcessor): |
| | def __init__(self, grammar_constraint): |
| | self.last_size = None |
| | self.grammar_constraint = grammar_constraint |
| | self.batch_stacks = None |
| |
|
| | def filter_logits(self, logits, device): |
| | |
| | |
| | |
| | acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device) |
| | |
| | |
| | logits[~acceptance] = -math.inf |
| |
|
| | |
| | def process_logits(self, input_ids, scores, parse_start_index=None): |
| | """ |
| | :param input_ids: |
| | :param scores: |
| | :param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids |
| | :return: |
| | """ |
| | |
| | if self.batch_stacks is None: |
| | self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))] |
| |
|
| | |
| | |
| | if self.last_size is None: |
| | prefix_to_parse = [ |
| | single_input_ids[parse_start_index:] if parse_start_index is not None else [] |
| | for single_input_ids in input_ids |
| | ] |
| | |
| | self.batch_stacks = [ |
| | self.grammar_constraint.accept_token_ids(prefix, stack) |
| | for prefix, stack in zip(prefix_to_parse, self.batch_stacks) |
| | ] |
| | |
| | |
| | elif len(input_ids[0]) == self.last_size + 1: |
| | |
| | self.batch_stacks = [ |
| | self.grammar_constraint.accept_token_id(single_input_ids[-1], stack) |
| | for single_input_ids, stack in zip(input_ids, self.batch_stacks) |
| | ] |
| | |
| | |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | raise RuntimeError( |
| | "Input ID's length is inconsistent with the current state of " |
| | "the GrammarConstrainedLogitsProcessor. If you want to process " |
| | "another input sequence, please instantiate a new " |
| | "GrammarConstrainedLogitsProcessor." |
| | ) |
| |
|
| | self.filter_logits(scores, scores.device) |
| |
|
| | self.last_size = len(input_ids[0]) |
| | return scores |
| |
|
| | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| | return self.process_logits(input_ids, scores) |
| |
|