| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use std::{ |
| collections::HashMap, |
| fs::{read_to_string, File}, |
| io::prelude::*, |
| io::BufReader, |
| ops::{Deref, DerefMut}, |
| path::{Path, PathBuf}, |
| }; |
|
|
| use serde::de::DeserializeOwned; |
| use serde::{Deserialize, Serialize}; |
|
|
| use crate::utils::iter::ResultShunt; |
| use crate::utils::parallelism::*; |
| use crate::utils::progress::{ProgressBar, ProgressStyle}; |
|
|
| mod added_vocabulary; |
| mod encoding; |
| pub mod normalizer; |
| pub mod pattern; |
| pub mod pre_tokenizer; |
| mod serialization; |
|
|
| |
| pub use crate::decoders::DecoderWrapper; |
| pub use crate::models::ModelWrapper; |
| pub use crate::normalizers::NormalizerWrapper; |
| pub use crate::pre_tokenizers::PreTokenizerWrapper; |
| pub use crate::processors::PostProcessorWrapper; |
| |
| pub use crate::utils::iter::LinesWithEnding; |
| pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy}; |
| pub use crate::utils::truncation::{ |
| truncate_encodings, TruncationDirection, TruncationParams, TruncationStrategy, |
| }; |
| pub use added_vocabulary::*; |
| pub use encoding::*; |
| pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior}; |
| pub use pre_tokenizer::*; |
|
|
| pub type Error = Box<dyn std::error::Error + Send + Sync>; |
| pub type Result<T> = std::result::Result<T, Error>; |
| pub type Offsets = (usize, usize); |
|
|
| |
| pub trait Normalizer { |
| fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>; |
| } |
|
|
| |
| |
| |
| |
| |
| pub trait PreTokenizer { |
| fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()>; |
| } |
|
|
| |
| pub trait Model { |
| type Trainer: Trainer + Sync; |
| |
| |
| fn tokenize(&self, sequence: &str) -> Result<Vec<Token>>; |
| |
| fn token_to_id(&self, token: &str) -> Option<u32>; |
| |
| fn id_to_token(&self, id: u32) -> Option<String>; |
| |
| fn get_vocab(&self) -> HashMap<String, u32>; |
| |
| fn get_vocab_size(&self) -> usize; |
| |
| |
| fn save(&self, folder: &Path, prefix: Option<&str>) -> Result<Vec<PathBuf>>; |
| |
| fn get_trainer(&self) -> <Self as Model>::Trainer; |
| } |
|
|
| |
| |
| pub trait PostProcessor { |
| |
| fn added_tokens(&self, is_pair: bool) -> usize; |
| |
| fn process( |
| &self, |
| encoding: Encoding, |
| pair_encoding: Option<Encoding>, |
| add_special_tokens: bool, |
| ) -> Result<Encoding> { |
| let mut encodings = if let Some(pair_encoding) = pair_encoding { |
| vec![encoding, pair_encoding] |
| } else { |
| vec![encoding] |
| }; |
| encodings.iter_mut().enumerate().for_each(|(i, encoding)| { |
| encoding.set_sequence_id(i); |
| encoding |
| .get_overflowing_mut() |
| .iter_mut() |
| .for_each(|encoding| encoding.set_sequence_id(i)); |
| encoding.set_type_ids(vec![i as u32; encoding.len()]); |
| }); |
|
|
| let encodings = self.process_encodings(encodings, add_special_tokens)?; |
| Ok(Encoding::merge(encodings, false)) |
| } |
|
|
| |
| fn process_encodings( |
| &self, |
| encodings: Vec<Encoding>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>>; |
| } |
| impl dyn PostProcessor { |
| pub fn default_process( |
| encodings: Vec<Encoding>, |
| _add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> { |
| match encodings.len() { |
| 1 => Ok(encodings), |
| _ => { |
| let mut final_encoding = Encoding::default(); |
| for (i, mut encoding) in encodings.into_iter().enumerate() { |
| encoding.set_sequence_id(i); |
| final_encoding.merge_with(encoding, false); |
| } |
| Ok(vec![final_encoding]) |
| } |
| } |
| } |
| } |
|
|
| #[derive(thiserror::Error, Debug)] |
| pub enum ProcessorError { |
| #[error("encodings vector length must be either 1 or 2")] |
| InvalidEncodingsVecLength, |
| } |
|
|
| |
| pub trait Decoder { |
| fn decode(&self, tokens: Vec<String>) -> Result<String> { |
| let results = self.decode_chain(tokens)?; |
| Ok(results.join("")) |
| } |
| fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>>; |
| } |
|
|
| |
| |
| pub trait Trainer { |
| type Model: Model + Sized; |
| |
| fn should_show_progress(&self) -> bool; |
| |
| |
| fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>; |
| |
| |
| fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()> |
| where |
| I: Iterator<Item = S> + Send, |
| S: AsRef<str> + Send, |
| F: Fn(&str) -> Result<Vec<String>> + Sync; |
| } |
|
|
| #[derive(Debug, Clone, PartialEq, Eq)] |
| pub struct Token { |
| pub id: u32, |
| pub value: String, |
| pub offsets: (usize, usize), |
| } |
| impl Token { |
| pub fn new(id: u32, value: String, offsets: (usize, usize)) -> Self { |
| Self { id, value, offsets } |
| } |
| } |
|
|
| use std::borrow::Cow; |
| #[derive(Debug, Clone)] |
| pub enum InputSequence<'s> { |
| Raw(Cow<'s, str>), |
| PreTokenized(Cow<'s, [&'s str]>), |
| PreTokenizedOwned(Cow<'s, [String]>), |
| PreTokenizedCow(Cow<'s, [Cow<'s, str>]>), |
| } |
|
|
| impl<'s> From<Cow<'s, str>> for InputSequence<'s> { |
| fn from(input: Cow<'s, str>) -> Self { |
| Self::Raw(input) |
| } |
| } |
|
|
| impl<'s> From<&'s str> for InputSequence<'s> { |
| fn from(input: &'s str) -> Self { |
| Self::Raw(Cow::Borrowed(input)) |
| } |
| } |
|
|
| impl From<String> for InputSequence<'_> { |
| fn from(input: String) -> Self { |
| Self::Raw(Cow::Owned(input)) |
| } |
| } |
|
|
| impl<'s> From<&'s [&'s str]> for InputSequence<'s> { |
| fn from(input: &'s [&'s str]) -> Self { |
| Self::PreTokenized(Cow::Borrowed(input)) |
| } |
| } |
|
|
| impl<'s> From<Vec<&'s str>> for InputSequence<'s> { |
| fn from(input: Vec<&'s str>) -> Self { |
| Self::PreTokenized(Cow::Owned(input)) |
| } |
| } |
|
|
| impl<'s> From<&'s [String]> for InputSequence<'s> { |
| fn from(input: &'s [String]) -> Self { |
| Self::PreTokenizedOwned(Cow::Borrowed(input)) |
| } |
| } |
|
|
| impl<'s> From<Vec<String>> for InputSequence<'s> { |
| fn from(input: Vec<String>) -> Self { |
| Self::PreTokenizedOwned(Cow::Owned(input)) |
| } |
| } |
|
|
| impl<'s> From<Vec<Cow<'s, str>>> for InputSequence<'s> { |
| fn from(input: Vec<Cow<'s, str>>) -> Self { |
| Self::PreTokenizedCow(Cow::Owned(input)) |
| } |
| } |
|
|
| impl<'s> From<&'s [Cow<'s, str>]> for InputSequence<'s> { |
| fn from(input: &'s [Cow<'s, str>]) -> Self { |
| Self::PreTokenizedCow(Cow::Borrowed(input)) |
| } |
| } |
|
|
| #[derive(Debug, Clone)] |
| pub enum EncodeInput<'s> { |
| Single(InputSequence<'s>), |
| Dual(InputSequence<'s>, InputSequence<'s>), |
| } |
|
|
| impl<'s, I: Into<InputSequence<'s>>> From<I> for EncodeInput<'s> { |
| fn from(input: I) -> Self { |
| Self::Single(input.into()) |
| } |
| } |
|
|
| impl<'s, I1, I2> From<(I1, I2)> for EncodeInput<'s> |
| where |
| I1: Into<InputSequence<'s>>, |
| I2: Into<InputSequence<'s>>, |
| { |
| fn from(input: (I1, I2)) -> Self { |
| Self::Dual(input.0.into(), input.1.into()) |
| } |
| } |
|
|
| #[derive(thiserror::Error, Debug)] |
| #[error("{0}")] |
| pub struct BuilderError(String); |
|
|
| |
| |
| |
| pub struct TokenizerBuilder<M, N, PT, PP, D> { |
| model: Option<M>, |
| normalizer: Option<N>, |
| pre_tokenizer: Option<PT>, |
| post_processor: Option<PP>, |
| decoder: Option<D>, |
|
|
| added_vocabulary: AddedVocabulary, |
|
|
| truncation: Option<TruncationParams>, |
| padding: Option<PaddingParams>, |
| } |
|
|
| impl<M, N, PT, PP, D> Default for TokenizerBuilder<M, N, PT, PP, D> |
| where |
| M: Model, |
| N: Normalizer, |
| PT: PreTokenizer, |
| PP: PostProcessor, |
| D: Decoder, |
| { |
| fn default() -> Self { |
| Self::new() |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerBuilder<M, N, PT, PP, D> |
| where |
| M: Model, |
| N: Normalizer, |
| PT: PreTokenizer, |
| PP: PostProcessor, |
| D: Decoder, |
| { |
| |
| pub fn new() -> Self { |
| Self { |
| model: None, |
| normalizer: None, |
| pre_tokenizer: None, |
| post_processor: None, |
| decoder: None, |
| added_vocabulary: AddedVocabulary::new(), |
| truncation: None, |
| padding: None, |
| } |
| } |
|
|
| |
| |
| |
| pub fn build(self) -> Result<TokenizerImpl<M, N, PT, PP, D>> { |
| let model = self |
| .model |
| .ok_or_else(|| Box::new(BuilderError("Model missing.".into())))?; |
| Ok(TokenizerImpl { |
| normalizer: self.normalizer, |
| pre_tokenizer: self.pre_tokenizer, |
| model, |
|
|
| post_processor: self.post_processor, |
| decoder: self.decoder, |
| added_vocabulary: self.added_vocabulary, |
| truncation: self.truncation, |
| padding: self.padding, |
| }) |
| } |
|
|
| |
| #[must_use] |
| pub fn with_model(mut self, model: M) -> Self { |
| self.model = Some(model); |
| self |
| } |
|
|
| |
| #[must_use] |
| pub fn with_normalizer(mut self, normalizer: Option<N>) -> Self { |
| self.normalizer = normalizer; |
| self |
| } |
|
|
| |
| #[must_use] |
| pub fn with_pre_tokenizer(mut self, pretokenizer: Option<PT>) -> Self { |
| self.pre_tokenizer = pretokenizer; |
| self |
| } |
|
|
| |
| #[must_use] |
| pub fn with_post_processor(mut self, post_processor: Option<PP>) -> Self { |
| self.post_processor = post_processor; |
| self |
| } |
|
|
| |
| #[must_use] |
| pub fn with_decoder(mut self, decoder: Option<D>) -> Self { |
| self.decoder = decoder; |
| self |
| } |
|
|
| |
| pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self { |
| self.added_vocabulary = added_vocabulary; |
| self |
| } |
|
|
| |
| #[must_use] |
| pub fn with_truncation(mut self, trunc: Option<TruncationParams>) -> Self { |
| self.truncation = trunc; |
| self |
| } |
|
|
| |
| #[must_use] |
| pub fn with_padding(mut self, padding: Option<PaddingParams>) -> Self { |
| self.padding = padding; |
| self |
| } |
| } |
|
|
| #[derive(Serialize, Deserialize, Debug, Clone)] |
| pub struct Tokenizer( |
| TokenizerImpl< |
| ModelWrapper, |
| NormalizerWrapper, |
| PreTokenizerWrapper, |
| PostProcessorWrapper, |
| DecoderWrapper, |
| >, |
| ); |
|
|
| impl Tokenizer { |
| |
| pub fn new(model: impl Into<ModelWrapper>) -> Self { |
| Self(TokenizerImpl::new(model.into())) |
| } |
|
|
| |
| pub fn into_inner( |
| self, |
| ) -> TokenizerImpl< |
| ModelWrapper, |
| NormalizerWrapper, |
| PreTokenizerWrapper, |
| PostProcessorWrapper, |
| DecoderWrapper, |
| > { |
| self.0 |
| } |
| pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> { |
| let content = read_to_string(file)?; |
| let tokenizer = serde_json::from_str(&content)?; |
| Ok(tokenizer) |
| } |
| pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> Result<Self> { |
| let tokenizer = serde_json::from_slice(bytes.as_ref())?; |
| Ok(tokenizer) |
| } |
| #[cfg(feature = "http")] |
| pub fn from_pretrained<S: AsRef<str>>( |
| identifier: S, |
| params: Option<crate::utils::from_pretrained::FromPretrainedParameters>, |
| ) -> Result<Self> { |
| let tokenizer_file = crate::utils::from_pretrained::from_pretrained(identifier, params)?; |
| Tokenizer::from_file(tokenizer_file) |
| } |
| } |
|
|
| impl std::str::FromStr for Tokenizer { |
| type Err = Box<dyn std::error::Error + Send + Sync>; |
|
|
| fn from_str(s: &str) -> Result<Self> { |
| Ok(serde_json::from_str(s)?) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> From<TokenizerImpl<M, N, PT, PP, D>> for Tokenizer |
| where |
| M: Into<ModelWrapper>, |
| N: Into<NormalizerWrapper>, |
| PT: Into<PreTokenizerWrapper>, |
| PP: Into<PostProcessorWrapper>, |
| D: Into<DecoderWrapper>, |
| { |
| fn from(t: TokenizerImpl<M, N, PT, PP, D>) -> Self { |
| Self(TokenizerImpl { |
| model: t.model.into(), |
| normalizer: t.normalizer.map(Into::into), |
| pre_tokenizer: t.pre_tokenizer.map(Into::into), |
| post_processor: t.post_processor.map(Into::into), |
| decoder: t.decoder.map(Into::into), |
| added_vocabulary: t.added_vocabulary, |
| padding: t.padding, |
| truncation: t.truncation, |
| }) |
| } |
| } |
|
|
| impl Deref for Tokenizer { |
| type Target = TokenizerImpl< |
| ModelWrapper, |
| NormalizerWrapper, |
| PreTokenizerWrapper, |
| PostProcessorWrapper, |
| DecoderWrapper, |
| >; |
|
|
| fn deref(&self) -> &Self::Target { |
| &self.0 |
| } |
| } |
|
|
| impl DerefMut for Tokenizer { |
| fn deref_mut(&mut self) -> &mut Self::Target { |
| &mut self.0 |
| } |
| } |
|
|
| #[derive(thiserror::Error, Debug)] |
| #[error("{0}")] |
| pub struct TruncationParamError(String); |
|
|
| |
| #[derive(Clone, Debug)] |
| pub struct TokenizerImpl<M, N, PT, PP, D> { |
| |
| normalizer: Option<N>, |
| pre_tokenizer: Option<PT>, |
| model: M, |
| post_processor: Option<PP>, |
| decoder: Option<D>, |
|
|
| |
| added_vocabulary: AddedVocabulary, |
|
|
| |
| truncation: Option<TruncationParams>, |
| padding: Option<PaddingParams>, |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: Model, |
| N: Normalizer, |
| PT: PreTokenizer, |
| PP: PostProcessor, |
| D: Decoder, |
| { |
| |
| pub fn new(model: M) -> Self { |
| Self { |
| normalizer: None, |
| pre_tokenizer: None, |
| model, |
| post_processor: None, |
| decoder: None, |
|
|
| added_vocabulary: AddedVocabulary::new(), |
|
|
| truncation: None, |
| padding: None, |
| } |
| } |
|
|
| |
| pub fn with_normalizer(&mut self, normalizer: Option<impl Into<N>>) -> &mut Self { |
| self.normalizer = normalizer.map(|norm| norm.into()); |
| self |
| } |
| |
| pub fn get_normalizer(&self) -> Option<&N> { |
| self.normalizer.as_ref() |
| } |
|
|
| |
| pub fn with_pre_tokenizer(&mut self, pre_tokenizer: Option<impl Into<PT>>) -> &mut Self { |
| self.pre_tokenizer = pre_tokenizer.map(|tok| tok.into()); |
| self |
| } |
|
|
| |
| pub fn get_pre_tokenizer(&self) -> Option<&PT> { |
| self.pre_tokenizer.as_ref() |
| } |
|
|
| |
| pub fn with_post_processor(&mut self, post_processor: Option<impl Into<PP>>) -> &mut Self { |
| self.post_processor = post_processor.map(|post_proc| post_proc.into()); |
| self |
| } |
|
|
| |
| pub fn get_post_processor(&self) -> Option<&PP> { |
| self.post_processor.as_ref() |
| } |
|
|
| |
| pub fn with_decoder(&mut self, decoder: Option<impl Into<D>>) -> &mut Self { |
| self.decoder = decoder.map(|dec| dec.into()); |
| self |
| } |
|
|
| |
| pub fn get_decoder(&self) -> Option<&D> { |
| self.decoder.as_ref() |
| } |
|
|
| |
| pub fn with_model(&mut self, model: impl Into<M>) -> &mut Self { |
| self.model = model.into(); |
| self |
| } |
|
|
| |
| pub fn get_model(&self) -> &M { |
| &self.model |
| } |
|
|
| |
| pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self { |
| self.added_vocabulary = added_vocabulary; |
| self |
| } |
|
|
| |
| pub fn get_added_vocabulary(&self) -> &AddedVocabulary { |
| &self.added_vocabulary |
| } |
|
|
| |
| |
| |
| pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> { |
| if let Some(trunc_params) = &trunc { |
| let n_added_tokens = self.get_n_added_tokens(false); |
| let effective_max_length = trunc_params.max_length - n_added_tokens; |
| if effective_max_length < trunc_params.stride { |
| return Err(Box::new(TruncationParamError(format!( |
| "tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ", |
| trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens |
| )))); |
| } |
| } |
| self.truncation = trunc; |
| Ok(self) |
| } |
|
|
| |
| pub fn get_truncation(&self) -> Option<&TruncationParams> { |
| self.truncation.as_ref() |
| } |
|
|
| |
| pub fn get_truncation_mut(&mut self) -> Option<&mut TruncationParams> { |
| self.truncation.as_mut() |
| } |
|
|
| |
| pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &mut Self { |
| self.padding = padding; |
| self |
| } |
|
|
| |
| pub fn get_padding(&self) -> Option<&PaddingParams> { |
| self.padding.as_ref() |
| } |
|
|
| |
| pub fn get_padding_mut(&mut self) -> Option<&mut PaddingParams> { |
| self.padding.as_mut() |
| } |
|
|
| |
| pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> { |
| let mut final_vocab = self.model.get_vocab(); |
|
|
| if with_added_tokens { |
| let added_vocab = self.added_vocabulary.get_vocab(); |
| if !added_vocab.is_empty() { |
| final_vocab.reserve(added_vocab.len()); |
| for (token, id) in added_vocab { |
| final_vocab.insert(token.clone(), *id); |
| } |
| } |
| } |
|
|
| final_vocab |
| } |
|
|
| |
| pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> { |
| self.added_vocabulary.get_added_tokens_decoder().clone() |
| } |
|
|
| |
| pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize { |
| |
| |
| if with_added_tokens { |
| self.get_vocab(true).len() |
| } else { |
| self.model.get_vocab_size() |
| } |
| } |
|
|
| |
| pub fn token_to_id(&self, token: &str) -> Option<u32> { |
| self.added_vocabulary.token_to_id(token, &self.model) |
| } |
|
|
| |
| pub fn id_to_token(&self, id: u32) -> Option<String> { |
| self.added_vocabulary |
| .simple_id_to_token(id) |
| .or_else(|| self.model.id_to_token(id)) |
| } |
|
|
| |
| pub fn set_encode_special_tokens(&mut self, value: bool) { |
| self.added_vocabulary.set_encode_special_tokens(value); |
| } |
|
|
| |
| pub fn get_encode_special_tokens(&self) -> bool { |
| self.added_vocabulary.get_encode_special_tokens() |
| } |
|
|
| |
| fn encode_single_sequence( |
| &self, |
| sequence: InputSequence, |
| type_id: u32, |
| offsets_type: OffsetType, |
| ) -> Result<Encoding> { |
| let encode = |is_pre_tokenized, subseq_idx, subseq| -> Result<Encoding> { |
| let normalized = self |
| .added_vocabulary |
| .extract_and_normalize(self.normalizer.as_ref(), subseq); |
| let pre_tokenized = self.do_pre_tokenize(normalized)?; |
| let subseq_encoding = self.do_tokenize( |
| pre_tokenized, |
| type_id, |
| if is_pre_tokenized { |
| Some(subseq_idx as u32) |
| } else { |
| None |
| }, |
| offsets_type, |
| )?; |
|
|
| Ok(subseq_encoding) |
| }; |
|
|
| match sequence { |
| InputSequence::PreTokenized(seq) => seq |
| .iter() |
| .enumerate() |
| .map(|(i, sequence)| encode(true, i, sequence)) |
| .collect(), |
| InputSequence::PreTokenizedOwned(seq) => seq |
| .iter() |
| .enumerate() |
| .map(|(i, sequence)| encode(true, i, sequence)) |
| .collect(), |
| InputSequence::PreTokenizedCow(seq) => seq |
| .iter() |
| .enumerate() |
| .map(|(i, sequence)| encode(true, i, sequence)) |
| .collect(), |
| InputSequence::Raw(seq) => encode(false, 0, seq.as_ref()), |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn encode_fast<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding> |
| where |
| E: Into<EncodeInput<'s>>, |
| { |
| |
| let (sequence, pair) = match input.into() { |
| EncodeInput::Single(s1) => (s1, None), |
| EncodeInput::Dual(s1, s2) => (s1, Some(s2)), |
| }; |
|
|
| |
| let encoding = self.encode_single_sequence(sequence, 0, OffsetType::None)?; |
| let pair_encoding = pair |
| .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::None)) |
| .transpose()?; |
|
|
| |
| self.post_process(encoding, pair_encoding, add_special_tokens) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn encode<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding> |
| where |
| E: Into<EncodeInput<'s>>, |
| { |
| |
| let (sequence, pair) = match input.into() { |
| EncodeInput::Single(s1) => (s1, None), |
| EncodeInput::Dual(s1, s2) => (s1, Some(s2)), |
| }; |
|
|
| |
| let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Byte)?; |
| let pair_encoding = pair |
| .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::Byte)) |
| .transpose()?; |
|
|
| |
| self.post_process(encoding, pair_encoding, add_special_tokens) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn encode_char_offsets<'s, E>(&self, input: E, add_special_tokens: bool) -> Result<Encoding> |
| where |
| E: Into<EncodeInput<'s>>, |
| { |
| |
| let (sequence, pair) = match input.into() { |
| EncodeInput::Single(s1) => (s1, None), |
| EncodeInput::Dual(s1, s2) => (s1, Some(s2)), |
| }; |
|
|
| |
| let encoding = self.encode_single_sequence(sequence, 0, OffsetType::Char)?; |
| let pair_encoding = pair |
| .map(|sequence| self.encode_single_sequence(sequence, 1, OffsetType::Char)) |
| .transpose()?; |
|
|
| |
| self.post_process(encoding, pair_encoding, add_special_tokens) |
| } |
|
|
| |
| pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> { |
| let tokens = ids |
| .iter() |
| .filter_map(|id| { |
| self.added_vocabulary |
| .simple_id_to_token(*id) |
| .or_else(|| self.model.id_to_token(*id)) |
| .filter(|token| { |
| !skip_special_tokens || !self.added_vocabulary.is_special_token(token) |
| }) |
| }) |
| .collect::<Vec<_>>(); |
|
|
| if let Some(decoder) = &self.decoder { |
| decoder.decode(tokens) |
| } else { |
| Ok(tokens.join(" ")) |
| } |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: Model, |
| { |
| |
| |
| fn do_tokenize<P: Into<PreTokenizedString>>( |
| &self, |
| pretokenized: P, |
| type_id: u32, |
| word_idx: Option<u32>, |
| offsets_type: OffsetType, |
| ) -> Result<Encoding> { |
| let mut pretokenized: PreTokenizedString = pretokenized.into(); |
| pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?; |
| pretokenized.into_encoding(word_idx, type_id, offsets_type) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| N: Normalizer, |
| { |
| |
| fn do_normalize<V: Into<NormalizedString>>(&self, normalized: V) -> Result<NormalizedString> { |
| let mut normalized: NormalizedString = normalized.into(); |
|
|
| if let Some(ref normalizer) = self.normalizer { |
| normalizer.normalize(&mut normalized)?; |
| } |
|
|
| Ok(normalized) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| N: Normalizer, |
| M: Model, |
| { |
| |
| |
| pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize { |
| self.added_vocabulary |
| .add_special_tokens(tokens, &self.model, self.normalizer.as_ref()) |
| } |
|
|
| |
| pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize { |
| self.added_vocabulary |
| .add_tokens(tokens, &self.model, self.normalizer.as_ref()) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| PT: PreTokenizer, |
| { |
| |
| fn do_pre_tokenize<P: Into<PreTokenizedString>>( |
| &self, |
| pretokenized: P, |
| ) -> Result<PreTokenizedString> { |
| let mut pretokenized: PreTokenizedString = pretokenized.into(); |
| if let Some(ref pretok) = self.pre_tokenizer { |
| pretok.pre_tokenize(&mut pretokenized)?; |
| } |
|
|
| Ok(pretokenized) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| PP: PostProcessor, |
| { |
| |
| pub fn post_process( |
| &self, |
| encoding: Encoding, |
| pair_encoding: Option<Encoding>, |
| add_special_tokens: bool, |
| ) -> Result<Encoding> { |
| |
| let (encoding, pair_encoding) = { |
| if let Some(trunc) = &self.truncation { |
| let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some()); |
|
|
| if add_special_tokens && n_added_tokens > 0 { |
| let params = TruncationParams { |
| max_length: trunc.max_length - n_added_tokens, |
| ..*trunc |
| }; |
| truncate_encodings(encoding, pair_encoding, ¶ms)? |
| } else { |
| truncate_encodings(encoding, pair_encoding, trunc)? |
| } |
| } else { |
| (encoding, pair_encoding) |
| } |
| }; |
|
|
| |
| let final_encoding = if let Some(processor) = &self.post_processor { |
| processor.process(encoding, pair_encoding, add_special_tokens)? |
| } else { |
| let encodings = if let Some(pair_encoding) = pair_encoding { |
| vec![encoding, pair_encoding] |
| } else { |
| vec![encoding] |
| }; |
| let mut encodings = |
| <dyn PostProcessor>::default_process(encodings, add_special_tokens)?; |
| if encodings.len() != 1 { |
| panic!("We haven't reduced the encodings like we should have"); |
| } |
| encodings.pop().unwrap() |
| }; |
|
|
| |
| let [final_encoding] = if let Some(params) = &self.padding { |
| let mut arr = [final_encoding]; |
| pad_encodings(&mut arr, params)?; |
| arr |
| } else { |
| [final_encoding] |
| }; |
|
|
| Ok(final_encoding) |
| } |
|
|
| fn get_n_added_tokens(&self, is_pair: bool) -> usize { |
| if let Some(processor) = &self.post_processor { |
| processor.added_tokens(is_pair) |
| } else { |
| 0 |
| } |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: Model + Send + Sync, |
| N: Normalizer + Send + Sync, |
| PT: PreTokenizer + Send + Sync, |
| PP: PostProcessor + Send + Sync, |
| D: Decoder + Send + Sync, |
| { |
| |
| pub fn encode_batch<'s, E>( |
| &self, |
| inputs: Vec<E>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> |
| where |
| E: Into<EncodeInput<'s>> + Send, |
| { |
| let mut encodings = inputs |
| .into_maybe_par_iter() |
| .map(|input| self.encode(input, add_special_tokens)) |
| .collect::<Result<Vec<Encoding>>>()?; |
|
|
| if let Some(params) = &self.padding { |
| |
| pad_encodings(&mut encodings, params)?; |
| } |
|
|
| Ok(encodings) |
| } |
|
|
| |
| |
| pub fn encode_batch_char_offsets<'s, E>( |
| &self, |
| inputs: Vec<E>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> |
| where |
| E: Into<EncodeInput<'s>> + Send, |
| { |
| let mut encodings = inputs |
| .into_maybe_par_iter() |
| .map(|input| self.encode_char_offsets(input, add_special_tokens)) |
| .collect::<Result<Vec<Encoding>>>()?; |
|
|
| if let Some(params) = &self.padding { |
| |
| pad_encodings(&mut encodings, params)?; |
| } |
|
|
| Ok(encodings) |
| } |
|
|
| |
| pub fn encode_batch_fast<'s, E>( |
| &self, |
| inputs: Vec<E>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> |
| where |
| E: Into<EncodeInput<'s>> + Send, |
| { |
| let mut encodings = inputs |
| .into_maybe_par_iter() |
| .map(|input| self.encode_fast(input, add_special_tokens)) |
| .collect::<Result<Vec<Encoding>>>()?; |
|
|
| if let Some(params) = &self.padding { |
| |
| pad_encodings(&mut encodings, params)?; |
| } |
|
|
| Ok(encodings) |
| } |
|
|
| |
| pub fn decode_batch( |
| &self, |
| sentences: &[&[u32]], |
| skip_special_tokens: bool, |
| ) -> Result<Vec<String>> |
| where |
| M: Send + Sync, |
| { |
| sentences |
| .into_maybe_par_iter() |
| .map(|sentence| self.decode(sentence, skip_special_tokens)) |
| .collect() |
| } |
|
|
| |
| pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self> |
| where |
| T: Trainer<Model = M> + Sync, |
| { |
| let mut len = 0; |
| for file in files.iter() { |
| len += File::open(file) |
| .and_then(|f| f.metadata()) |
| .map(|m| m.len())?; |
| } |
|
|
| let max_read = 1_000_000; |
|
|
| ResultShunt::process( |
| files.into_iter().flat_map(|filename| { |
| match File::open(filename) { |
| Ok(file) => { |
| let file = BufReader::with_capacity(max_read, file); |
| |
| |
| |
| itertools::Either::Left(file.lines_with_ending()) |
| } |
| Err(e) => itertools::Either::Right(std::iter::once(Err(e))), |
| } |
| }), |
| |sequences| -> Result<()> { |
| let progress = if trainer.should_show_progress() { |
| let progress = ProgressBar::new(len); |
| progress.set_style( |
| ProgressStyle::default_bar() |
| .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {percent:>18!}%") |
| .expect("Invalid progress template"), |
| ); |
| progress |
| .set_message(format!("Pre-processing files ({:.2} Mo)", len / 1_000_000)); |
| Some(progress) |
| } else { |
| None |
| }; |
|
|
| trainer.feed( |
| sequences.inspect(|s| { |
| if let Some(progress) = &progress { |
| progress.inc(s.len() as u64) |
| } |
| }), |
| |seq| { |
| let normalized = self.do_normalize(seq.as_ref())?; |
| let pre_tokenized = self.do_pre_tokenize(normalized)?; |
| Ok(pre_tokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, _, _)| s.to_owned()) |
| .collect()) |
| }, |
| )?; |
|
|
| if let Some(pbar) = progress { |
| pbar.finish(); |
| } |
| let special_tokens = trainer.train(&mut self.model)?; |
| self.add_special_tokens(&special_tokens); |
|
|
| Ok(()) |
| }, |
| )??; |
| Ok(self) |
| } |
|
|
| |
| pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self> |
| where |
| T: Trainer<Model = M> + Sync, |
| I: Iterator<Item = S> + Send, |
| S: AsRef<str> + Send, |
| { |
| let (lower, upper) = sequences.size_hint(); |
| let len = upper.unwrap_or(lower) as u64; |
| let progress = if trainer.should_show_progress() { |
| let progress = ProgressBar::new(len); |
| progress.set_style( |
| ProgressStyle::default_bar() |
| .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}") |
| .expect("Invalid progress template"), |
| ); |
| progress.set_message("Pre-processing sequences"); |
| Some(progress) |
| } else { |
| None |
| }; |
|
|
| trainer.feed( |
| sequences.inspect(|_s| { |
| if let Some(progress) = &progress { |
| progress.inc(1) |
| } |
| }), |
| |seq| { |
| let normalized = self.do_normalize(seq.as_ref())?; |
| let pre_tokenized = self.do_pre_tokenize(normalized)?; |
| Ok(pre_tokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, _, _)| s.to_owned()) |
| .collect()) |
| }, |
| )?; |
| if let Some(pbar) = progress { |
| pbar.finish(); |
| } |
|
|
| let special_tokens = trainer.train(&mut self.model)?; |
| self.add_special_tokens(&special_tokens); |
|
|
| Ok(self) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: for<'de> Deserialize<'de> + Model, |
| N: for<'de> Deserialize<'de> + Normalizer, |
| PT: for<'de> Deserialize<'de> + PreTokenizer, |
| PP: for<'de> Deserialize<'de> + PostProcessor, |
| D: for<'de> Deserialize<'de> + Decoder, |
| { |
| type Err = Error; |
|
|
| fn from_str(s: &str) -> Result<Self> { |
| Ok(serde_json::from_str(s)?) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: DeserializeOwned + Model, |
| N: DeserializeOwned + Normalizer, |
| PT: DeserializeOwned + PreTokenizer, |
| PP: DeserializeOwned + PostProcessor, |
| D: DeserializeOwned + Decoder, |
| { |
| |
| pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> { |
| let content = read_to_string(file)?; |
| let tokenizer = serde_json::from_str(&content)?; |
| Ok(tokenizer) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: DeserializeOwned + Model, |
| N: DeserializeOwned + Normalizer, |
| PT: DeserializeOwned + PreTokenizer, |
| PP: DeserializeOwned + PostProcessor, |
| D: DeserializeOwned + Decoder, |
| { |
| |
| pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> Result<Self> { |
| let tokenizer = serde_json::from_slice(bytes.as_ref())?; |
| Ok(tokenizer) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: DeserializeOwned + Model, |
| N: DeserializeOwned + Normalizer, |
| PT: DeserializeOwned + PreTokenizer, |
| PP: DeserializeOwned + PostProcessor, |
| D: DeserializeOwned + Decoder, |
| { |
| #[deprecated( |
| since = "0.14.0", |
| note = "Users should download the file separately using https://github.com/huggingface/hf-hub instead, which splits concerns of accessing the web, and should use the new cache layout" |
| )] |
| #[cfg(feature = "http")] |
| |
| |
| pub fn from_pretrained<S: AsRef<str>>( |
| identifier: S, |
| params: Option<crate::utils::from_pretrained::FromPretrainedParameters>, |
| ) -> Result<Self> { |
| let tokenizer_file = crate::utils::from_pretrained::from_pretrained(identifier, params)?; |
| TokenizerImpl::from_file(tokenizer_file) |
| } |
| } |
|
|
| impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> |
| where |
| M: Serialize, |
| N: Serialize, |
| PT: Serialize, |
| PP: Serialize, |
| D: Serialize, |
| { |
| |
| pub fn to_string(&self, pretty: bool) -> Result<String> { |
| Ok(if pretty { |
| serde_json::to_string_pretty(self)? |
| } else { |
| serde_json::to_string(self)? |
| }) |
| } |
|
|
| |
| pub fn save<P: AsRef<Path>>(&self, path: P, pretty: bool) -> Result<()> { |
| let serialized = self.to_string(pretty)?; |
|
|
| let mut file = File::create(path)?; |
| file.write_all(serialized.as_bytes())?; |
|
|
| Ok(()) |
| } |
| } |
|
|
| #[cfg(test)] |
| mod test { |
| #[cfg(feature = "http")] |
| #[test] |
| fn test_decoding_with_added_bpe() { |
| use crate::{ |
| normalizers, |
| pre_tokenizers::split::{Split, SplitPattern}, |
| AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer, |
| }; |
|
|
| let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap(); |
| tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new())); |
| tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split( |
| Split::new( |
| SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()), |
| SplitDelimiterBehavior::Isolated, |
| false, |
| ) |
| .unwrap(), |
| )); |
| tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]); |
| let encoded = tokenizer |
| .encode("Hey! how is this token: 嗎", false) |
| .unwrap(); |
| assert_eq!( |
| encoded.get_ids(), |
| [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256] |
| ); |
| assert_eq!( |
| encoded.get_tokens(), |
| ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"] |
| ); |
|
|
| let decoded = tokenizer.decode(encoded.get_ids(), false); |
| assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎"); |
|
|
| tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]); |
| let encoded = tokenizer |
| .encode("Hey! how is this token: д", false) |
| .unwrap(); |
| assert_eq!( |
| encoded.get_ids(), |
| [19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257] |
| ); |
| assert_eq!( |
| encoded.get_tokens(), |
| ["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"] |
| ); |
| let decoded = tokenizer.decode(encoded.get_ids(), false); |
| assert_eq!(decoded.unwrap(), "Hey! how is this token: д") |
| } |
| } |
|
|