| use std::sync::{Arc, RwLock}; |
|
|
| use pyo3::exceptions; |
| use pyo3::prelude::*; |
| use pyo3::types::*; |
| use serde::ser::SerializeStruct; |
| use serde::{Deserialize, Deserializer, Serialize, Serializer}; |
|
|
| use tk::normalizer::SplitDelimiterBehavior; |
| use tk::pre_tokenizers::bert::BertPreTokenizer; |
| use tk::pre_tokenizers::byte_level::ByteLevel; |
| use tk::pre_tokenizers::delimiter::CharDelimiterSplit; |
| use tk::pre_tokenizers::digits::Digits; |
| use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; |
| use tk::pre_tokenizers::punctuation::Punctuation; |
| use tk::pre_tokenizers::split::Split; |
| use tk::pre_tokenizers::unicode_scripts::UnicodeScripts; |
| use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; |
| use tk::pre_tokenizers::PreTokenizerWrapper; |
| use tk::tokenizer::Offsets; |
| use tk::{PreTokenizedString, PreTokenizer}; |
| use tokenizers as tk; |
|
|
| use super::error::ToPyResult; |
| use super::utils::*; |
|
|
| |
| |
| |
| |
| #[pyclass( |
| dict, |
| module = "tokenizers.pre_tokenizers", |
| name = "PreTokenizer", |
| subclass |
| )] |
| #[derive(Clone, Serialize, Deserialize)] |
| #[serde(transparent)] |
| pub struct PyPreTokenizer { |
| pub(crate) pretok: PyPreTokenizerTypeWrapper, |
| } |
|
|
| impl PyPreTokenizer { |
| #[allow(dead_code)] |
| pub(crate) fn new(pretok: PyPreTokenizerTypeWrapper) -> Self { |
| PyPreTokenizer { pretok } |
| } |
|
|
| pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> { |
| let base = self.clone(); |
| Ok(match &self.pretok { |
| PyPreTokenizerTypeWrapper::Sequence(_) => { |
| Py::new(py, (PySequence {}, base))?.into_py(py) |
| } |
| PyPreTokenizerTypeWrapper::Single(ref inner) => { |
| match &*inner.as_ref().read().unwrap() { |
| PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), |
| PyPreTokenizerWrapper::Wrapped(inner) => match inner { |
| PreTokenizerWrapper::Whitespace(_) => { |
| Py::new(py, (PyWhitespace {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::Split(_) => { |
| Py::new(py, (PySplit {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::Punctuation(_) => { |
| Py::new(py, (PyPunctuation {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::Sequence(_) => { |
| Py::new(py, (PySequence {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::Metaspace(_) => { |
| Py::new(py, (PyMetaspace {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::Delimiter(_) => { |
| Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::WhitespaceSplit(_) => { |
| Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::ByteLevel(_) => { |
| Py::new(py, (PyByteLevel {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::BertPreTokenizer(_) => { |
| Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::Digits(_) => { |
| Py::new(py, (PyDigits {}, base))?.into_py(py) |
| } |
| PreTokenizerWrapper::UnicodeScripts(_) => { |
| Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py) |
| } |
| }, |
| } |
| } |
| }) |
| } |
| } |
|
|
| impl PreTokenizer for PyPreTokenizer { |
| fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> { |
| self.pretok.pre_tokenize(normalized) |
| } |
| } |
|
|
| #[pymethods] |
| impl PyPreTokenizer { |
| #[staticmethod] |
| fn custom(pretok: PyObject) -> Self { |
| PyPreTokenizer { |
| pretok: PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(pretok)).into(), |
| } |
| } |
|
|
| fn __getstate__(&self, py: Python) -> PyResult<PyObject> { |
| let data = serde_json::to_string(&self.pretok).map_err(|e| { |
| exceptions::PyException::new_err(format!( |
| "Error while attempting to pickle PreTokenizer: {}", |
| e |
| )) |
| })?; |
| Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py)) |
| } |
|
|
| fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { |
| match state.extract::<&PyBytes>(py) { |
| Ok(s) => { |
| let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| { |
| exceptions::PyException::new_err(format!( |
| "Error while attempting to unpickle PreTokenizer: {}", |
| e |
| )) |
| })?; |
| self.pretok = unpickled; |
| Ok(()) |
| } |
| Err(e) => Err(e), |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(text_signature = "(self, pretok)")] |
| fn pre_tokenize(&self, pretok: &mut PyPreTokenizedString) -> PyResult<()> { |
| ToPyResult(self.pretok.pre_tokenize(&mut pretok.pretok)).into() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyo3(text_signature = "(self, sequence)")] |
| fn pre_tokenize_str(&self, s: &str) -> PyResult<Vec<(String, Offsets)>> { |
| let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s); |
|
|
| ToPyResult(self.pretok.pre_tokenize(&mut pretokenized)).into_py()?; |
|
|
| Ok(pretokenized |
| .get_splits(tk::OffsetReferential::Original, tk::OffsetType::Char) |
| .into_iter() |
| .map(|(s, o, _)| (s.to_owned(), o)) |
| .collect()) |
| } |
|
|
| fn __repr__(&self) -> PyResult<String> { |
| crate::utils::serde_pyo3::repr(self) |
| .map_err(|e| exceptions::PyException::new_err(e.to_string())) |
| } |
|
|
| fn __str__(&self) -> PyResult<String> { |
| crate::utils::serde_pyo3::to_string(self) |
| .map_err(|e| exceptions::PyException::new_err(e.to_string())) |
| } |
| } |
|
|
| macro_rules! getter { |
| ($self: ident, $variant: ident, $($name: tt)+) => {{ |
| let super_ = $self.as_ref(); |
| if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok { |
| if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) = |
| *single.read().unwrap() { |
| pretok.$($name)+ |
| } else { |
| unreachable!() |
| } |
| } else { |
| unreachable!() |
| } |
| }}; |
| } |
|
|
| macro_rules! setter { |
| ($self: ident, $variant: ident, $name: ident, $value: expr) => {{ |
| let super_ = $self.as_ref(); |
| if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok { |
| if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) = |
| *single.write().unwrap() |
| { |
| pretok.$name = $value; |
| } |
| } |
| }}; |
| ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{ |
| let super_ = $self.as_ref(); |
| if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok { |
| if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) = |
| *single.write().unwrap() |
| { |
| pretok.$name($value); |
| } |
| } |
| }}; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "ByteLevel")] |
| pub struct PyByteLevel {} |
| #[pymethods] |
| impl PyByteLevel { |
| #[getter] |
| fn get_add_prefix_space(self_: PyRef<Self>) -> bool { |
| getter!(self_, ByteLevel, add_prefix_space) |
| } |
|
|
| #[setter] |
| fn set_add_prefix_space(self_: PyRef<Self>, add_prefix_space: bool) { |
| setter!(self_, ByteLevel, add_prefix_space, add_prefix_space); |
| } |
|
|
| #[getter] |
| fn get_use_regex(self_: PyRef<Self>) -> bool { |
| getter!(self_, ByteLevel, use_regex) |
| } |
|
|
| #[setter] |
| fn set_use_regex(self_: PyRef<Self>, use_regex: bool) { |
| setter!(self_, ByteLevel, use_regex, use_regex); |
| } |
|
|
| #[new] |
| #[pyo3(signature = (add_prefix_space = true, use_regex = true, **_kwargs), text_signature = "(self, add_prefix_space=True, use_regex=True)")] |
| fn new( |
| add_prefix_space: bool, |
| use_regex: bool, |
| _kwargs: Option<&Bound<'_, PyDict>>, |
| ) -> (Self, PyPreTokenizer) { |
| ( |
| PyByteLevel {}, |
| ByteLevel::default() |
| .add_prefix_space(add_prefix_space) |
| .use_regex(use_regex) |
| .into(), |
| ) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| #[pyo3(text_signature = "()")] |
| fn alphabet() -> Vec<String> { |
| ByteLevel::alphabet() |
| .into_iter() |
| .map(|c| c.to_string()) |
| .collect() |
| } |
| } |
|
|
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Whitespace")] |
| pub struct PyWhitespace {} |
| #[pymethods] |
| impl PyWhitespace { |
| #[new] |
| #[pyo3(text_signature = "(self)")] |
| fn new() -> (Self, PyPreTokenizer) { |
| (PyWhitespace {}, Whitespace {}.into()) |
| } |
| } |
|
|
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "WhitespaceSplit")] |
| pub struct PyWhitespaceSplit {} |
| #[pymethods] |
| impl PyWhitespaceSplit { |
| #[new] |
| #[pyo3(text_signature = "(self)")] |
| fn new() -> (Self, PyPreTokenizer) { |
| (PyWhitespaceSplit {}, WhitespaceSplit.into()) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Split")] |
| pub struct PySplit {} |
| #[pymethods] |
| impl PySplit { |
| #[new] |
| #[pyo3(signature = (pattern, behavior, invert = false), text_signature = "(self, pattern, behavior, invert=False)")] |
| fn new( |
| pattern: PyPattern, |
| behavior: PySplitDelimiterBehavior, |
| invert: bool, |
| ) -> PyResult<(Self, PyPreTokenizer)> { |
| Ok(( |
| PySplit {}, |
| ToPyResult(Split::new(pattern, behavior.into(), invert)) |
| .into_py()? |
| .into(), |
| )) |
| } |
|
|
| fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { |
| PyTuple::new_bound(py, [" ", "removed"]) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "CharDelimiterSplit")] |
| pub struct PyCharDelimiterSplit {} |
| #[pymethods] |
| impl PyCharDelimiterSplit { |
| #[getter] |
| fn get_delimiter(self_: PyRef<Self>) -> String { |
| getter!(self_, Delimiter, delimiter.to_string()) |
| } |
|
|
| #[setter] |
| fn set_delimiter(self_: PyRef<Self>, delimiter: char) { |
| setter!(self_, Delimiter, delimiter, delimiter); |
| } |
|
|
| #[new] |
| #[pyo3(text_signature = None)] |
| pub fn new(delimiter: char) -> PyResult<(Self, PyPreTokenizer)> { |
| Ok(( |
| PyCharDelimiterSplit {}, |
| CharDelimiterSplit::new(delimiter).into(), |
| )) |
| } |
|
|
| fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { |
| PyTuple::new_bound(py, [" "]) |
| } |
| } |
|
|
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "BertPreTokenizer")] |
| pub struct PyBertPreTokenizer {} |
| #[pymethods] |
| impl PyBertPreTokenizer { |
| #[new] |
| #[pyo3(text_signature = "(self)")] |
| fn new() -> (Self, PyPreTokenizer) { |
| (PyBertPreTokenizer {}, BertPreTokenizer.into()) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Punctuation")] |
| pub struct PyPunctuation {} |
| #[pymethods] |
| impl PyPunctuation { |
| #[new] |
| #[pyo3( signature = (behavior = PySplitDelimiterBehavior(SplitDelimiterBehavior::Isolated)), text_signature = "(self, behavior=\"isolated\")")] |
| fn new(behavior: PySplitDelimiterBehavior) -> (Self, PyPreTokenizer) { |
| (PyPunctuation {}, Punctuation::new(behavior.into()).into()) |
| } |
| } |
|
|
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Sequence")] |
| pub struct PySequence {} |
| #[pymethods] |
| impl PySequence { |
| #[new] |
| #[pyo3(text_signature = "(self, pretokenizers)")] |
| fn new(pre_tokenizers: &Bound<'_, PyList>) -> PyResult<(Self, PyPreTokenizer)> { |
| let mut sequence = Vec::with_capacity(pre_tokenizers.len()); |
| for n in pre_tokenizers.iter() { |
| let pretokenizer: PyRef<PyPreTokenizer> = n.extract()?; |
| match &pretokenizer.pretok { |
| PyPreTokenizerTypeWrapper::Sequence(inner) => { |
| sequence.extend(inner.iter().cloned()) |
| } |
| PyPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()), |
| } |
| } |
| Ok(( |
| PySequence {}, |
| PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Sequence(sequence)), |
| )) |
| } |
|
|
| fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { |
| PyTuple::new_bound(py, [PyList::empty_bound(py)]) |
| } |
|
|
| fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> { |
| match &self_.as_ref().pretok { |
| PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { |
| Some(item) => { |
| PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item))) |
| .get_as_subtype(py) |
| } |
| _ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>( |
| "Index not found", |
| )), |
| }, |
| PyPreTokenizerTypeWrapper::Single(inner) => { |
| PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner))) |
| .get_as_subtype(py) |
| } |
| } |
| } |
| } |
|
|
| pub(crate) fn from_string(string: String) -> Result<PrependScheme, PyErr> { |
| let scheme = match string.as_str() { |
| "first" => PrependScheme::First, |
| "never" => PrependScheme::Never, |
| "always" => PrependScheme::Always, |
| _ => { |
| return Err(exceptions::PyValueError::new_err(format!( |
| "{} is an unknown variant, should be one of ['first', 'never', 'always']", |
| string |
| ))); |
| } |
| }; |
| Ok(scheme) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Metaspace")] |
| pub struct PyMetaspace {} |
| #[pymethods] |
| impl PyMetaspace { |
| #[getter] |
| fn get_replacement(self_: PyRef<Self>) -> String { |
| getter!(self_, Metaspace, get_replacement().to_string()) |
| } |
|
|
| #[setter] |
| fn set_replacement(self_: PyRef<Self>, replacement: char) { |
| setter!(self_, Metaspace, @set_replacement, replacement); |
| } |
|
|
| #[getter] |
| fn get_split(self_: PyRef<Self>) -> bool { |
| getter!(self_, Metaspace, get_split()) |
| } |
|
|
| #[setter] |
| fn set_split(self_: PyRef<Self>, split: bool) { |
| setter!(self_, Metaspace, @set_split, split); |
| } |
|
|
| #[getter] |
| fn get_prepend_scheme(self_: PyRef<Self>) -> String { |
| |
| let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme()); |
| match scheme { |
| PrependScheme::First => "first", |
| PrependScheme::Never => "never", |
| PrependScheme::Always => "always", |
| } |
| .to_string() |
| } |
|
|
| #[setter] |
| fn set_prepend_scheme(self_: PyRef<Self>, prepend_scheme: String) -> PyResult<()> { |
| let scheme = from_string(prepend_scheme)?; |
| setter!(self_, Metaspace, @set_prepend_scheme, scheme); |
| Ok(()) |
| } |
|
|
| #[new] |
| #[pyo3(signature = (replacement = '▁', prepend_scheme=String::from("always"), split=true), text_signature = "(self, replacement=\"_\", prepend_scheme=\"always\", split=True)")] |
| fn new( |
| replacement: char, |
| prepend_scheme: String, |
| split: bool, |
| ) -> PyResult<(Self, PyPreTokenizer)> { |
| |
| let prepend_scheme = from_string(prepend_scheme)?; |
| let new_instance: Metaspace = Metaspace::new(replacement, prepend_scheme, split); |
| Ok((PyMetaspace {}, new_instance.into())) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Digits")] |
| pub struct PyDigits {} |
| #[pymethods] |
| impl PyDigits { |
| #[getter] |
| fn get_individual_digits(self_: PyRef<Self>) -> bool { |
| getter!(self_, Digits, individual_digits) |
| } |
|
|
| #[setter] |
| fn set_individual_digits(self_: PyRef<Self>, individual_digits: bool) { |
| setter!(self_, Digits, individual_digits, individual_digits); |
| } |
|
|
| #[new] |
| #[pyo3(signature = (individual_digits = false), text_signature = "(self, individual_digits=False)")] |
| fn new(individual_digits: bool) -> (Self, PyPreTokenizer) { |
| (PyDigits {}, Digits::new(individual_digits).into()) |
| } |
| } |
|
|
| |
| |
| |
| |
| #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "UnicodeScripts")] |
| pub struct PyUnicodeScripts {} |
| #[pymethods] |
| impl PyUnicodeScripts { |
| #[new] |
| #[pyo3(text_signature = "(self)")] |
| fn new() -> (Self, PyPreTokenizer) { |
| (PyUnicodeScripts {}, UnicodeScripts::new().into()) |
| } |
| } |
|
|
| #[derive(Clone)] |
| pub(crate) struct CustomPreTokenizer { |
| inner: PyObject, |
| } |
|
|
| impl CustomPreTokenizer { |
| pub fn new(inner: PyObject) -> Self { |
| Self { inner } |
| } |
| } |
|
|
| impl tk::tokenizer::PreTokenizer for CustomPreTokenizer { |
| fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> { |
| Python::with_gil(|py| { |
| let pretok = PyPreTokenizedStringRefMut::new(sentence); |
| let py_pretok = self.inner.bind(py); |
| py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?; |
| Ok(()) |
| }) |
| } |
| } |
|
|
| impl Serialize for CustomPreTokenizer { |
| fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error> |
| where |
| S: Serializer, |
| { |
| Err(serde::ser::Error::custom( |
| "Custom PreTokenizer cannot be serialized", |
| )) |
| } |
| } |
|
|
| impl<'de> Deserialize<'de> for CustomPreTokenizer { |
| fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error> |
| where |
| D: Deserializer<'de>, |
| { |
| Err(serde::de::Error::custom( |
| "Custom PreTokenizer cannot be deserialized", |
| )) |
| } |
| } |
|
|
| #[derive(Clone, Deserialize)] |
| #[serde(untagged)] |
| pub(crate) enum PyPreTokenizerWrapper { |
| Custom(CustomPreTokenizer), |
| Wrapped(PreTokenizerWrapper), |
| } |
|
|
| impl Serialize for PyPreTokenizerWrapper { |
| fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error> |
| where |
| S: Serializer, |
| { |
| match self { |
| PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer), |
| PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer), |
| } |
| } |
| } |
|
|
| #[derive(Clone, Deserialize)] |
| #[serde(untagged)] |
| pub(crate) enum PyPreTokenizerTypeWrapper { |
| Sequence(Vec<Arc<RwLock<PyPreTokenizerWrapper>>>), |
| Single(Arc<RwLock<PyPreTokenizerWrapper>>), |
| } |
|
|
| impl Serialize for PyPreTokenizerTypeWrapper { |
| fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
| where |
| S: Serializer, |
| { |
| match self { |
| PyPreTokenizerTypeWrapper::Sequence(seq) => { |
| let mut ser = serializer.serialize_struct("Sequence", 2)?; |
| ser.serialize_field("type", "Sequence")?; |
| ser.serialize_field("pretokenizers", seq)?; |
| ser.end() |
| } |
| PyPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer), |
| } |
| } |
| } |
|
|
| impl<I> From<I> for PyPreTokenizerWrapper |
| where |
| I: Into<PreTokenizerWrapper>, |
| { |
| fn from(pretok: I) -> Self { |
| PyPreTokenizerWrapper::Wrapped(pretok.into()) |
| } |
| } |
|
|
| impl<I> From<I> for PyPreTokenizerTypeWrapper |
| where |
| I: Into<PyPreTokenizerWrapper>, |
| { |
| fn from(pretok: I) -> Self { |
| PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into()))) |
| } |
| } |
|
|
| impl<I> From<I> for PyPreTokenizer |
| where |
| I: Into<PreTokenizerWrapper>, |
| { |
| fn from(pretok: I) -> Self { |
| PyPreTokenizer { |
| pretok: pretok.into().into(), |
| } |
| } |
| } |
|
|
| impl PreTokenizer for PyPreTokenizerTypeWrapper { |
| fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> { |
| match self { |
| PyPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok), |
| PyPreTokenizerTypeWrapper::Sequence(inner) => inner |
| .iter() |
| .try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)), |
| } |
| } |
| } |
|
|
| impl PreTokenizer for PyPreTokenizerWrapper { |
| fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> { |
| match self { |
| PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok), |
| PyPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok), |
| } |
| } |
| } |
|
|
| |
| #[pymodule] |
| pub fn pre_tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> { |
| m.add_class::<PyPreTokenizer>()?; |
| m.add_class::<PyByteLevel>()?; |
| m.add_class::<PyWhitespace>()?; |
| m.add_class::<PyWhitespaceSplit>()?; |
| m.add_class::<PySplit>()?; |
| m.add_class::<PyBertPreTokenizer>()?; |
| m.add_class::<PyMetaspace>()?; |
| m.add_class::<PyCharDelimiterSplit>()?; |
| m.add_class::<PyPunctuation>()?; |
| m.add_class::<PySequence>()?; |
| m.add_class::<PyDigits>()?; |
| m.add_class::<PyUnicodeScripts>()?; |
| Ok(()) |
| } |
|
|
| #[cfg(test)] |
| mod test { |
| use pyo3::prelude::*; |
| use tk::pre_tokenizers::sequence::Sequence; |
| use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; |
| use tk::pre_tokenizers::PreTokenizerWrapper; |
|
|
| use crate::pre_tokenizers::{ |
| CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerTypeWrapper, PyPreTokenizerWrapper, |
| }; |
|
|
| #[test] |
| fn get_subtype() { |
| Python::with_gil(|py| { |
| let py_norm = PyPreTokenizer::new(Whitespace {}.into()); |
| let py_wsp = py_norm.get_as_subtype(py).unwrap(); |
| assert_eq!("Whitespace", py_wsp.bind(py).get_type().qualname().unwrap()); |
| }) |
| } |
|
|
| #[test] |
| fn serialize() { |
| let py_wrapped: PyPreTokenizerWrapper = Whitespace {}.into(); |
| let py_ser = serde_json::to_string(&py_wrapped).unwrap(); |
| let rs_wrapped = PreTokenizerWrapper::Whitespace(Whitespace {}); |
| let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); |
| assert_eq!(py_ser, rs_ser); |
| let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap(); |
| match py_pretok.pretok { |
| PyPreTokenizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() { |
| PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {} |
| _ => panic!("Expected Whitespace"), |
| }, |
| _ => panic!("Expected wrapped, not custom."), |
| } |
|
|
| let py_seq: PyPreTokenizerWrapper = |
| Sequence::new(vec![Whitespace {}.into(), WhitespaceSplit.into()]).into(); |
| let py_wrapper_ser = serde_json::to_string(&py_seq).unwrap(); |
| let rs_wrapped = PreTokenizerWrapper::Sequence(Sequence::new(vec![ |
| Whitespace {}.into(), |
| WhitespaceSplit.into(), |
| ])); |
| let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); |
| assert_eq!(py_wrapper_ser, rs_ser); |
|
|
| let py_seq = PyPreTokenizer::new(py_seq.into()); |
| let py_ser = serde_json::to_string(&py_seq).unwrap(); |
| assert_eq!(py_wrapper_ser, py_ser); |
|
|
| let obj = Python::with_gil(|py| { |
| let py_wsp = PyPreTokenizer::new(Whitespace {}.into()); |
| let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py); |
| obj |
| }); |
| let py_seq: PyPreTokenizerWrapper = |
| PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(obj)); |
| assert!(serde_json::to_string(&py_seq).is_err()); |
| } |
| } |
|
|