| | |
| | |
| | |
| | |
| |
|
| | """Implements tracking of constraints for a beam item. |
| | |
| | A list of constraints is given as a list of one or more token |
| | sequences, each of length at least one token. For example, for an input sentence |
| | |
| | > Die maschinelle Übersetzung ist schwer zu kontrollieren. |
| | |
| | We could have the constraints: |
| | * to influence |
| | * hard |
| | |
| | There are two implementations: |
| | * OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. |
| | * UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. |
| | |
| | The difference is that in the first, the constraints are assumed to be |
| | in order; the algorithm will permit zero or more tokens between them. |
| | In the second, the constraints are not ordered, so many orderings will |
| | be explored. |
| | |
| | The same sequence can be present any number of times, and will appear |
| | that many times in the output. |
| | """ |
| |
|
| | from collections import Counter |
| | from typing import List, Optional, Set, Tuple |
| |
|
| | import torch |
| |
|
| |
|
| | class ConstraintState: |
| | def __init__(self): |
| | pass |
| |
|
| |
|
| | def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: |
| | """Takes a list of list of constraints in tensor form (a list of |
| | tensor constraints for each sentence) and transforms it into a |
| | packed Tensor. For example, here is a batch of size 3 with 3, 0, |
| | and 1 constraints: |
| | |
| | [ [ [3 1 2], [3], [4 5 6 7], ] |
| | [], |
| | [ [1 8 9 10 1 4 11 12], ] |
| | ] |
| | |
| | Its corresponding packed structure is: |
| | |
| | [ [ 3 3 1 2 0 3 0 4 5 6 7 0], |
| | [ 0 0 0 0 0 0 0 0 0 0 0 0], |
| | [ 1 1 8 9 10 1 4 11 12 0 0 0] ] |
| | |
| | The packed tensor has shape (batch size, maxlen), where |
| | maxlen is defined below. Each row contains concatenated |
| | constraint tokens for that sentence, with 0 appended after |
| | each constraint. The first item in each row is the number |
| | of constraints for that sentence. So maxlen is the maximum |
| | of |
| | |
| | (number of constraints) + (sum length of constraints) + 1. |
| | |
| | across all sentences in the batch. |
| | """ |
| | |
| | max_constraints_len = 1 |
| | for sentence_constraints in batch_constraints: |
| | if len(sentence_constraints): |
| | |
| | constraints_len = ( |
| | 1 |
| | + sum([c.size(0) for c in sentence_constraints]) |
| | + len(sentence_constraints) |
| | ) |
| | max_constraints_len = max(max_constraints_len, constraints_len) |
| |
|
| | batch_size = len(batch_constraints) |
| | constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() |
| | for i, sentence_constraints in enumerate(batch_constraints): |
| | constraints_tensor[i, 0] = len(sentence_constraints) |
| | offset = 1 |
| | for j, constraint in enumerate(sentence_constraints): |
| | this_len = constraint.size(0) |
| | constraints_tensor[i, offset : offset + this_len] = constraint |
| | offset += this_len + 1 |
| |
|
| | return constraints_tensor.long() |
| |
|
| |
|
| | def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: |
| | """ |
| | Transforms *one row* of a packed constraint tensor (e.g., for one |
| | sentence in the batch) into a list of constraint tensors. |
| | """ |
| | constraint_list = [] |
| | num_constraints = constraint_tensor[0] |
| | constraints = constraint_tensor.tolist() |
| | offset = 1 |
| | for i in range(num_constraints): |
| | where = constraints.index(0, offset) |
| | constraint_list.append(constraint_tensor[offset:where]) |
| | offset = where + 1 |
| |
|
| | return constraint_list |
| |
|
| |
|
| | class ConstraintNode: |
| | """ |
| | Represents a node in a trie managing unordered constraints. |
| | """ |
| |
|
| | def __init__(self, token: int = None, parent=None): |
| | |
| | self.token = int(token) if token is not None else None |
| | |
| | self.parent = parent |
| | |
| | self.terminal = 0 |
| | |
| | self.children = {} |
| |
|
| | |
| | |
| | self.num_constraints = 0 |
| |
|
| | @property |
| | def id(self): |
| | return self.token |
| |
|
| | def __str__(self): |
| | term = self.terminal != 0 |
| | return f"[{self.token}].{term}#{self.num_constraints}" |
| |
|
| | def __getitem__(self, key: int): |
| | return self.children.get(key, None) |
| |
|
| | def next_tokens(self) -> Set[int]: |
| | """The set of child labels.""" |
| | return set(self.children.keys()) |
| |
|
| | @staticmethod |
| | def create(constraints: List[List[int]]): |
| | root = ConstraintNode() |
| | for sequence in constraints: |
| | root.add_sequence(sequence) |
| |
|
| | return root |
| |
|
| | @staticmethod |
| | def print_graph(node: "ConstraintNode"): |
| | if len(node.children) == 0: |
| | return str(node) |
| | else: |
| | s = f"({node}" |
| | for child in node.children.values(): |
| | s += " " + ConstraintNode.print_graph(child) |
| | s += ")" |
| | return s |
| |
|
| | def token_counts(self) -> Counter: |
| | """Returns a counter of the number of times each token is used |
| | in a constraint. |
| | """ |
| | token_counts = Counter() |
| | kids = list(self.children.values()) |
| | while len(kids) > 0: |
| | kid = kids.pop() |
| | token_counts[kid.id] += kid.num_constraints |
| | kids += list(kid.children.values()) |
| |
|
| | return token_counts |
| |
|
| | def tokens(self) -> Set[int]: |
| | """Returns the set of tokens in constraints.""" |
| | return set(self.token_counts().keys()) |
| |
|
| | def add_sequence(self, sequence: List[int]): |
| | """Adds a constraint, represented as a list of integers, to |
| | the trie.""" |
| | assert len(sequence) > 0 |
| |
|
| | token = int(sequence[0]) |
| | if token not in self.children: |
| | self.children[token] = ConstraintNode(token, parent=self) |
| |
|
| | node = self.children[token] |
| | if len(sequence) == 1: |
| | node.terminal += 1 |
| | node.num_constraints += 1 |
| | parent = node.parent |
| | while parent is not None: |
| | parent.num_constraints += 1 |
| | parent = parent.parent |
| | else: |
| | node.add_sequence(sequence[1:]) |
| |
|
| |
|
| | class UnorderedConstraintState(ConstraintState): |
| | """ |
| | Records progress through the set of constraints for each item in the beam |
| | using a trie. |
| | """ |
| |
|
| | def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): |
| | self.node = node |
| |
|
| | if copy_from is None: |
| | |
| | self.root = node |
| | |
| | self.completed = Counter() |
| | |
| | self.generated = Counter() |
| | |
| | self.needed_tokens = self.root.tokens() |
| | else: |
| | self.completed = Counter(copy_from.completed) |
| | self.generated = Counter(copy_from.generated) |
| | self.root = copy_from.root |
| |
|
| | |
| | if self.node != self.root: |
| | self.generated[node] += 1 |
| |
|
| | @staticmethod |
| | def create(constraint_tensor: torch.Tensor): |
| | constraint_list = unpack_constraints(constraint_tensor) |
| | constraint_trie_root = ConstraintNode.create(constraint_list) |
| | return UnorderedConstraintState(constraint_trie_root) |
| |
|
| | def __str__(self): |
| | gen_str = ",".join([str(node) for node in self.generated]) |
| | return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" |
| |
|
| | def __copy__(self): |
| | copied_state = UnorderedConstraintState(self.node, copy_from=self) |
| | return copied_state |
| |
|
| | def copy(self): |
| | return self.__copy__() |
| |
|
| | @property |
| | def name(self): |
| | if self.node.id is None: |
| | return "ROOT" |
| | else: |
| | return str(self.node.id) |
| |
|
| | @property |
| | def is_root(self): |
| | return self.node == self.root |
| |
|
| | @property |
| | def bank(self): |
| | return sum(self.generated.values()) |
| |
|
| | @property |
| | def num_completed(self): |
| | """The number of constraints (not constraint tokens) that are completed. |
| | In addition to the already-completed states, we need to account for the |
| | current state, which might get marked as completed when another token |
| | is generated. |
| | """ |
| | in_final = self.node.terminal and self.completed[self.node] < self.node.terminal |
| | return sum(self.completed.values()) + in_final |
| |
|
| | @property |
| | def finished(self): |
| | return self.root.num_constraints - self.num_completed == 0 |
| |
|
| | @property |
| | def token_counts(self): |
| | return self.root.token_counts() |
| |
|
| | @property |
| | def tokens(self): |
| | return self.root.tokens() |
| |
|
| | @property |
| | def num_constraint_tokens(self): |
| | return sum(self.token_counts.values()) |
| |
|
| | def next_tokens(self) -> Set[int]: |
| | """Returns the list of tokens that could come next. |
| | These are (a) all tokens extending the root state and, for |
| | non-root states, additionally all tokens extending the current |
| | state.""" |
| |
|
| | if self.node != self.root: |
| | return self.root.next_tokens().union(self.node.next_tokens()) |
| | else: |
| | return self.root.next_tokens() |
| |
|
| | def advance(self, token: int): |
| | """Reads in a token and advances the state. Here's how it works. |
| | |
| | We can advance to the next state if: |
| | - there is a matching child |
| | - its path isn't blocked |
| | |
| | A path is blocked when all constraints that are descendants of |
| | that node have already been generated, in the current state. |
| | |
| | If we are not able to advance from the current state, we "fall |
| | off the graph" and return to the root state. There, we again |
| | try to advance, checking the same criteria. |
| | |
| | In any case, when falling off the graph, we need to do some |
| | bookkeeping. We: |
| | - check whether any constraints were met (all prefixes of |
| | current state) |
| | - if one is found, mark it as completed |
| | - adjust visited nodes accordingly |
| | """ |
| | token = int(token) |
| |
|
| | next_state = None |
| | child = self.node[token] |
| | if child is not None and self.generated[child] < child.num_constraints: |
| | next_state = UnorderedConstraintState(child, copy_from=self) |
| |
|
| | def rewind(): |
| | """If we're mid-trie and an "illegal" token is chosen next, we need |
| | to reset our state to the root state. However, along the way, we need |
| | to check whether a prefix of the current trie state represents a state |
| | we could mark as completed. |
| | """ |
| | node = self.node |
| | while node != self.root: |
| | if node.terminal and self.completed[node] < node.terminal: |
| | next_state.completed[node] += 1 |
| | return |
| |
|
| | next_state.generated[node] -= 1 |
| | node = node.parent |
| |
|
| | |
| | if next_state is None and token in self.root.next_tokens(): |
| | child = self.root[token] |
| | |
| | if self.generated[child] < child.num_constraints: |
| | next_state = UnorderedConstraintState(child, copy_from=self) |
| | else: |
| | next_state = UnorderedConstraintState(self.root, copy_from=self) |
| |
|
| | |
| | rewind() |
| |
|
| | elif next_state is None: |
| | next_state = UnorderedConstraintState(self.root, copy_from=self) |
| | |
| | rewind() |
| |
|
| | return next_state |
| |
|
| |
|
| | class ConstraintSequence: |
| | def __init__(self, sequences: List[List[int]]): |
| | """Represents a set of possibly multitoken constraints by |
| | concatenating them and internally recording the end points. |
| | """ |
| | self.sequences = [] |
| | self.endpoints = [] |
| | self.num_tokens = 0 |
| | self.tokens = set() |
| | for sequence in sequences: |
| | for token in sequence: |
| | self.tokens.add(token) |
| | self.num_tokens += len(sequence) |
| | self.endpoints += [False for x in range(len(sequence) - 1)] + [True] |
| | self.sequences += sequence |
| |
|
| | def __getitem__(self, key: int): |
| | return self.sequences[key] |
| |
|
| | def __len__(self): |
| | return len(self.sequences) |
| |
|
| | def __str__(self): |
| | return str(self.sequences) |
| |
|
| |
|
| | class OrderedConstraintState(ConstraintState): |
| | """ |
| | Records progress through the set of linear nonbranching constraints with gaps. |
| | """ |
| |
|
| | def __init__(self, sequence: ConstraintSequence, state: int = -1): |
| | self.sequence = sequence |
| | self.state = state |
| |
|
| | @staticmethod |
| | def create(constraint_tensor: torch.Tensor): |
| | constraint_list = unpack_constraints(constraint_tensor) |
| | return OrderedConstraintState(ConstraintSequence(constraint_list), -1) |
| |
|
| | def __str__(self): |
| | return f"{self.state}/{self.bank}x{self.num_completed}" |
| |
|
| | def __copy__(self): |
| | return OrderedConstraintState(self.sequence, self.state) |
| |
|
| | def copy(self): |
| | return self.__copy__() |
| |
|
| | @property |
| | def num_completed(self): |
| | if self.state == -1: |
| | return 0 |
| | count = len( |
| | list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) |
| | ) |
| | return count |
| |
|
| | @property |
| | def is_root(self): |
| | return self.state == -1 |
| |
|
| | @property |
| | def name(self): |
| | if self.state == -1: |
| | return "ROOT" |
| | else: |
| | return str(self.sequence[self.state]) |
| |
|
| | @property |
| | def bank(self) -> int: |
| | return self.state + 1 |
| |
|
| | @property |
| | def finished(self): |
| | return self.state + 1 == len(self.sequence) |
| |
|
| | @property |
| | def token_counts(self): |
| | return self.sequence.token_counts() |
| |
|
| | @property |
| | def tokens(self): |
| | return self.sequence.tokens |
| |
|
| | @property |
| | def num_constraint_tokens(self): |
| | return sum(self.token_counts.values()) |
| |
|
| | def next_tokens(self) -> Set[int]: |
| | """Returns the list of tokens that could come next. |
| | These are (a) all tokens extending the root state and, for |
| | non-root states, additionally all tokens extending the current |
| | state.""" |
| |
|
| | tokens = set() |
| | if self.state > 0: |
| | tokens.add(self.sequence[0]) |
| | if not self.finished: |
| | tokens.add(self.sequence[self.state + 1]) |
| | return tokens |
| |
|
| | def advance(self, token: int): |
| | """Reads in a token and advances the state. Here's how it works. |
| | |
| | We can advance to the next state if: |
| | - there is a matching child |
| | - its path isn't blocked |
| | |
| | A path is blocked when all constraints that are descendants of |
| | that node have already been generated, in the current state. |
| | |
| | If we are not able to advance from the current state, we "fall |
| | off the graph" and return to the root state. There, we again |
| | try to advance, checking the same criteria. |
| | |
| | In any case, when falling off the graph, we need to do some |
| | bookkeeping. We: |
| | - check whether any constraints were met (all prefixes of |
| | current state) |
| | - if one is found, mark it as completed |
| | - adjust visited nodes accordingly |
| | """ |
| | token = int(token) |
| | |
| |
|
| | if self.finished: |
| | |
| | next_state = self.copy() |
| |
|
| | elif self.sequence[self.state + 1] == token: |
| | |
| | next_state = OrderedConstraintState(self.sequence, self.state + 1) |
| |
|
| | elif self.sequence.endpoints[self.state]: |
| | |
| | next_state = self.copy() |
| |
|
| | elif token == self.sequence[0]: |
| | |
| | next_state = OrderedConstraintState(self.sequence, 0) |
| | else: |
| | |
| | next_state = OrderedConstraintState(self.sequence, -1) |
| |
|
| | return next_state |
| |
|