| use pyo3::exceptions; | |
| use pyo3::prelude::*; | |
| use pyo3::types::*; | |
| use tk::tokenizer::{Offsets, PaddingDirection}; | |
| use tk::utils::truncation::TruncationDirection; | |
| use tokenizers as tk; | |
| use crate::error::{deprecation_warning, PyError}; | |
| /// The :class:`~tokenizers.Encoding` represents the output of a :class:`~tokenizers.Tokenizer`. | |
| pub struct PyEncoding { | |
| pub encoding: tk::tokenizer::Encoding, | |
| } | |
| impl From<tk::tokenizer::Encoding> for PyEncoding { | |
| fn from(v: tk::tokenizer::Encoding) -> Self { | |
| Self { encoding: v } | |
| } | |
| } | |
| impl PyEncoding { | |
| fn new() -> Self { | |
| Self { | |
| encoding: tk::tokenizer::Encoding::default(), | |
| } | |
| } | |
| fn __getstate__(&self, py: Python) -> PyResult<PyObject> { | |
| let data = serde_json::to_string(&self.encoding).map_err(|e| { | |
| exceptions::PyException::new_err(format!( | |
| "Error while attempting to pickle Encoding: {}", | |
| e | |
| )) | |
| })?; | |
| Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py)) | |
| } | |
| fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { | |
| match state.extract::<&PyBytes>(py) { | |
| Ok(s) => { | |
| self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| { | |
| exceptions::PyException::new_err(format!( | |
| "Error while attempting to unpickle Encoding: {}", | |
| e | |
| )) | |
| })?; | |
| Ok(()) | |
| } | |
| Err(e) => Err(e), | |
| } | |
| } | |
| fn __repr__(&self) -> PyResult<String> { | |
| Ok(format!( | |
| "Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \ | |
| attention_mask, special_tokens_mask, overflowing])", | |
| self.encoding.get_ids().len() | |
| )) | |
| } | |
| fn __len__(&self) -> PyResult<usize> { | |
| Ok(self.encoding.len()) | |
| } | |
| /// Merge the list of encodings into one final :class:`~tokenizers.Encoding` | |
| /// | |
| /// Args: | |
| /// encodings (A :obj:`List` of :class:`~tokenizers.Encoding`): | |
| /// The list of encodings that should be merged in one | |
| /// | |
| /// growing_offsets (:obj:`bool`, defaults to :obj:`True`): | |
| /// Whether the offsets should accumulate while merging | |
| /// | |
| /// Returns: | |
| /// :class:`~tokenizers.Encoding`: The resulting Encoding | |
| fn merge(encodings: Vec<PyRef<PyEncoding>>, growing_offsets: bool) -> PyEncoding { | |
| tk::tokenizer::Encoding::merge( | |
| encodings.into_iter().map(|e| e.encoding.clone()), | |
| growing_offsets, | |
| ) | |
| .into() | |
| } | |
| /// The number of sequences represented | |
| /// | |
| /// Returns: | |
| /// :obj:`int`: The number of sequences in this :class:`~tokenizers.Encoding` | |
| fn get_n_sequences(&self) -> usize { | |
| self.encoding.n_sequences() | |
| } | |
| /// Set the given sequence index | |
| /// | |
| /// Set the given sequence index for the whole range of tokens contained in this | |
| /// :class:`~tokenizers.Encoding`. | |
| fn set_sequence_id(&mut self, sequence_id: usize) { | |
| self.encoding.set_sequence_id(sequence_id); | |
| } | |
| /// The generated IDs | |
| /// | |
| /// The IDs are the main input to a Language Model. They are the token indices, | |
| /// the numerical representations that a LM understands. | |
| /// | |
| /// Returns: | |
| /// :obj:`List[int]`: The list of IDs | |
| fn get_ids(&self) -> Vec<u32> { | |
| self.encoding.get_ids().to_vec() | |
| } | |
| /// The generated tokens | |
| /// | |
| /// They are the string representation of the IDs. | |
| /// | |
| /// Returns: | |
| /// :obj:`List[str]`: The list of tokens | |
| fn get_tokens(&self) -> Vec<String> { | |
| self.encoding.get_tokens().to_vec() | |
| } | |
| /// The generated word indices. | |
| /// | |
| /// .. warning:: | |
| /// This is deprecated and will be removed in a future version. | |
| /// Please use :obj:`~tokenizers.Encoding.word_ids` instead. | |
| /// | |
| /// They represent the index of the word associated to each token. | |
| /// When the input is pre-tokenized, they correspond to the ID of the given input label, | |
| /// otherwise they correspond to the words indices as defined by the | |
| /// :class:`~tokenizers.pre_tokenizers.PreTokenizer` that was used. | |
| /// | |
| /// For special tokens and such (any token that was generated from something that was | |
| /// not part of the input), the output is :obj:`None` | |
| /// | |
| /// Returns: | |
| /// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index. | |
| fn get_words(&self, py: Python<'_>) -> PyResult<Vec<Option<u32>>> { | |
| deprecation_warning( | |
| py, | |
| "0.9.4", | |
| "Encoding.words is deprecated, please use Encoding.word_ids instead.", | |
| )?; | |
| Ok(self.get_word_ids()) | |
| } | |
| /// The generated word indices. | |
| /// | |
| /// They represent the index of the word associated to each token. | |
| /// When the input is pre-tokenized, they correspond to the ID of the given input label, | |
| /// otherwise they correspond to the words indices as defined by the | |
| /// :class:`~tokenizers.pre_tokenizers.PreTokenizer` that was used. | |
| /// | |
| /// For special tokens and such (any token that was generated from something that was | |
| /// not part of the input), the output is :obj:`None` | |
| /// | |
| /// Returns: | |
| /// A :obj:`List` of :obj:`Optional[int]`: A list of optional word index. | |
| fn get_word_ids(&self) -> Vec<Option<u32>> { | |
| self.encoding.get_word_ids().to_vec() | |
| } | |
| /// The generated sequence indices. | |
| /// | |
| /// They represent the index of the input sequence associated to each token. | |
| /// The sequence id can be None if the token is not related to any input sequence, | |
| /// like for example with special tokens. | |
| /// | |
| /// Returns: | |
| /// A :obj:`List` of :obj:`Optional[int]`: A list of optional sequence index. | |
| fn get_sequence_ids(&self) -> Vec<Option<usize>> { | |
| self.encoding.get_sequence_ids() | |
| } | |
| /// The generated type IDs | |
| /// | |
| /// Generally used for tasks like sequence classification or question answering, | |
| /// these tokens let the LM know which input sequence corresponds to each tokens. | |
| /// | |
| /// Returns: | |
| /// :obj:`List[int]`: The list of type ids | |
| fn get_type_ids(&self) -> Vec<u32> { | |
| self.encoding.get_type_ids().to_vec() | |
| } | |
| /// The offsets associated to each token | |
| /// | |
| /// These offsets let's you slice the input string, and thus retrieve the original | |
| /// part that led to producing the corresponding token. | |
| /// | |
| /// Returns: | |
| /// A :obj:`List` of :obj:`Tuple[int, int]`: The list of offsets | |
| fn get_offsets(&self) -> Vec<(usize, usize)> { | |
| self.encoding.get_offsets().to_vec() | |
| } | |
| /// The special token mask | |
| /// | |
| /// This indicates which tokens are special tokens, and which are not. | |
| /// | |
| /// Returns: | |
| /// :obj:`List[int]`: The special tokens mask | |
| fn get_special_tokens_mask(&self) -> Vec<u32> { | |
| self.encoding.get_special_tokens_mask().to_vec() | |
| } | |
| /// The attention mask | |
| /// | |
| /// This indicates to the LM which tokens should be attended to, and which should not. | |
| /// This is especially important when batching sequences, where we need to applying | |
| /// padding. | |
| /// | |
| /// Returns: | |
| /// :obj:`List[int]`: The attention mask | |
| fn get_attention_mask(&self) -> Vec<u32> { | |
| self.encoding.get_attention_mask().to_vec() | |
| } | |
| /// A :obj:`List` of overflowing :class:`~tokenizers.Encoding` | |
| /// | |
| /// When using truncation, the :class:`~tokenizers.Tokenizer` takes care of splitting | |
| /// the output into as many pieces as required to match the specified maximum length. | |
| /// This field lets you retrieve all the subsequent pieces. | |
| /// | |
| /// When you use pairs of sequences, the overflowing pieces will contain enough | |
| /// variations to cover all the possible combinations, while respecting the provided | |
| /// maximum length. | |
| fn get_overflowing(&self) -> Vec<PyEncoding> { | |
| self.encoding | |
| .get_overflowing() | |
| .clone() | |
| .into_iter() | |
| .map(|e| e.into()) | |
| .collect() | |
| } | |
| /// Get the encoded tokens corresponding to the word at the given index | |
| /// in one of the input sequences. | |
| /// | |
| /// Args: | |
| /// word_index (:obj:`int`): | |
| /// The index of a word in one of the input sequences. | |
| /// sequence_index (:obj:`int`, defaults to :obj:`0`): | |
| /// The index of the sequence that contains the target word | |
| /// | |
| /// Returns: | |
| /// :obj:`Tuple[int, int]`: The range of tokens: :obj:`(first, last + 1)` | |
| fn word_to_tokens(&self, word_index: u32, sequence_index: usize) -> Option<(usize, usize)> { | |
| self.encoding.word_to_tokens(word_index, sequence_index) | |
| } | |
| /// Get the offsets of the word at the given index in one of the input sequences. | |
| /// | |
| /// Args: | |
| /// word_index (:obj:`int`): | |
| /// The index of a word in one of the input sequences. | |
| /// sequence_index (:obj:`int`, defaults to :obj:`0`): | |
| /// The index of the sequence that contains the target word | |
| /// | |
| /// Returns: | |
| /// :obj:`Tuple[int, int]`: The range of characters (span) :obj:`(first, last + 1)` | |
| fn word_to_chars(&self, word_index: u32, sequence_index: usize) -> Option<Offsets> { | |
| self.encoding.word_to_chars(word_index, sequence_index) | |
| } | |
| /// Get the index of the sequence represented by the given token. | |
| /// | |
| /// In the general use case, this method returns :obj:`0` for a single sequence or | |
| /// the first sequence of a pair, and :obj:`1` for the second sequence of a pair | |
| /// | |
| /// Args: | |
| /// token_index (:obj:`int`): | |
| /// The index of a token in the encoded sequence. | |
| /// | |
| /// Returns: | |
| /// :obj:`int`: The sequence id of the given token | |
| fn token_to_sequence(&self, token_index: usize) -> Option<usize> { | |
| self.encoding.token_to_sequence(token_index) | |
| } | |
| /// Get the offsets of the token at the given index. | |
| /// | |
| /// The returned offsets are related to the input sequence that contains the | |
| /// token. In order to determine in which input sequence it belongs, you | |
| /// must call :meth:`~tokenizers.Encoding.token_to_sequence()`. | |
| /// | |
| /// Args: | |
| /// token_index (:obj:`int`): | |
| /// The index of a token in the encoded sequence. | |
| /// | |
| /// Returns: | |
| /// :obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)` | |
| fn token_to_chars(&self, token_index: usize) -> Option<Offsets> { | |
| let (_, offsets) = self.encoding.token_to_chars(token_index)?; | |
| Some(offsets) | |
| } | |
| /// Get the index of the word that contains the token in one of the input sequences. | |
| /// | |
| /// The returned word index is related to the input sequence that contains | |
| /// the token. In order to determine in which input sequence it belongs, you | |
| /// must call :meth:`~tokenizers.Encoding.token_to_sequence()`. | |
| /// | |
| /// Args: | |
| /// token_index (:obj:`int`): | |
| /// The index of a token in the encoded sequence. | |
| /// | |
| /// Returns: | |
| /// :obj:`int`: The index of the word in the relevant input sequence. | |
| fn token_to_word(&self, token_index: usize) -> Option<u32> { | |
| let (_, word_idx) = self.encoding.token_to_word(token_index)?; | |
| Some(word_idx) | |
| } | |
| /// Get the token that contains the char at the given position in the input sequence. | |
| /// | |
| /// Args: | |
| /// char_pos (:obj:`int`): | |
| /// The position of a char in the input string | |
| /// sequence_index (:obj:`int`, defaults to :obj:`0`): | |
| /// The index of the sequence that contains the target char | |
| /// | |
| /// Returns: | |
| /// :obj:`int`: The index of the token that contains this char in the encoded sequence | |
| fn char_to_token(&self, char_pos: usize, sequence_index: usize) -> Option<usize> { | |
| self.encoding.char_to_token(char_pos, sequence_index) | |
| } | |
| /// Get the word that contains the char at the given position in the input sequence. | |
| /// | |
| /// Args: | |
| /// char_pos (:obj:`int`): | |
| /// The position of a char in the input string | |
| /// sequence_index (:obj:`int`, defaults to :obj:`0`): | |
| /// The index of the sequence that contains the target char | |
| /// | |
| /// Returns: | |
| /// :obj:`int`: The index of the word that contains this char in the input sequence | |
| fn char_to_word(&self, char_pos: usize, sequence_index: usize) -> Option<u32> { | |
| self.encoding.char_to_word(char_pos, sequence_index) | |
| } | |
| /// Pad the :class:`~tokenizers.Encoding` at the given length | |
| /// | |
| /// Args: | |
| /// length (:obj:`int`): | |
| /// The desired length | |
| /// | |
| /// direction: (:obj:`str`, defaults to :obj:`right`): | |
| /// The expected padding direction. Can be either :obj:`right` or :obj:`left` | |
| /// | |
| /// pad_id (:obj:`int`, defaults to :obj:`0`): | |
| /// The ID corresponding to the padding token | |
| /// | |
| /// pad_type_id (:obj:`int`, defaults to :obj:`0`): | |
| /// The type ID corresponding to the padding token | |
| /// | |
| /// pad_token (:obj:`str`, defaults to `[PAD]`): | |
| /// The pad token to use | |
| fn pad(&mut self, length: usize, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> { | |
| let mut pad_id = 0; | |
| let mut pad_type_id = 0; | |
| let mut pad_token = "[PAD]".to_string(); | |
| let mut direction = PaddingDirection::Right; | |
| if let Some(kwargs) = kwargs { | |
| for (key, value) in kwargs { | |
| let key: &str = key.extract()?; | |
| match key { | |
| "direction" => { | |
| let value: &str = value.extract()?; | |
| direction = match value { | |
| "left" => Ok(PaddingDirection::Left), | |
| "right" => Ok(PaddingDirection::Right), | |
| other => Err(PyError(format!( | |
| "Unknown `direction`: `{}`. Use \ | |
| one of `left` or `right`", | |
| other | |
| )) | |
| .into_pyerr::<exceptions::PyValueError>()), | |
| }?; | |
| } | |
| "pad_id" => pad_id = value.extract()?, | |
| "pad_type_id" => pad_type_id = value.extract()?, | |
| "pad_token" => pad_token = value.extract()?, | |
| _ => println!("Ignored unknown kwarg option {}", key), | |
| } | |
| } | |
| } | |
| self.encoding | |
| .pad(length, pad_id, pad_type_id, &pad_token, direction); | |
| Ok(()) | |
| } | |
| /// Truncate the :class:`~tokenizers.Encoding` at the given length | |
| /// | |
| /// If this :class:`~tokenizers.Encoding` represents multiple sequences, when truncating | |
| /// this information is lost. It will be considered as representing a single sequence. | |
| /// | |
| /// Args: | |
| /// max_length (:obj:`int`): | |
| /// The desired length | |
| /// | |
| /// stride (:obj:`int`, defaults to :obj:`0`): | |
| /// The length of previous content to be included in each overflowing piece | |
| /// | |
| /// direction (:obj:`str`, defaults to :obj:`right`): | |
| /// Truncate direction | |
| fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> { | |
| let tdir = match direction { | |
| "left" => Ok(TruncationDirection::Left), | |
| "right" => Ok(TruncationDirection::Right), | |
| _ => Err(PyError(format!( | |
| "Invalid truncation direction value : {}", | |
| direction | |
| )) | |
| .into_pyerr::<exceptions::PyValueError>()), | |
| }?; | |
| self.encoding.truncate(max_length, stride, tdir); | |
| Ok(()) | |
| } | |
| } | |