use crate::decoders::Decoder; use crate::encoding::{JsEncoding, JsTruncationDirection, JsTruncationStrategy}; use crate::models::Model; use crate::normalizers::Normalizer; use crate::pre_tokenizers::PreTokenizer; use crate::processors::Processor; use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask}; use crate::trainers::Trainer; use std::collections::HashMap; use tokenizers::Model as ModelTrait; use napi::bindgen_prelude::*; use napi_derive::napi; use std::sync::{Arc, RwLock}; use tokenizers as tk; #[napi] #[derive(Default)] pub enum PaddingDirection { #[default] Left, Right, } impl From for tk::PaddingDirection { fn from(w: PaddingDirection) -> Self { match w { PaddingDirection::Left => tk::PaddingDirection::Left, PaddingDirection::Right => tk::PaddingDirection::Right, } } } impl TryFrom for PaddingDirection { type Error = Error; fn try_from(w: String) -> Result { match w.as_str() { "left" => Ok(PaddingDirection::Left), "right" => Ok(PaddingDirection::Right), s => Err(Error::from_reason(format!( "{s:?} is not a valid direction" ))), } } } #[napi(object)] #[derive(Default)] pub struct PaddingOptions { pub max_length: Option, pub direction: Option>, pub pad_to_multiple_of: Option, pub pad_id: Option, pub pad_type_id: Option, pub pad_token: Option, } impl TryFrom for tk::PaddingParams { type Error = Error; fn try_from(value: PaddingOptions) -> Result { let direction = match value.direction { Some(either) => match either { Either::A(string) => { let direction: PaddingDirection = string.try_into()?; direction.into() } Either::B(direction) => direction.into(), }, None => tk::PaddingDirection::Right, }; Ok(Self { pad_to_multiple_of: value.pad_to_multiple_of.map(|s| s as usize), pad_id: value.pad_id.unwrap_or_default(), pad_type_id: value.pad_type_id.unwrap_or_default(), pad_token: value.pad_token.unwrap_or("[PAD]".to_string()), direction, strategy: match value.max_length { Some(length) => tk::PaddingStrategy::Fixed(length as usize), None => tk::PaddingStrategy::BatchLongest, }, }) } } #[napi(object)] #[derive(Default)] pub struct EncodeOptions { pub is_pretokenized: Option, pub add_special_tokens: Option, } #[derive(Default)] struct EncodeOptionsDef { // TODO // is_pretokenized: bool, add_special_tokens: bool, } impl From for EncodeOptionsDef { fn from(value: EncodeOptions) -> Self { EncodeOptionsDef { // TODO // is_pretokenized: value.is_pretokenized.unwrap_or(false), add_special_tokens: value.add_special_tokens.unwrap_or(true), } } } #[napi(object)] #[derive(Default)] pub struct TruncationOptions { pub max_length: Option, pub strategy: Option, pub direction: Option>, pub stride: Option, } impl TryFrom for tk::TruncationParams { type Error = Error; fn try_from(value: TruncationOptions) -> Result { let direction = match value.direction { Some(either) => match either { Either::A(string) => { let direction: JsTruncationDirection = string.try_into()?; direction.into() } Either::B(direction) => direction.into(), }, None => Default::default(), }; Ok(Self { max_length: value.max_length.unwrap_or(0) as usize, strategy: value.strategy.map(|s| s.into()).unwrap_or_default(), direction, stride: value.stride.unwrap_or_default() as usize, }) } } #[napi(object)] pub struct AddedTokenOptions { pub single_word: Option, pub left_strip: Option, pub right_strip: Option, pub normalized: Option, } #[napi] #[derive(Clone)] pub struct AddedToken { token: tk::AddedToken, } #[napi] impl AddedToken { #[napi(constructor)] pub fn from(token: String, is_special: bool, options: Option) -> Self { let mut token = tk::AddedToken::from(token, is_special); if let Some(options) = options { if let Some(sw) = options.single_word { token = token.single_word(sw); } if let Some(ls) = options.left_strip { token = token.lstrip(ls); } if let Some(rs) = options.right_strip { token = token.rstrip(rs); } if let Some(n) = options.normalized { token = token.normalized(n); } } Self { token } } #[napi] pub fn get_content(&self) -> String { self.token.content.clone() } } impl From for tk::AddedToken { fn from(v: AddedToken) -> Self { v.token } } type RsTokenizer = tk::TokenizerImpl; #[napi] #[derive(Clone)] pub struct Tokenizer { pub(crate) tokenizer: Arc>, } #[napi] impl Tokenizer { #[napi(constructor)] pub fn new(model: &Model) -> Self { Self { tokenizer: Arc::new(RwLock::new(tk::TokenizerImpl::new((*model).clone()))), } } #[napi] pub fn set_pre_tokenizer(&mut self, pre_tokenizer: &PreTokenizer) { self .tokenizer .write() .unwrap() .with_pre_tokenizer(Some((*pre_tokenizer).clone())); } #[napi] pub fn set_decoder(&mut self, decoder: &Decoder) { self .tokenizer .write() .unwrap() .with_decoder(Some((*decoder).clone())); } #[napi] pub fn set_model(&mut self, model: &Model) { self.tokenizer.write().unwrap().with_model((*model).clone()); } #[napi] pub fn set_post_processor(&mut self, post_processor: &Processor) { self .tokenizer .write() .unwrap() .with_post_processor(Some((*post_processor).clone())); } #[napi] pub fn set_normalizer(&mut self, normalizer: &Normalizer) { self .tokenizer .write() .unwrap() .with_normalizer(Some((*normalizer).clone())); } #[napi] pub fn save(&self, path: String, pretty: Option) -> Result<()> { let pretty = pretty.unwrap_or(false); self .tokenizer .read() .unwrap() .save(path, pretty) .map_err(|e| Error::from_reason(format!("{}", e))) } #[napi] pub fn add_added_tokens(&mut self, tokens: Vec<&AddedToken>) -> u32 { let tokens: Vec<_> = tokens .into_iter() .map(|tok| (*tok).clone().into()) .collect(); self.tokenizer.write().unwrap().add_tokens(&tokens) as u32 } #[napi] pub fn add_tokens(&mut self, tokens: Vec) -> u32 { let tokens: Vec<_> = tokens .into_iter() .map(|tok| tk::AddedToken::from(tok, false)) .collect(); self.tokenizer.write().unwrap().add_tokens(&tokens) as u32 } #[napi(ts_return_type = "Promise")] pub fn encode( &self, #[napi(ts_arg_type = "InputSequence")] sentence: String, #[napi(ts_arg_type = "InputSequence | null")] pair: Option, encode_options: Option, ) -> AsyncTask> { let options: EncodeOptionsDef = encode_options.unwrap_or_default().into(); let input: tk::EncodeInput = match pair { Some(pair) => (sentence, pair).into(), None => sentence.into(), }; AsyncTask::new(EncodeTask { tokenizer: (*self).clone(), input: Some(input), add_special_tokens: options.add_special_tokens, }) } #[napi(ts_return_type = "Promise")] pub fn encode_batch( &self, #[napi(ts_arg_type = "EncodeInput[]")] sentences: Vec, encode_options: Option, ) -> AsyncTask> { let options: EncodeOptionsDef = encode_options.unwrap_or_default().into(); let inputs: Vec = sentences .into_iter() .map(|sentence| sentence.into()) .collect(); AsyncTask::new(EncodeBatchTask { tokenizer: (*self).clone(), inputs: Some(inputs), add_special_tokens: options.add_special_tokens, }) } #[napi(ts_return_type = "Promise")] pub fn decode(&self, ids: Vec, skip_special_tokens: bool) -> AsyncTask { AsyncTask::new(DecodeTask { tokenizer: (*self).clone(), ids, skip_special_tokens, }) } #[napi(ts_return_type = "Promise")] pub fn decode_batch( &self, ids: Vec>, skip_special_tokens: bool, ) -> AsyncTask { AsyncTask::new(DecodeBatchTask { tokenizer: (*self).clone(), ids, skip_special_tokens, }) } #[napi(factory)] pub fn from_string(s: String) -> Result { let tokenizer: tk::tokenizer::TokenizerImpl< Model, Normalizer, PreTokenizer, Processor, Decoder, > = s .parse() .map_err(|e| Error::from_reason(format!("{}", e)))?; Ok(Self { tokenizer: Arc::new(RwLock::new(tokenizer)), }) } #[napi(factory)] pub fn from_file(file: String) -> Result { let tokenizer = tk::tokenizer::TokenizerImpl::from_file(file) .map_err(|e| Error::from_reason(format!("Error loading from file{}", e)))?; Ok(Self { tokenizer: Arc::new(RwLock::new(tokenizer)), }) } #[napi] pub fn add_special_tokens(&mut self, tokens: Vec) { let tokens: Vec<_> = tokens .into_iter() .map(|s| tk::AddedToken::from(s, true)) .collect(); self.tokenizer.write().unwrap().add_special_tokens(&tokens); } #[napi] pub fn set_truncation( &mut self, max_length: u32, options: Option, ) -> Result<()> { let mut options: tk::TruncationParams = if let Some(options) = options { options.try_into()? } else { Default::default() }; options.max_length = max_length as usize; self .tokenizer .write() .unwrap() .with_truncation(Some(options)) .unwrap(); Ok(()) } #[napi] pub fn disable_truncation(&mut self) { self .tokenizer .write() .unwrap() .with_truncation(None) .unwrap(); } #[napi] pub fn set_padding(&mut self, options: Option) -> Result<()> { let options = if let Some(options) = options { Some(options.try_into()?) } else { None }; self.tokenizer.write().unwrap().with_padding(options); Ok(()) } #[napi] pub fn disable_padding(&mut self) { self.tokenizer.write().unwrap().with_padding(None); } #[napi] pub fn get_decoder(&self) -> Option { self.tokenizer.read().unwrap().get_decoder().cloned() } #[napi] pub fn get_normalizer(&self) -> Option { self.tokenizer.read().unwrap().get_normalizer().cloned() } #[napi] pub fn get_pre_tokenizer(&self) -> Option { self.tokenizer.read().unwrap().get_pre_tokenizer().cloned() } #[napi] pub fn get_post_processor(&self) -> Option { self.tokenizer.read().unwrap().get_post_processor().cloned() } #[napi] pub fn get_vocab(&self, with_added_tokens: Option) -> HashMap { let with_added_tokens = with_added_tokens.unwrap_or(true); self.tokenizer.read().unwrap().get_vocab(with_added_tokens) } #[napi] pub fn get_vocab_size(&self, with_added_tokens: Option) -> u32 { self.get_vocab(with_added_tokens).len() as u32 } #[napi] pub fn id_to_token(&self, id: u32) -> Option { self.tokenizer.read().unwrap().id_to_token(id) } #[napi] pub fn token_to_id(&self, token: String) -> Option { self.tokenizer.read().unwrap().token_to_id(&token) } #[napi] pub fn train(&mut self, files: Vec) -> Result<()> { let mut trainer: Trainer = self .tokenizer .read() .unwrap() .get_model() .model .as_ref() .unwrap() .read() .unwrap() .get_trainer() .into(); self .tokenizer .write() .unwrap() .train_from_files(&mut trainer, files) .map_err(|e| Error::from_reason(format!("{}", e)))?; Ok(()) } #[napi] pub fn running_tasks(&self) -> u32 { std::sync::Arc::strong_count(&self.tokenizer) as u32 } #[napi] pub fn post_process( &self, encoding: &JsEncoding, pair: Option<&JsEncoding>, add_special_tokens: Option, ) -> Result { let add_special_tokens = add_special_tokens.unwrap_or(true); Ok( self .tokenizer .read() .unwrap() .post_process( (*encoding).clone().try_into()?, if let Some(pair) = pair { Some((*pair).clone().try_into()?) } else { None }, add_special_tokens, ) .map_err(|e| Error::from_reason(format!("{}", e)))? .into(), ) } } #[napi(object)] #[derive(Default)] pub struct JsFromPretrainedParameters { pub revision: Option, pub auth_token: Option, }