| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import Dict, List |
|
|
|
|
| class Trie(object): |
| def __init__(self, sequences: List[List[int]] = []): |
| self.trie_dict = {} |
| self.len = 0 |
| if sequences: |
| for sequence in sequences: |
| Trie._add_to_trie(sequence, self.trie_dict) |
| self.len += 1 |
|
|
| self.append_trie = None |
| self.bos_token_id = None |
|
|
| def append(self, trie, bos_token_id): |
| self.append_trie = trie |
| self.bos_token_id = bos_token_id |
|
|
| def add(self, sequence: List[int]): |
| Trie._add_to_trie(sequence, self.trie_dict) |
| self.len += 1 |
|
|
| def get(self, prefix_sequence: List[int]): |
| return Trie._get_from_trie( |
| prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id |
| ) |
|
|
| @staticmethod |
| def load_from_dict(trie_dict): |
| trie = Trie() |
| trie.trie_dict = trie_dict |
| trie.len = sum(1 for _ in trie) |
| return trie |
|
|
| @staticmethod |
| def _add_to_trie(sequence: List[int], trie_dict: Dict): |
| if sequence: |
| if sequence[0] not in trie_dict: |
| trie_dict[sequence[0]] = {} |
| Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]]) |
|
|
| @staticmethod |
| def _get_from_trie( |
| prefix_sequence: List[int], |
| trie_dict: Dict, |
| append_trie=None, |
| bos_token_id: int = None, |
| ): |
| if len(prefix_sequence) == 0: |
| output = list(trie_dict.keys()) |
| if append_trie and bos_token_id in output: |
| output.remove(bos_token_id) |
| output += list(append_trie.trie_dict.keys()) |
| return output |
| elif prefix_sequence[0] in trie_dict: |
| return Trie._get_from_trie( |
| prefix_sequence[1:], |
| trie_dict[prefix_sequence[0]], |
| append_trie, |
| bos_token_id, |
| ) |
| else: |
| if append_trie: |
| return append_trie.get(prefix_sequence) |
| else: |
| return [] |
|
|
| def __iter__(self): |
| def _traverse(prefix_sequence, trie_dict): |
| if trie_dict: |
| for next_token in trie_dict: |
| yield from _traverse( |
| prefix_sequence + [next_token], trie_dict[next_token] |
| ) |
| else: |
| yield prefix_sequence |
|
|
| return _traverse([], self.trie_dict) |
|
|
| def __len__(self): |
| return self.len |
|
|
| def __getitem__(self, value): |
| return self.get(value) |