| use serde::Serialize; |
| use std::collections::{hash_map::DefaultHasher, HashMap}; |
| use std::hash::{Hash, Hasher}; |
|
|
| use numpy::{npyffi, PyArray1}; |
| use pyo3::class::basic::CompareOp; |
| use pyo3::exceptions; |
| use pyo3::intern; |
| use pyo3::prelude::*; |
| use pyo3::types::*; |
| use tk::models::bpe::BPE; |
| use tk::tokenizer::{ |
| Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, |
| TruncationDirection, TruncationParams, TruncationStrategy, |
| }; |
| use tk::utils::iter::ResultShunt; |
| use tokenizers as tk; |
|
|
| use super::decoders::PyDecoder; |
| use super::encoding::PyEncoding; |
| use super::error::{PyError, ToPyResult}; |
| use super::models::PyModel; |
| use super::normalizers::PyNormalizer; |
| use super::pre_tokenizers::PyPreTokenizer; |
| use super::trainers::PyTrainer; |
| use crate::processors::PyPostProcessor; |
| use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; |
| use std::collections::BTreeMap; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(dict, module = "tokenizers", name = "AddedToken")] |
| pub struct PyAddedToken { |
| pub content: String, |
| pub special: bool, |
| pub single_word: Option<bool>, |
| pub lstrip: Option<bool>, |
| pub rstrip: Option<bool>, |
| pub normalized: Option<bool>, |
| } |
| impl PyAddedToken { |
| pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self { |
| Self { |
| content: content.into(), |
| special: special.unwrap_or(false), |
| single_word: None, |
| lstrip: None, |
| rstrip: None, |
| normalized: None, |
| } |
| } |
|
|
| pub fn get_token(&self) -> tk::tokenizer::AddedToken { |
| let mut token = tk::AddedToken::from(&self.content, self.special); |
|
|
| if let Some(sw) = self.single_word { |
| token = token.single_word(sw); |
| } |
| if let Some(ls) = self.lstrip { |
| token = token.lstrip(ls); |
| } |
| if let Some(rs) = self.rstrip { |
| token = token.rstrip(rs); |
| } |
| if let Some(n) = self.normalized { |
| token = token.normalized(n); |
| } |
|
|
| token |
| } |
|
|
| pub fn as_pydict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> { |
| let dict = PyDict::new_bound(py); |
| let token = self.get_token(); |
|
|
| dict.set_item("content", token.content)?; |
| dict.set_item("single_word", token.single_word)?; |
| dict.set_item("lstrip", token.lstrip)?; |
| dict.set_item("rstrip", token.rstrip)?; |
| dict.set_item("normalized", token.normalized)?; |
| dict.set_item("special", token.special)?; |
|
|
| Ok(dict) |
| } |
| } |
|
|
| impl From<tk::AddedToken> for PyAddedToken { |
| fn from(token: tk::AddedToken) -> Self { |
| Self { |
| content: token.content, |
| single_word: Some(token.single_word), |
| lstrip: Some(token.lstrip), |
| rstrip: Some(token.rstrip), |
| normalized: Some(token.normalized), |
| special: token.special, |
| } |
| } |
| } |
|
|
| #[pymethods] |
| impl PyAddedToken { |
| #[new] |
| #[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False)")] |
| fn __new__(content: Option<&str>, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<Self> { |
| let mut token = PyAddedToken::from(content.unwrap_or(""), None); |
|
|
| if let Some(kwargs) = kwargs { |
| for (key, value) in kwargs { |
| let key: &str = key.extract()?; |
| match key { |
| "single_word" => token.single_word = Some(value.extract()?), |
| "lstrip" => token.lstrip = Some(value.extract()?), |
| "rstrip" => token.rstrip = Some(value.extract()?), |
| "normalized" => token.normalized = Some(value.extract()?), |
| "special" => token.special = value.extract()?, |
| _ => println!("Ignored unknown kwarg option {}", key), |
| } |
| } |
| } |
|
|
| Ok(token) |
| } |
|
|
| fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> { |
| self.as_pydict(py) |
| } |
|
|
| fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { |
| match state.extract::<&PyDict>(py) { |
| Ok(state) => { |
| for (key, value) in state { |
| let key: &str = key.extract()?; |
| match key { |
| "content" => self.content = value.extract()?, |
| "single_word" => self.single_word = Some(value.extract()?), |
| "lstrip" => self.lstrip = Some(value.extract()?), |
| "rstrip" => self.rstrip = Some(value.extract()?), |
| "normalized" => self.normalized = Some(value.extract()?), |
| "special" => self.special = value.extract()?, |
| _ => {} |
| } |
| } |
| Ok(()) |
| } |
| Err(e) => Err(e), |
| } |
| } |
|
|
| |
| #[getter] |
| fn get_content(&self) -> &str { |
| &self.content |
| } |
|
|
| |
| #[setter] |
| fn set_content(&mut self, content: String) { |
| self.content = content; |
| } |
|
|
| |
| #[getter] |
| fn get_rstrip(&self) -> bool { |
| self.get_token().rstrip |
| } |
|
|
| |
| #[getter] |
| fn get_lstrip(&self) -> bool { |
| self.get_token().lstrip |
| } |
|
|
| |
| #[getter] |
| fn get_single_word(&self) -> bool { |
| self.get_token().single_word |
| } |
|
|
| |
| #[getter] |
| fn get_normalized(&self) -> bool { |
| self.get_token().normalized |
| } |
| |
| #[getter] |
| fn get_special(&self) -> bool { |
| self.get_token().special |
| } |
|
|
| |
| #[setter] |
| fn set_special(&mut self, special: bool) { |
| self.special = special; |
| } |
|
|
| fn __str__(&self) -> PyResult<&str> { |
| Ok(&self.content) |
| } |
|
|
| fn __repr__(&self) -> PyResult<String> { |
| let bool_to_python = |p| match p { |
| true => "True", |
| false => "False", |
| }; |
|
|
| let token = self.get_token(); |
| Ok(format!( |
| "AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={}, special={})", |
| self.content, |
| bool_to_python(token.rstrip), |
| bool_to_python(token.lstrip), |
| bool_to_python(token.single_word), |
| bool_to_python(token.normalized), |
| bool_to_python(token.special) |
| )) |
| } |
|
|
| fn __richcmp__(&self, other: Py<PyAddedToken>, op: CompareOp) -> bool { |
| use CompareOp::*; |
| Python::with_gil(|py| match op { |
| Lt | Le | Gt | Ge => false, |
| Eq => self.get_token() == other.borrow(py).get_token(), |
| Ne => self.get_token() != other.borrow(py).get_token(), |
| }) |
| } |
|
|
| fn __hash__(&self) -> u64 { |
| let mut hasher = DefaultHasher::new(); |
| self.get_token().hash(&mut hasher); |
| hasher.finish() |
| } |
| } |
|
|
| struct TextInputSequence<'s>(tk::InputSequence<'s>); |
| impl<'s> FromPyObject<'s> for TextInputSequence<'s> { |
| fn extract(ob: &'s PyAny) -> PyResult<Self> { |
| let err = exceptions::PyTypeError::new_err("TextInputSequence must be str"); |
| if let Ok(s) = ob.downcast::<PyString>() { |
| Ok(Self(s.to_string_lossy().into())) |
| } else { |
| Err(err) |
| } |
| } |
| } |
| impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> { |
| fn from(s: TextInputSequence<'s>) -> Self { |
| s.0 |
| } |
| } |
|
|
| struct PyArrayUnicode(Vec<String>); |
| impl FromPyObject<'_> for PyArrayUnicode { |
| fn extract(ob: &PyAny) -> PyResult<Self> { |
| |
| if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 { |
| return Err(exceptions::PyTypeError::new_err("Expected an np.array")); |
| } |
| let arr = ob.as_ptr() as *mut npyffi::PyArrayObject; |
| |
| let (type_num, elsize, alignment, data, nd, flags) = unsafe { |
| let desc = (*arr).descr; |
| ( |
| (*desc).type_num, |
| (*desc).elsize as usize, |
| (*desc).alignment as usize, |
| (*arr).data, |
| (*arr).nd, |
| (*arr).flags, |
| ) |
| }; |
|
|
| if nd != 1 { |
| return Err(exceptions::PyTypeError::new_err( |
| "Expected a 1 dimensional np.array", |
| )); |
| } |
| if flags & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) == 0 { |
| return Err(exceptions::PyTypeError::new_err( |
| "Expected a contiguous np.array", |
| )); |
| } |
| if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 { |
| return Err(exceptions::PyTypeError::new_err( |
| "Expected a np.array[dtype='U']", |
| )); |
| } |
|
|
| |
| unsafe { |
| let n_elem = *(*arr).dimensions as usize; |
| let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem); |
|
|
| let seq = (0..n_elem) |
| .map(|i| { |
| let bytes = &all_bytes[i * elsize..(i + 1) * elsize]; |
| let unicode = pyo3::ffi::PyUnicode_FromKindAndData( |
| pyo3::ffi::PyUnicode_4BYTE_KIND as _, |
| bytes.as_ptr() as *const _, |
| elsize as isize / alignment as isize, |
| ); |
| let py = ob.py(); |
| let obj = PyObject::from_owned_ptr(py, unicode); |
| let s = obj.downcast_bound::<PyString>(py)?; |
| Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned()) |
| }) |
| .collect::<PyResult<Vec<_>>>()?; |
|
|
| Ok(Self(seq)) |
| } |
| } |
| } |
| impl From<PyArrayUnicode> for tk::InputSequence<'_> { |
| fn from(s: PyArrayUnicode) -> Self { |
| s.0.into() |
| } |
| } |
|
|
| struct PyArrayStr(Vec<String>); |
| impl FromPyObject<'_> for PyArrayStr { |
| fn extract(ob: &PyAny) -> PyResult<Self> { |
| let array = ob.downcast::<PyArray1<PyObject>>()?; |
| let seq = array |
| .readonly() |
| .as_array() |
| .iter() |
| .map(|obj| { |
| let s = obj.downcast_bound::<PyString>(ob.py())?; |
| Ok(s.to_string_lossy().into_owned()) |
| }) |
| .collect::<PyResult<Vec<_>>>()?; |
|
|
| Ok(Self(seq)) |
| } |
| } |
| impl From<PyArrayStr> for tk::InputSequence<'_> { |
| fn from(s: PyArrayStr) -> Self { |
| s.0.into() |
| } |
| } |
|
|
| struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>); |
| impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> { |
| fn extract(ob: &'s PyAny) -> PyResult<Self> { |
| if let Ok(seq) = ob.extract::<PyArrayUnicode>() { |
| return Ok(Self(seq.into())); |
| } |
| if let Ok(seq) = ob.extract::<PyArrayStr>() { |
| return Ok(Self(seq.into())); |
| } |
| if let Ok(s) = ob.downcast::<PyList>() { |
| if let Ok(seq) = s.extract::<Vec<String>>() { |
| return Ok(Self(seq.into())); |
| } |
| } |
| if let Ok(s) = ob.downcast::<PyTuple>() { |
| if let Ok(seq) = s.extract::<Vec<String>>() { |
| return Ok(Self(seq.into())); |
| } |
| } |
| Err(exceptions::PyTypeError::new_err( |
| "PreTokenizedInputSequence must be Union[List[str], Tuple[str]]", |
| )) |
| } |
| } |
| impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> { |
| fn from(s: PreTokenizedInputSequence<'s>) -> Self { |
| s.0 |
| } |
| } |
|
|
| struct TextEncodeInput<'s>(tk::EncodeInput<'s>); |
| impl<'s> FromPyObject<'s> for TextEncodeInput<'s> { |
| fn extract(ob: &'s PyAny) -> PyResult<Self> { |
| if let Ok(i) = ob.extract::<TextInputSequence>() { |
| return Ok(Self(i.into())); |
| } |
| if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() { |
| return Ok(Self((i1, i2).into())); |
| } |
| if let Ok(arr) = ob.extract::<Vec<&PyAny>>() { |
| if arr.len() == 2 { |
| let first = arr[0].extract::<TextInputSequence>()?; |
| let second = arr[1].extract::<TextInputSequence>()?; |
| return Ok(Self((first, second).into())); |
| } |
| } |
| Err(exceptions::PyTypeError::new_err( |
| "TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]", |
| )) |
| } |
| } |
| impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> { |
| fn from(i: TextEncodeInput<'s>) -> Self { |
| i.0 |
| } |
| } |
| struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>); |
| impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> { |
| fn extract(ob: &'s PyAny) -> PyResult<Self> { |
| if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() { |
| return Ok(Self(i.into())); |
| } |
| if let Ok((i1, i2)) = ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>() |
| { |
| return Ok(Self((i1, i2).into())); |
| } |
| if let Ok(arr) = ob.extract::<Vec<&PyAny>>() { |
| if arr.len() == 2 { |
| let first = arr[0].extract::<PreTokenizedInputSequence>()?; |
| let second = arr[1].extract::<PreTokenizedInputSequence>()?; |
| return Ok(Self((first, second).into())); |
| } |
| } |
| Err(exceptions::PyTypeError::new_err( |
| "PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \ |
| Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]", |
| )) |
| } |
| } |
| impl<'s> From<PreTokenizedEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> { |
| fn from(i: PreTokenizedEncodeInput<'s>) -> Self { |
| i.0 |
| } |
| } |
|
|
| type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>; |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(dict, module = "tokenizers", name = "Tokenizer")] |
| #[derive(Clone, Serialize)] |
| #[serde(transparent)] |
| pub struct PyTokenizer { |
| tokenizer: Tokenizer, |
| } |
|
|
| impl PyTokenizer { |
| fn new(tokenizer: Tokenizer) -> Self { |
| PyTokenizer { tokenizer } |
| } |
|
|
| fn from_model(model: PyModel) -> Self { |
| PyTokenizer::new(TokenizerImpl::new(model)) |
| } |
| } |
|
|
| #[pymethods] |
| impl PyTokenizer { |
| #[new] |
| #[pyo3(text_signature = "(self, model)")] |
| fn __new__(model: PyRef<PyModel>) -> Self { |
| PyTokenizer::from_model(model.clone()) |
| } |
|
|
| fn __getstate__(&self, py: Python) -> PyResult<PyObject> { |
| let data = serde_json::to_string(&self.tokenizer).map_err(|e| { |
| exceptions::PyException::new_err(format!( |
| "Error while attempting to pickle Tokenizer: {}", |
| e |
| )) |
| })?; |
| Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py)) |
| } |
|
|
| fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { |
| match state.extract::<&PyBytes>(py) { |
| Ok(s) => { |
| self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| { |
| exceptions::PyException::new_err(format!( |
| "Error while attempting to unpickle Tokenizer: {}", |
| e |
| )) |
| })?; |
| Ok(()) |
| } |
| Err(e) => Err(e), |
| } |
| } |
|
|
| fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { |
| let model = PyModel::from(BPE::default()).into_py(py); |
| PyTuple::new_bound(py, vec![model]) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| #[pyo3(text_signature = "(json)")] |
| fn from_str(json: &str) -> PyResult<Self> { |
| let tokenizer: PyResult<_> = ToPyResult(json.parse()).into(); |
| Ok(Self::new(tokenizer?)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| #[pyo3(text_signature = "(path)")] |
| fn from_file(path: &str) -> PyResult<Self> { |
| let tokenizer: PyResult<_> = ToPyResult(Tokenizer::from_file(path)).into(); |
| Ok(Self::new(tokenizer?)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| #[pyo3(text_signature = "(buffer)")] |
| fn from_buffer(buffer: &Bound<'_, PyBytes>) -> PyResult<Self> { |
| let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| { |
| exceptions::PyValueError::new_err(format!( |
| "Cannot instantiate Tokenizer from buffer: {}", |
| e |
| )) |
| })?; |
| Ok(Self { tokenizer }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| #[pyo3(signature = (identifier, revision = String::from("main"), auth_token = None))] |
| #[pyo3(text_signature = "(identifier, revision=\"main\", auth_token=None)")] |
| fn from_pretrained( |
| identifier: &str, |
| revision: String, |
| auth_token: Option<String>, |
| ) -> PyResult<Self> { |
| let path = Python::with_gil(|py| -> PyResult<String> { |
| let huggingface_hub = PyModule::import_bound(py, intern!(py, "huggingface_hub"))?; |
| let hf_hub_download = huggingface_hub.getattr(intern!(py, "hf_hub_download"))?; |
| let kwargs = [ |
| (intern!(py, "repo_id"), identifier), |
| (intern!(py, "filename"), "tokenizer.json"), |
| (intern!(py, "revision"), &revision), |
| ] |
| .into_py_dict_bound(py); |
| if let Some(auth_token) = auth_token { |
| kwargs.set_item(intern!(py, "token"), auth_token)?; |
| } |
| let path: String = hf_hub_download.call((), Some(&kwargs))?.extract()?; |
| Ok(path) |
| })?; |
|
|
| let tokenizer: PyResult<_> = ToPyResult(Tokenizer::from_file(path)).into(); |
| Ok(Self::new(tokenizer?)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (pretty = false))] |
| #[pyo3(text_signature = "(self, pretty=False)")] |
| fn to_str(&self, pretty: bool) -> PyResult<String> { |
| ToPyResult(self.tokenizer.to_string(pretty)).into() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (path, pretty = true))] |
| #[pyo3(text_signature = "(self, path, pretty=True)")] |
| fn save(&self, path: &str, pretty: bool) -> PyResult<()> { |
| ToPyResult(self.tokenizer.save(path, pretty)).into() |
| } |
|
|
| fn __repr__(&self) -> PyResult<String> { |
| crate::utils::serde_pyo3::repr(self) |
| .map_err(|e| exceptions::PyException::new_err(e.to_string())) |
| } |
|
|
| fn __str__(&self) -> PyResult<String> { |
| crate::utils::serde_pyo3::to_string(self) |
| .map_err(|e| exceptions::PyException::new_err(e.to_string())) |
| } |
|
|
| |
| |
| |
| #[pyo3(text_signature = "(self, is_pair)")] |
| fn num_special_tokens_to_add(&self, is_pair: bool) -> usize { |
| self.tokenizer |
| .get_post_processor() |
| .map_or(0, |p| p.added_tokens(is_pair)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (with_added_tokens = true))] |
| #[pyo3(text_signature = "(self, with_added_tokens=True)")] |
| fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> { |
| self.tokenizer.get_vocab(with_added_tokens) |
| } |
|
|
| |
| |
| |
| |
| #[pyo3(signature = ())] |
| #[pyo3(text_signature = "(self)")] |
| fn get_added_tokens_decoder(&self) -> BTreeMap<u32, PyAddedToken> { |
| let mut sorted_map = BTreeMap::new(); |
|
|
| for (key, value) in self.tokenizer.get_added_tokens_decoder() { |
| sorted_map.insert(key, value.into()); |
| } |
|
|
| sorted_map |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (with_added_tokens = true))] |
| #[pyo3(text_signature = "(self, with_added_tokens=True)")] |
| fn get_vocab_size(&self, with_added_tokens: bool) -> usize { |
| self.tokenizer.get_vocab_size(with_added_tokens) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (max_length, **kwargs))] |
| #[pyo3( |
| text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')" |
| )] |
| fn enable_truncation( |
| &mut self, |
| max_length: usize, |
| kwargs: Option<&Bound<'_, PyDict>>, |
| ) -> PyResult<()> { |
| let mut params = TruncationParams { |
| max_length, |
| ..Default::default() |
| }; |
|
|
| if let Some(kwargs) = kwargs { |
| for (key, value) in kwargs { |
| let key: &str = key.extract()?; |
| match key { |
| "stride" => params.stride = value.extract()?, |
| "strategy" => { |
| let value: &str = value.extract()?; |
| params.strategy = match value { |
| "longest_first" => Ok(TruncationStrategy::LongestFirst), |
| "only_first" => Ok(TruncationStrategy::OnlyFirst), |
| "only_second" => Ok(TruncationStrategy::OnlySecond), |
| _ => Err(PyError(format!( |
| "Unknown `strategy`: `{}`. Use \ |
| one of `longest_first`, `only_first`, or `only_second`", |
| value |
| )) |
| .into_pyerr::<exceptions::PyValueError>()), |
| }? |
| } |
| "direction" => { |
| let value: &str = value.extract()?; |
| params.direction = match value { |
| "left" => Ok(TruncationDirection::Left), |
| "right" => Ok(TruncationDirection::Right), |
| _ => Err(PyError(format!( |
| "Unknown `direction`: `{}`. Use \ |
| one of `left` or `right`.", |
| value |
| )) |
| .into_pyerr::<exceptions::PyValueError>()), |
| }? |
| } |
| _ => println!("Ignored unknown kwarg option {}", key), |
| } |
| } |
| } |
|
|
| if let Err(error_message) = self.tokenizer.with_truncation(Some(params)) { |
| return Err(PyError(error_message.to_string()).into_pyerr::<exceptions::PyValueError>()); |
| } |
| Ok(()) |
| } |
|
|
| |
| #[pyo3(text_signature = "(self)")] |
| fn no_truncation(&mut self) { |
| self.tokenizer |
| .with_truncation(None) |
| .expect("Failed to set truncation to `None`! This should never happen"); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[getter] |
| fn get_truncation<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyDict>>> { |
| self.tokenizer.get_truncation().map_or(Ok(None), |params| { |
| let dict = PyDict::new_bound(py); |
|
|
| dict.set_item("max_length", params.max_length)?; |
| dict.set_item("stride", params.stride)?; |
| dict.set_item("strategy", params.strategy.as_ref())?; |
| dict.set_item("direction", params.direction.as_ref())?; |
|
|
| Ok(Some(dict)) |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (**kwargs))] |
| #[pyo3( |
| text_signature = "(self, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]', length=None, pad_to_multiple_of=None)" |
| )] |
| fn enable_padding(&mut self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> { |
| let mut params = PaddingParams::default(); |
|
|
| if let Some(kwargs) = kwargs { |
| for (key, value) in kwargs { |
| let key: &str = key.extract()?; |
| match key { |
| "direction" => { |
| let value: &str = value.extract()?; |
| params.direction = match value { |
| "left" => Ok(PaddingDirection::Left), |
| "right" => Ok(PaddingDirection::Right), |
| other => Err(PyError(format!( |
| "Unknown `direction`: `{}`. Use \ |
| one of `left` or `right`", |
| other |
| )) |
| .into_pyerr::<exceptions::PyValueError>()), |
| }?; |
| } |
| "pad_to_multiple_of" => { |
| if let Some(multiple) = value.extract()? { |
| params.pad_to_multiple_of = multiple; |
| } |
| } |
| "pad_id" => params.pad_id = value.extract()?, |
| "pad_type_id" => params.pad_type_id = value.extract()?, |
| "pad_token" => params.pad_token = value.extract()?, |
| "max_length" => { |
| println!( |
| "enable_padding(max_length=X) is deprecated, \ |
| use enable_padding(length=X) instead" |
| ); |
| if let Some(l) = value.extract()? { |
| params.strategy = PaddingStrategy::Fixed(l); |
| } else { |
| params.strategy = PaddingStrategy::BatchLongest; |
| } |
| } |
| "length" => { |
| if let Some(l) = value.extract()? { |
| params.strategy = PaddingStrategy::Fixed(l); |
| } else { |
| params.strategy = PaddingStrategy::BatchLongest; |
| } |
| } |
| _ => println!("Ignored unknown kwarg option {}", key), |
| } |
| } |
| } |
|
|
| self.tokenizer.with_padding(Some(params)); |
|
|
| Ok(()) |
| } |
|
|
| |
| #[pyo3(text_signature = "(self)")] |
| fn no_padding(&mut self) { |
| self.tokenizer.with_padding(None); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[getter] |
| fn get_padding<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyDict>>> { |
| self.tokenizer.get_padding().map_or(Ok(None), |params| { |
| let dict = PyDict::new_bound(py); |
|
|
| dict.set_item( |
| "length", |
| match params.strategy { |
| tk::PaddingStrategy::BatchLongest => None, |
| tk::PaddingStrategy::Fixed(size) => Some(size), |
| }, |
| )?; |
| dict.set_item("pad_to_multiple_of", params.pad_to_multiple_of)?; |
| dict.set_item("pad_id", params.pad_id)?; |
| dict.set_item("pad_token", ¶ms.pad_token)?; |
| dict.set_item("pad_type_id", params.pad_type_id)?; |
| dict.set_item("direction", params.direction.as_ref())?; |
|
|
| Ok(Some(dict)) |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (sequence, pair = None, is_pretokenized = false, add_special_tokens = true))] |
| #[pyo3( |
| text_signature = "(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True)" |
| )] |
| fn encode( |
| &self, |
| sequence: &Bound<'_, PyAny>, |
| pair: Option<&Bound<'_, PyAny>>, |
| is_pretokenized: bool, |
| add_special_tokens: bool, |
| ) -> PyResult<PyEncoding> { |
| let sequence: tk::InputSequence = if is_pretokenized { |
| sequence.extract::<PreTokenizedInputSequence>()?.into() |
| } else { |
| sequence.extract::<TextInputSequence>()?.into() |
| }; |
| let input = match pair { |
| Some(pair) => { |
| let pair: tk::InputSequence = if is_pretokenized { |
| pair.extract::<PreTokenizedInputSequence>()?.into() |
| } else { |
| pair.extract::<TextInputSequence>()?.into() |
| }; |
| tk::EncodeInput::Dual(sequence, pair) |
| } |
| None => tk::EncodeInput::Single(sequence), |
| }; |
|
|
| ToPyResult( |
| self.tokenizer |
| .encode_char_offsets(input, add_special_tokens) |
| .map(|e| e.into()), |
| ) |
| .into() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (input, is_pretokenized = false, add_special_tokens = true))] |
| #[pyo3(text_signature = "(self, input, is_pretokenized=False, add_special_tokens=True)")] |
| fn encode_batch( |
| &self, |
| py: Python<'_>, |
| input: Vec<&PyAny>, |
| is_pretokenized: bool, |
| add_special_tokens: bool, |
| ) -> PyResult<Vec<PyEncoding>> { |
| let input: Vec<tk::EncodeInput> = input |
| .into_iter() |
| .map(|o| { |
| let input: tk::EncodeInput = if is_pretokenized { |
| o.extract::<PreTokenizedEncodeInput>()?.into() |
| } else { |
| o.extract::<TextEncodeInput>()?.into() |
| }; |
| Ok(input) |
| }) |
| .collect::<PyResult<Vec<tk::EncodeInput>>>()?; |
| py.allow_threads(|| { |
| ToPyResult( |
| self.tokenizer |
| .encode_batch_char_offsets(input, add_special_tokens) |
| .map(|encodings| encodings.into_iter().map(|e| e.into()).collect()), |
| ) |
| .into() |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (input, is_pretokenized = false, add_special_tokens = true))] |
| #[pyo3(text_signature = "(self, input, is_pretokenized=False, add_special_tokens=True)")] |
| fn encode_batch_fast( |
| &self, |
| py: Python<'_>, |
| input: Vec<&PyAny>, |
| is_pretokenized: bool, |
| add_special_tokens: bool, |
| ) -> PyResult<Vec<PyEncoding>> { |
| let input: Vec<tk::EncodeInput> = input |
| .into_iter() |
| .map(|o| { |
| let input: tk::EncodeInput = if is_pretokenized { |
| o.extract::<PreTokenizedEncodeInput>()?.into() |
| } else { |
| o.extract::<TextEncodeInput>()?.into() |
| }; |
| Ok(input) |
| }) |
| .collect::<PyResult<Vec<tk::EncodeInput>>>()?; |
| py.allow_threads(|| { |
| ToPyResult( |
| self.tokenizer |
| .encode_batch_fast(input, add_special_tokens) |
| .map(|encodings| encodings.into_iter().map(|e| e.into()).collect()), |
| ) |
| .into() |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (ids, skip_special_tokens = true))] |
| #[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")] |
| fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> { |
| ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (sequences, skip_special_tokens = true))] |
| #[pyo3(text_signature = "(self, sequences, skip_special_tokens=True)")] |
| fn decode_batch( |
| &self, |
| py: Python<'_>, |
| sequences: Vec<Vec<u32>>, |
| skip_special_tokens: bool, |
| ) -> PyResult<Vec<String>> { |
| py.allow_threads(|| { |
| let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>(); |
| ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into() |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(text_signature = "(self, token)")] |
| fn token_to_id(&self, token: &str) -> Option<u32> { |
| self.tokenizer.token_to_id(token) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(text_signature = "(self, id)")] |
| fn id_to_token(&self, id: u32) -> Option<String> { |
| self.tokenizer.id_to_token(id) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[setter] |
| fn set_encode_special_tokens(&mut self, value: bool) { |
| self.tokenizer.set_encode_special_tokens(value); |
| } |
| |
| |
| |
| |
| #[getter] |
| fn get_encode_special_tokens(&self) -> bool { |
| self.tokenizer.get_encode_special_tokens() |
| } |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(text_signature = "(self, tokens)")] |
| fn add_tokens(&mut self, tokens: &Bound<'_, PyList>) -> PyResult<usize> { |
| let tokens = tokens |
| .into_iter() |
| .map(|token| { |
| if let Ok(content) = token.extract::<String>() { |
| Ok(PyAddedToken::from(content, Some(false)).get_token()) |
| } else if let Ok(token) = token.extract::<PyRefMut<PyAddedToken>>() { |
| Ok(token.get_token()) |
| } else { |
| Err(exceptions::PyTypeError::new_err( |
| "Input must be a List[Union[str, AddedToken]]", |
| )) |
| } |
| }) |
| .collect::<PyResult<Vec<_>>>()?; |
|
|
| Ok(self.tokenizer.add_tokens(&tokens)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(text_signature = "(self, tokens)")] |
| fn add_special_tokens(&mut self, tokens: &Bound<'_, PyList>) -> PyResult<usize> { |
| let tokens = tokens |
| .into_iter() |
| .map(|token| { |
| if let Ok(content) = token.extract::<String>() { |
| Ok(tk::tokenizer::AddedToken::from(content, true)) |
| } else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() { |
| token.special = true; |
| Ok(token.get_token()) |
| } else { |
| Err(exceptions::PyTypeError::new_err( |
| "Input must be a List[Union[str, AddedToken]]", |
| )) |
| } |
| }) |
| .collect::<PyResult<Vec<_>>>()?; |
|
|
| Ok(self.tokenizer.add_special_tokens(&tokens)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (files, trainer = None))] |
| #[pyo3(text_signature = "(self, files, trainer = None)")] |
| fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> { |
| let mut trainer = |
| trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); |
| Python::with_gil(|py| { |
| py.allow_threads(|| { |
| ToPyResult( |
| self.tokenizer |
| .train_from_files(&mut trainer, files) |
| .map(|_| {}), |
| ) |
| .into() |
| }) |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (iterator, trainer = None, length = None))] |
| #[pyo3(text_signature = "(self, iterator, trainer=None, length=None)")] |
| fn train_from_iterator( |
| &mut self, |
| py: Python, |
| iterator: &Bound<'_, PyAny>, |
| trainer: Option<&mut PyTrainer>, |
| length: Option<usize>, |
| ) -> PyResult<()> { |
| let mut trainer = |
| trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); |
|
|
| let buffered_iter = PyBufferedIterator::new( |
| iterator, |
| |element| { |
| |
| |
| |
| if let Ok(s) = element.downcast::<PyString>() { |
| itertools::Either::Right(std::iter::once(s.to_str().map(|s| s.to_owned()))) |
| } else { |
| match element.iter() { |
| Ok(iter) => itertools::Either::Left( |
| iter.map(|i| i?.extract::<String>()) |
| .collect::<Vec<_>>() |
| .into_iter(), |
| ), |
| Err(e) => itertools::Either::Right(std::iter::once(Err(e))), |
| } |
| } |
| }, |
| 256, |
| )?; |
|
|
| py.allow_threads(|| { |
| ResultShunt::process(buffered_iter, |iter| { |
| self.tokenizer |
| .train(&mut trainer, MaybeSizedIterator::new(iter, length)) |
| .map(|_| {}) |
| .map_err(|e| exceptions::PyException::new_err(e.to_string())) |
| })? |
| }) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(signature = (encoding, pair = None, add_special_tokens = true))] |
| #[pyo3(text_signature = "(self, encoding, pair=None, add_special_tokens=True)")] |
| fn post_process( |
| &self, |
| encoding: &PyEncoding, |
| pair: Option<&PyEncoding>, |
| add_special_tokens: bool, |
| ) -> PyResult<PyEncoding> { |
| ToPyResult( |
| self.tokenizer |
| .post_process( |
| encoding.encoding.clone(), |
| pair.map(|p| p.encoding.clone()), |
| add_special_tokens, |
| ) |
| .map(|e| e.into()), |
| ) |
| .into() |
| } |
|
|
| |
| #[getter] |
| fn get_model(&self, py: Python<'_>) -> PyResult<PyObject> { |
| self.tokenizer.get_model().get_as_subtype(py) |
| } |
|
|
| |
| #[setter] |
| fn set_model(&mut self, model: PyRef<PyModel>) { |
| self.tokenizer.with_model(model.clone()); |
| } |
|
|
| |
| #[getter] |
| fn get_normalizer(&self, py: Python<'_>) -> PyResult<PyObject> { |
| if let Some(n) = self.tokenizer.get_normalizer() { |
| n.get_as_subtype(py) |
| } else { |
| Ok(py.None()) |
| } |
| } |
|
|
| |
| #[setter] |
| fn set_normalizer(&mut self, normalizer: Option<PyRef<PyNormalizer>>) { |
| let normalizer_option = normalizer.map(|norm| norm.clone()); |
| self.tokenizer.with_normalizer(normalizer_option); |
| } |
|
|
| |
| #[getter] |
| fn get_pre_tokenizer(&self, py: Python<'_>) -> PyResult<PyObject> { |
| if let Some(pt) = self.tokenizer.get_pre_tokenizer() { |
| pt.get_as_subtype(py) |
| } else { |
| Ok(py.None()) |
| } |
| } |
|
|
| |
| #[setter] |
| fn set_pre_tokenizer(&mut self, pretok: Option<PyRef<PyPreTokenizer>>) { |
| self.tokenizer |
| .with_pre_tokenizer(pretok.map(|pre| pre.clone())); |
| } |
|
|
| |
| #[getter] |
| fn get_post_processor(&self, py: Python<'_>) -> PyResult<PyObject> { |
| if let Some(n) = self.tokenizer.get_post_processor() { |
| n.get_as_subtype(py) |
| } else { |
| Ok(py.None()) |
| } |
| } |
|
|
| |
| #[setter] |
| fn set_post_processor(&mut self, processor: Option<PyRef<PyPostProcessor>>) { |
| self.tokenizer |
| .with_post_processor(processor.map(|p| p.clone())); |
| } |
|
|
| |
| #[getter] |
| fn get_decoder(&self, py: Python<'_>) -> PyResult<PyObject> { |
| if let Some(dec) = self.tokenizer.get_decoder() { |
| dec.get_as_subtype(py) |
| } else { |
| Ok(py.None()) |
| } |
| } |
|
|
| |
| #[setter] |
| fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) { |
| self.tokenizer.with_decoder(decoder.map(|d| d.clone())); |
| } |
| } |
|
|
| #[cfg(test)] |
| mod test { |
| use super::*; |
| use crate::models::PyModel; |
| use crate::normalizers::{PyNormalizer, PyNormalizerTypeWrapper}; |
| use std::sync::{Arc, RwLock}; |
| use tempfile::NamedTempFile; |
| use tk::normalizers::{Lowercase, NFKC}; |
|
|
| #[test] |
| fn serialize() { |
| let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); |
| tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence( |
| vec![ |
| Arc::new(RwLock::new(NFKC.into())), |
| Arc::new(RwLock::new(Lowercase.into())), |
| ], |
| )))); |
|
|
| let tmp = NamedTempFile::new().unwrap().into_temp_path(); |
| tokenizer.save(&tmp, false).unwrap(); |
|
|
| Tokenizer::from_file(&tmp).unwrap(); |
| } |
|
|
| #[test] |
| fn serde_pyo3() { |
| let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); |
| tokenizer.with_normalizer(Some(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence( |
| vec![ |
| Arc::new(RwLock::new(NFKC.into())), |
| Arc::new(RwLock::new(Lowercase.into())), |
| ], |
| )))); |
|
|
| let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap(); |
| assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))"); |
| } |
| } |
|
|