| use pyo3::prelude::*; |
| use std::collections::HashMap; |
|
|
| struct TrieNode { |
| children: HashMap<char, TrieNode>, |
| token_id: Option<usize>, |
| } |
|
|
| impl TrieNode { |
| fn new() -> Self { |
| TrieNode { |
| children: HashMap::new(), |
| token_id: None, |
| } |
| } |
| } |
|
|
| #[pyclass] |
| pub struct Trie { |
| root: TrieNode, |
| next_id: usize, |
| unk_token_set: bool, |
| unk_token_id: usize, |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pymethods] |
| impl Trie { |
| #[new] |
| pub fn new(unk_token_id: Option<usize>) -> Self { |
| Trie { |
| root: TrieNode::new(), |
| next_id: 0, |
| unk_token_set: unk_token_id.is_some(), |
| unk_token_id: unk_token_id.unwrap_or(0), |
| } |
| } |
|
|
| |
| |
| |
| |
| pub fn add(&mut self, word: &str) { |
| let mut node = &mut self.root; |
| for ch in word.chars() { |
| node = node.children.entry(ch).or_insert_with(TrieNode::new); |
| } |
| if node.token_id.is_none() { |
| node.token_id = Some(self.next_id); |
| self.next_id += 1; |
| if !self.unk_token_set { |
| self.unk_token_id = self.next_id; |
| } |
| } |
| } |
|
|
| |
| |
| |
| pub fn tokenize(&self, text: &str) -> Vec<usize> { |
| let mut tokens = vec![]; |
| let mut start = 0; |
| |
| while start < text.len() { |
| let mut node = &self.root; |
| let mut matched = false; |
| let mut end = start; |
| for ch in text[start..].chars() { |
| if let Some(next_node) = node.children.get(&ch) { |
| |
| node = next_node; |
| end += ch.len_utf8(); |
| if node.token_id.is_some() { |
| matched = true; |
| break; |
| } |
| } else { |
| break; |
| } |
| } |
| |
| if matched { |
| tokens.push(node.token_id.unwrap()); |
| start = end; |
| } else { |
| tokens.push(self.unk_token_id); |
| start += text[start..].chars().next().unwrap().len_utf8(); |
| } |
| } |
|
|
| tokens |
| } |
| } |
|
|
| #[pymodule] |
| fn rust_trie(_py: Python, m: &PyModule) -> PyResult<()> { |
| m.add_class::<Trie>()?; |
| Ok(()) |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use pyo3::types::IntoPyDict; |
|
|
| #[test] |
| fn test_trie() { |
| let gil = Python::acquire_gil(); |
| let py = gil.python(); |
| let trie_module = PyModule::new(py, "trie_module").unwrap(); |
| let locals = [("trie", trie_module)].into_py_dict(py); |
| let py_trie: PyObject = py |
| .eval("trie.Trie()", Some(locals), None) |
| .unwrap() |
| .into(); |
| py_trie.call_method0("add", "[CLS]").unwrap(); |
| let tokens: Vec<usize> = py_trie |
| .call_method1("tokenize", ("[CLS] This is a test",)) |
| .unwrap() |
| .extract() |
| .unwrap(); |
| assert_eq!(tokens, vec![0, 1, 1, 1, 1]); |
| } |
| } |
|
|
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use pyo3::types::IntoPyDict; |
|
|
| #[test] |
| fn test_trie() { |
| let gil = Python::acquire_gil(); |
| let py = gil.python(); |
| let trie_mod = PyModule::new(py, "trie_module").unwrap(); |
| let locals = [("trie", trie_mod)].into_py_dict(py); |
| let py_trie: PyObject = py |
| .eval("trie.Trie()", Some(locals), None) |
| .unwrap() |
| .into(); |
| py_trie.call_method0("add", "<cls>").unwrap(); |
| let tokens: Vec<usize> = py_trie |
| .call_method1("tokenized", ("<cls> This is a test",)) |
| .unwrap() |
| .extract() |
| .unwrap(); |
| assert_eq!(tokens, vec![0, 1, 1, 1, 1]); |
| } |
| } |
|
|
|
|