File size: 5,671 Bytes
a476bbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | use pyo3::prelude::*;
use std::collections::HashMap;
struct TrieNode {
children: HashMap<char, TrieNode>,
token_id: Option<usize>,
}
impl TrieNode {
fn new() -> Self {
TrieNode {
children: HashMap::new(),
token_id: None,
}
}
}
#[pyclass]
pub struct Trie {
root: TrieNode,
next_id: usize, // for assigning unique IDs to tokens
unk_token_set: bool,
unk_token_id: usize,
}
// A Trie: See https://en.wikipedia.org/wiki/Trie
// This is a data structure that allows for tokenizing a stream of text
// such that the longest possible tokens are recognized first.
//
// To explain how this works, let's first consider how we add new tokens.
// Let's say we have four possible tokens: 'A', 'B', 'AA', 'AB'
// The trie always has an empty root node. There will be two children:
// the node 'A' with token_id 0 and 'B' with token_id 1. The node 'B' has
// no children since we have no tokens that start with 'B' and continue to
// another character.
// The node 'A' has two children, one other node with 'A' with token_id of 2
// and a node 'B' with token_id of 3.
// Now, to tokenize the following string: 'ABAABA'
// We start from the beginning of the string, continuing until we no longer have
// the substring in our tokens. The first character is 'A'; going down the trie
// we have a node that starts with 'A'. The next character is 'B', and our 'A' node
// has a child that starts with 'B'. The character after is 'A', but our last node,
// with token_id = 3, has no children, so we have our first token 'AB' with token_id 3.
// Similarly, we have 'AA' token_id 2, 'B' with token_id 1, and 'A' with token_id 0.
// So, 'ABAABA' -> [3, 2, 1, 0]
#[pymethods]
impl Trie {
#[new]
pub fn new(unk_token_id: Option<usize>) -> Self {
Trie {
root: TrieNode::new(),
next_id: 0, // We start the IDs at 0
unk_token_set: unk_token_id.is_some(),
unk_token_id: unk_token_id.unwrap_or(0),
}
}
// Function responsible for figuring out the tree structure
// Children are represented as dictionaries to make the search simpler.
// In fact, for our purposes where the number of children will be small,
// it is probably faster to use lists.
pub fn add(&mut self, word: &str) {
let mut node = &mut self.root;
for ch in word.chars() {
node = node.children.entry(ch).or_insert_with(TrieNode::new);
}
if node.token_id.is_none() {
node.token_id = Some(self.next_id);
self.next_id += 1;
if !self.unk_token_set {
self.unk_token_id = self.next_id;
}
}
}
// Tokenizing function. Does what is described in the comment above.
// You can see how we keep going through the characters until we hit a node
// that has no children.
pub fn tokenize(&self, text: &str) -> Vec<usize> {
let mut tokens = vec![];
let mut start = 0;
while start < text.len() {
let mut node = &self.root;
let mut matched = false;
let mut end = start;
for ch in text[start..].chars() {
if let Some(next_node) = node.children.get(&ch) {
// If the character matches a child, we go to the next node
node = next_node;
end += ch.len_utf8();
if node.token_id.is_some() { // If at the leaf, we have our token
matched = true;
break;
}
} else { // This means we never matched, so it is an '<unk>' token
break;
}
}
if matched {
tokens.push(node.token_id.unwrap());
start = end;
} else {
tokens.push(self.unk_token_id); // Assign unknown token ID
start += text[start..].chars().next().unwrap().len_utf8();
}
}
tokens
}
}
#[pymodule]
fn rust_trie(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Trie>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use pyo3::types::IntoPyDict;
#[test]
fn test_trie() {
let gil = Python::acquire_gil();
let py = gil.python();
let trie_module = PyModule::new(py, "trie_module").unwrap();
let locals = [("trie", trie_module)].into_py_dict(py);
let py_trie: PyObject = py
.eval("trie.Trie()", Some(locals), None)
.unwrap()
.into();
py_trie.call_method0("add", "[CLS]").unwrap();
let tokens: Vec<usize> = py_trie
.call_method1("tokenize", ("[CLS] This is a test",))
.unwrap()
.extract()
.unwrap();
assert_eq!(tokens, vec![0, 1, 1, 1, 1]);
}
}
#[cfg(test)]
mod tests {
use super::*;
use pyo3::types::IntoPyDict;
#[test]
fn test_trie() {
let gil = Python::acquire_gil();
let py = gil.python();
let trie_mod = PyModule::new(py, "trie_module").unwrap();
let locals = [("trie", trie_mod)].into_py_dict(py);
let py_trie: PyObject = py
.eval("trie.Trie()", Some(locals), None)
.unwrap()
.into();
py_trie.call_method0("add", "<cls>").unwrap();
let tokens: Vec<usize> = py_trie
.call_method1("tokenized", ("<cls> This is a test",))
.unwrap()
.extract()
.unwrap();
assert_eq!(tokens, vec![0, 1, 1, 1, 1]);
}
}
|