use crate::tokenizer::PaddingOptions; use napi::bindgen_prelude::*; use napi_derive::napi; use tokenizers::utils::truncation::TruncationDirection; use tokenizers::Encoding; #[napi(js_name = "Encoding")] #[derive(Clone, Default)] pub struct JsEncoding { pub(crate) encoding: Option, } impl From for JsEncoding { fn from(value: Encoding) -> Self { Self { encoding: Some(value), } } } impl TryFrom for Encoding { type Error = Error; fn try_from(value: JsEncoding) -> Result { value .encoding .ok_or(Error::from_reason("Uninitialized encoding".to_string())) } } #[napi(string_enum, js_name = "TruncationDirection")] pub enum JsTruncationDirection { Left, Right, } impl From for TruncationDirection { fn from(value: JsTruncationDirection) -> Self { match value { JsTruncationDirection::Left => TruncationDirection::Left, JsTruncationDirection::Right => TruncationDirection::Right, } } } impl TryFrom for JsTruncationDirection { type Error = Error; fn try_from(value: String) -> Result { match value.as_str() { "left" => Ok(JsTruncationDirection::Left), "right" => Ok(JsTruncationDirection::Right), s => Err(Error::from_reason(format!( "{s:?} is not a valid direction" ))), } } } #[napi(string_enum, js_name = "TruncationStrategy")] pub enum JsTruncationStrategy { LongestFirst, OnlyFirst, OnlySecond, } impl From for tokenizers::TruncationStrategy { fn from(value: JsTruncationStrategy) -> Self { match value { JsTruncationStrategy::LongestFirst => tokenizers::TruncationStrategy::LongestFirst, JsTruncationStrategy::OnlyFirst => tokenizers::TruncationStrategy::OnlyFirst, JsTruncationStrategy::OnlySecond => tokenizers::TruncationStrategy::OnlySecond, } } } #[napi] impl JsEncoding { #[napi(constructor)] pub fn new() -> Self { Self { encoding: None } } #[napi] pub fn get_length(&self) -> u32 { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_ids() .len() as u32 } #[napi] pub fn get_n_sequences(&self) -> u32 { self .encoding .as_ref() .expect("Uninitialized Encoding") .n_sequences() as u32 } #[napi] pub fn get_ids(&self) -> Vec { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_ids() .to_vec() } #[napi] pub fn get_type_ids(&self) -> Vec { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_type_ids() .to_vec() } #[napi] pub fn get_attention_mask(&self) -> Vec { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_attention_mask() .to_vec() } #[napi] pub fn get_special_tokens_mask(&self) -> Vec { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_special_tokens_mask() .to_vec() } #[napi] pub fn get_tokens(&self) -> Vec { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_tokens() .to_vec() } #[napi] pub fn get_offsets(&self) -> Vec> { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_offsets() .iter() .map(|(a, b)| vec![*a as u32, *b as u32]) .collect() } #[napi] pub fn get_word_ids(&self) -> Vec> { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_word_ids() .to_vec() } #[napi] pub fn char_to_token(&self, pos: u32, seq_id: Option) -> Option { let seq_id = seq_id.unwrap_or(0); self .encoding .as_ref() .expect("Uninitialized Encoding") .char_to_token(pos as usize, seq_id as usize) .map(|i| i as u32) } #[napi] pub fn char_to_word(&self, pos: u32, seq_id: Option) -> Option { let seq_id = seq_id.unwrap_or(0); self .encoding .as_ref() .expect("Uninitialized Encoding") .char_to_word(pos as usize, seq_id as usize) } #[napi] pub fn pad(&mut self, length: u32, options: Option) -> Result<()> { let params: tokenizers::PaddingParams = options.unwrap_or_default().try_into()?; self.encoding.as_mut().expect("Uninitialized Encoding").pad( length as usize, params.pad_id, params.pad_type_id, ¶ms.pad_token, params.direction, ); Ok(()) } #[napi] pub fn truncate( &mut self, length: u32, stride: Option, direction: Option>, ) -> Result<()> { let stride = stride.unwrap_or_default(); let direction = match direction { None => TruncationDirection::Left, Some(Either::A(s)) => match s.as_str() { "left" => TruncationDirection::Left, "right" => TruncationDirection::Right, d => { return Err(Error::from_reason(format!( "{d} is not a valid truncation direction" ))); } }, Some(Either::B(t)) => t.into(), }; self .encoding .as_mut() .expect("Uninitialized Encoding") .truncate(length as usize, stride as usize, direction); Ok(()) } #[napi(ts_return_type = "[number, number] | null | undefined")] pub fn word_to_tokens(&self, env: Env, word: u32, seq_id: Option) -> Result> { let seq_id = seq_id.unwrap_or(0); if let Some((a, b)) = self .encoding .as_ref() .expect("Uninitialized Encoding") .word_to_tokens(word, seq_id as usize) { let mut arr = env.create_array(2)?; arr.set(0, env.create_uint32(a as u32)?)?; arr.set(1, env.create_uint32(b as u32)?)?; Ok(Some(arr)) } else { Ok(None) } } #[napi(ts_return_type = "[number, number] | null | undefined")] pub fn word_to_chars(&self, env: Env, word: u32, seq_id: Option) -> Result> { let seq_id = seq_id.unwrap_or(0); if let Some((a, b)) = self .encoding .as_ref() .expect("Uninitialized Encoding") .word_to_chars(word, seq_id as usize) { let mut arr = env.create_array(2)?; arr.set(0, env.create_uint32(a as u32)?)?; arr.set(1, env.create_uint32(b as u32)?)?; Ok(Some(arr)) } else { Ok(None) } } #[napi(ts_return_type = "[number, [number, number]] | null | undefined")] pub fn token_to_chars(&self, env: Env, token: u32) -> Result> { if let Some((_, (start, stop))) = self .encoding .as_ref() .expect("Uninitialized Encoding") .token_to_chars(token as usize) { let mut offsets = env.create_array(2)?; offsets.set(0, env.create_uint32(start as u32)?)?; offsets.set(1, env.create_uint32(stop as u32)?)?; Ok(Some(offsets)) } else { Ok(None) } } #[napi] pub fn token_to_word(&self, token: u32) -> Result> { if let Some((_, index)) = self .encoding .as_ref() .expect("Uninitialized Encoding") .token_to_word(token as usize) { Ok(Some(index)) } else { Ok(None) } } #[napi] pub fn get_overflowing(&self) -> Vec { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_overflowing() .clone() .into_iter() .map(|enc| JsEncoding { encoding: Some(enc), }) .collect() } #[napi] pub fn get_sequence_ids(&self) -> Vec> { self .encoding .as_ref() .expect("Uninitialized Encoding") .get_sequence_ids() .into_iter() .map(|s| s.map(|id| id as u32)) .collect() } #[napi] pub fn token_to_sequence(&self, token: u32) -> Option { self .encoding .as_ref() .expect("Uninitialized Encoding") .token_to_sequence(token as usize) .map(|s| s as u32) } }