| |
| |
| |
| |
|
|
| """Implements tracking of constraints for a beam item. |
| |
| A list of constraints is given as a list of one or more token |
| sequences, each of length at least one token. For example, for an input sentence |
| |
| > Die maschinelle Übersetzung ist schwer zu kontrollieren. |
| |
| We could have the constraints: |
| * to influence |
| * hard |
| |
| There are two implementations: |
| * OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. |
| * UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. |
| |
| The difference is that in the first, the constraints are assumed to be |
| in order; the algorithm will permit zero or more tokens between them. |
| In the second, the constraints are not ordered, so many orderings will |
| be explored. |
| |
| The same sequence can be present any number of times, and will appear |
| that many times in the output. |
| """ |
|
|
| from collections import Counter |
| from typing import List, Optional, Set, Tuple |
|
|
| import torch |
|
|
|
|
| class ConstraintState: |
| def __init__(self): |
| pass |
|
|
|
|
| def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: |
| """Takes a list of list of constraints in tensor form (a list of |
| tensor constraints for each sentence) and transforms it into a |
| packed Tensor. For example, here is a batch of size 3 with 3, 0, |
| and 1 constraints: |
| |
| [ [ [3 1 2], [3], [4 5 6 7], ] |
| [], |
| [ [1 8 9 10 1 4 11 12], ] |
| ] |
| |
| Its corresponding packed structure is: |
| |
| [ [ 3 3 1 2 0 3 0 4 5 6 7 0], |
| [ 0 0 0 0 0 0 0 0 0 0 0 0], |
| [ 1 1 8 9 10 1 4 11 12 0 0 0] ] |
| |
| The packed tensor has shape (batch size, maxlen), where |
| maxlen is defined below. Each row contains concatenated |
| constraint tokens for that sentence, with 0 appended after |
| each constraint. The first item in each row is the number |
| of constraints for that sentence. So maxlen is the maximum |
| of |
| |
| (number of constraints) + (sum length of constraints) + 1. |
| |
| across all sentences in the batch. |
| """ |
| |
| max_constraints_len = 1 |
| for sentence_constraints in batch_constraints: |
| if len(sentence_constraints): |
| |
| constraints_len = ( |
| 1 |
| + sum([c.size(0) for c in sentence_constraints]) |
| + len(sentence_constraints) |
| ) |
| max_constraints_len = max(max_constraints_len, constraints_len) |
|
|
| batch_size = len(batch_constraints) |
| constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() |
| for i, sentence_constraints in enumerate(batch_constraints): |
| constraints_tensor[i, 0] = len(sentence_constraints) |
| offset = 1 |
| for j, constraint in enumerate(sentence_constraints): |
| this_len = constraint.size(0) |
| constraints_tensor[i, offset : offset + this_len] = constraint |
| offset += this_len + 1 |
|
|
| return constraints_tensor.long() |
|
|
|
|
| def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: |
| """ |
| Transforms *one row* of a packed constraint tensor (e.g., for one |
| sentence in the batch) into a list of constraint tensors. |
| """ |
| constraint_list = [] |
| num_constraints = constraint_tensor[0] |
| constraints = constraint_tensor.tolist() |
| offset = 1 |
| for i in range(num_constraints): |
| where = constraints.index(0, offset) |
| constraint_list.append(constraint_tensor[offset:where]) |
| offset = where + 1 |
|
|
| return constraint_list |
|
|
|
|
| class ConstraintNode: |
| """ |
| Represents a node in a trie managing unordered constraints. |
| """ |
|
|
| def __init__(self, token: int = None, parent=None): |
| |
| self.token = int(token) if token is not None else None |
| |
| self.parent = parent |
| |
| self.terminal = 0 |
| |
| self.children = {} |
|
|
| |
| |
| self.num_constraints = 0 |
|
|
| @property |
| def id(self): |
| return self.token |
|
|
| def __str__(self): |
| term = self.terminal != 0 |
| return f"[{self.token}].{term}#{self.num_constraints}" |
|
|
| def __getitem__(self, key: int): |
| return self.children.get(key, None) |
|
|
| def next_tokens(self) -> Set[int]: |
| """The set of child labels.""" |
| return set(self.children.keys()) |
|
|
| @staticmethod |
| def create(constraints: List[List[int]]): |
| root = ConstraintNode() |
| for sequence in constraints: |
| root.add_sequence(sequence) |
|
|
| return root |
|
|
| @staticmethod |
| def print_graph(node: "ConstraintNode"): |
| if len(node.children) == 0: |
| return str(node) |
| else: |
| s = f"({node}" |
| for child in node.children.values(): |
| s += " " + ConstraintNode.print_graph(child) |
| s += ")" |
| return s |
|
|
| def token_counts(self) -> Counter: |
| """Returns a counter of the number of times each token is used |
| in a constraint. |
| """ |
| token_counts = Counter() |
| kids = list(self.children.values()) |
| while len(kids) > 0: |
| kid = kids.pop() |
| token_counts[kid.id] += kid.num_constraints |
| kids += list(kid.children.values()) |
|
|
| return token_counts |
|
|
| def tokens(self) -> Set[int]: |
| """Returns the set of tokens in constraints.""" |
| return set(self.token_counts().keys()) |
|
|
| def add_sequence(self, sequence: List[int]): |
| """Adds a constraint, represented as a list of integers, to |
| the trie.""" |
| assert len(sequence) > 0 |
|
|
| token = int(sequence[0]) |
| if token not in self.children: |
| self.children[token] = ConstraintNode(token, parent=self) |
|
|
| node = self.children[token] |
| if len(sequence) == 1: |
| node.terminal += 1 |
| node.num_constraints += 1 |
| parent = node.parent |
| while parent is not None: |
| parent.num_constraints += 1 |
| parent = parent.parent |
| else: |
| node.add_sequence(sequence[1:]) |
|
|
|
|
| class UnorderedConstraintState(ConstraintState): |
| """ |
| Records progress through the set of constraints for each item in the beam |
| using a trie. |
| """ |
|
|
| def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): |
| self.node = node |
|
|
| if copy_from is None: |
| |
| self.root = node |
| |
| self.completed = Counter() |
| |
| self.generated = Counter() |
| |
| self.needed_tokens = self.root.tokens() |
| else: |
| self.completed = Counter(copy_from.completed) |
| self.generated = Counter(copy_from.generated) |
| self.root = copy_from.root |
|
|
| |
| if self.node != self.root: |
| self.generated[node] += 1 |
|
|
| @staticmethod |
| def create(constraint_tensor: torch.Tensor): |
| constraint_list = unpack_constraints(constraint_tensor) |
| constraint_trie_root = ConstraintNode.create(constraint_list) |
| return UnorderedConstraintState(constraint_trie_root) |
|
|
| def __str__(self): |
| gen_str = ",".join([str(node) for node in self.generated]) |
| return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" |
|
|
| def __copy__(self): |
| copied_state = UnorderedConstraintState(self.node, copy_from=self) |
| return copied_state |
|
|
| def copy(self): |
| return self.__copy__() |
|
|
| @property |
| def name(self): |
| if self.node.id is None: |
| return "ROOT" |
| else: |
| return str(self.node.id) |
|
|
| @property |
| def is_root(self): |
| return self.node == self.root |
|
|
| @property |
| def bank(self): |
| return sum(self.generated.values()) |
|
|
| @property |
| def num_completed(self): |
| """The number of constraints (not constraint tokens) that are completed. |
| In addition to the already-completed states, we need to account for the |
| current state, which might get marked as completed when another token |
| is generated. |
| """ |
| in_final = self.node.terminal and self.completed[self.node] < self.node.terminal |
| return sum(self.completed.values()) + in_final |
|
|
| @property |
| def finished(self): |
| return self.root.num_constraints - self.num_completed == 0 |
|
|
| @property |
| def token_counts(self): |
| return self.root.token_counts() |
|
|
| @property |
| def tokens(self): |
| return self.root.tokens() |
|
|
| @property |
| def num_constraint_tokens(self): |
| return sum(self.token_counts.values()) |
|
|
| def next_tokens(self) -> Set[int]: |
| """Returns the list of tokens that could come next. |
| These are (a) all tokens extending the root state and, for |
| non-root states, additionally all tokens extending the current |
| state.""" |
|
|
| if self.node != self.root: |
| return self.root.next_tokens().union(self.node.next_tokens()) |
| else: |
| return self.root.next_tokens() |
|
|
| def advance(self, token: int): |
| """Reads in a token and advances the state. Here's how it works. |
| |
| We can advance to the next state if: |
| - there is a matching child |
| - its path isn't blocked |
| |
| A path is blocked when all constraints that are descendants of |
| that node have already been generated, in the current state. |
| |
| If we are not able to advance from the current state, we "fall |
| off the graph" and return to the root state. There, we again |
| try to advance, checking the same criteria. |
| |
| In any case, when falling off the graph, we need to do some |
| bookkeeping. We: |
| - check whether any constraints were met (all prefixes of |
| current state) |
| - if one is found, mark it as completed |
| - adjust visited nodes accordingly |
| """ |
| token = int(token) |
|
|
| next_state = None |
| child = self.node[token] |
| if child is not None and self.generated[child] < child.num_constraints: |
| next_state = UnorderedConstraintState(child, copy_from=self) |
|
|
| def rewind(): |
| """If we're mid-trie and an "illegal" token is chosen next, we need |
| to reset our state to the root state. However, along the way, we need |
| to check whether a prefix of the current trie state represents a state |
| we could mark as completed. |
| """ |
| node = self.node |
| while node != self.root: |
| if node.terminal and self.completed[node] < node.terminal: |
| next_state.completed[node] += 1 |
| return |
|
|
| next_state.generated[node] -= 1 |
| node = node.parent |
|
|
| |
| if next_state is None and token in self.root.next_tokens(): |
| child = self.root[token] |
| |
| if self.generated[child] < child.num_constraints: |
| next_state = UnorderedConstraintState(child, copy_from=self) |
| else: |
| next_state = UnorderedConstraintState(self.root, copy_from=self) |
|
|
| |
| rewind() |
|
|
| elif next_state is None: |
| next_state = UnorderedConstraintState(self.root, copy_from=self) |
| |
| rewind() |
|
|
| return next_state |
|
|
|
|
| class ConstraintSequence: |
| def __init__(self, sequences: List[List[int]]): |
| """Represents a set of possibly multitoken constraints by |
| concatenating them and internally recording the end points. |
| """ |
| self.sequences = [] |
| self.endpoints = [] |
| self.num_tokens = 0 |
| self.tokens = set() |
| for sequence in sequences: |
| for token in sequence: |
| self.tokens.add(token) |
| self.num_tokens += len(sequence) |
| self.endpoints += [False for x in range(len(sequence) - 1)] + [True] |
| self.sequences += sequence |
|
|
| def __getitem__(self, key: int): |
| return self.sequences[key] |
|
|
| def __len__(self): |
| return len(self.sequences) |
|
|
| def __str__(self): |
| return str(self.sequences) |
|
|
|
|
| class OrderedConstraintState(ConstraintState): |
| """ |
| Records progress through the set of linear nonbranching constraints with gaps. |
| """ |
|
|
| def __init__(self, sequence: ConstraintSequence, state: int = -1): |
| self.sequence = sequence |
| self.state = state |
|
|
| @staticmethod |
| def create(constraint_tensor: torch.Tensor): |
| constraint_list = unpack_constraints(constraint_tensor) |
| return OrderedConstraintState(ConstraintSequence(constraint_list), -1) |
|
|
| def __str__(self): |
| return f"{self.state}/{self.bank}x{self.num_completed}" |
|
|
| def __copy__(self): |
| return OrderedConstraintState(self.sequence, self.state) |
|
|
| def copy(self): |
| return self.__copy__() |
|
|
| @property |
| def num_completed(self): |
| if self.state == -1: |
| return 0 |
| count = len( |
| list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) |
| ) |
| return count |
|
|
| @property |
| def is_root(self): |
| return self.state == -1 |
|
|
| @property |
| def name(self): |
| if self.state == -1: |
| return "ROOT" |
| else: |
| return str(self.sequence[self.state]) |
|
|
| @property |
| def bank(self) -> int: |
| return self.state + 1 |
|
|
| @property |
| def finished(self): |
| return self.state + 1 == len(self.sequence) |
|
|
| @property |
| def token_counts(self): |
| return self.sequence.token_counts() |
|
|
| @property |
| def tokens(self): |
| return self.sequence.tokens |
|
|
| @property |
| def num_constraint_tokens(self): |
| return sum(self.token_counts.values()) |
|
|
| def next_tokens(self) -> Set[int]: |
| """Returns the list of tokens that could come next. |
| These are (a) all tokens extending the root state and, for |
| non-root states, additionally all tokens extending the current |
| state.""" |
|
|
| tokens = set() |
| if self.state > 0: |
| tokens.add(self.sequence[0]) |
| if not self.finished: |
| tokens.add(self.sequence[self.state + 1]) |
| return tokens |
|
|
| def advance(self, token: int): |
| """Reads in a token and advances the state. Here's how it works. |
| |
| We can advance to the next state if: |
| - there is a matching child |
| - its path isn't blocked |
| |
| A path is blocked when all constraints that are descendants of |
| that node have already been generated, in the current state. |
| |
| If we are not able to advance from the current state, we "fall |
| off the graph" and return to the root state. There, we again |
| try to advance, checking the same criteria. |
| |
| In any case, when falling off the graph, we need to do some |
| bookkeeping. We: |
| - check whether any constraints were met (all prefixes of |
| current state) |
| - if one is found, mark it as completed |
| - adjust visited nodes accordingly |
| """ |
| token = int(token) |
| |
|
|
| if self.finished: |
| |
| next_state = self.copy() |
|
|
| elif self.sequence[self.state + 1] == token: |
| |
| next_state = OrderedConstraintState(self.sequence, self.state + 1) |
|
|
| elif self.sequence.endpoints[self.state]: |
| |
| next_state = self.copy() |
|
|
| elif token == self.sequence[0]: |
| |
| next_state = OrderedConstraintState(self.sequence, 0) |
| else: |
| |
| next_state = OrderedConstraintState(self.sequence, -1) |
|
|
| return next_state |
|
|