use crate::arc_rwlock_serde; use serde::{Deserialize, Serialize}; extern crate tokenizers as tk; use napi::bindgen_prelude::*; use napi_derive::napi; use std::sync::{Arc, RwLock}; use tk::decoders::DecoderWrapper; /// Decoder #[derive(Clone, Serialize, Deserialize)] #[napi] pub struct Decoder { #[serde(flatten, with = "arc_rwlock_serde")] decoder: Option>>, } #[napi] impl Decoder { #[napi] pub fn decode(&self, tokens: Vec) -> Result { use tk::Decoder; self .decoder .as_ref() .unwrap() .read() .unwrap() .decode(tokens) .map_err(|e| Error::from_reason(format!("{}", e))) } } impl tk::Decoder for Decoder { fn decode_chain(&self, tokens: Vec) -> tk::Result> { self .decoder .as_ref() .ok_or("Uninitialized Decoder")? .read() .unwrap() .decode_chain(tokens) } } #[napi] pub fn bpe_decoder(suffix: Option) -> Decoder { let suffix = suffix.unwrap_or("".to_string()); let decoder = Some(Arc::new(RwLock::new( tk::decoders::bpe::BPEDecoder::new(suffix).into(), ))); Decoder { decoder } } #[napi] pub fn byte_fallback_decoder() -> Decoder { Decoder { decoder: Some(Arc::new(RwLock::new( tk::decoders::byte_fallback::ByteFallback::new().into(), ))), } } #[napi] pub fn ctc_decoder( #[napi(ts_arg_type = "string = ''")] pad_token: Option, word_delimiter_token: Option, cleanup: Option, ) -> Decoder { let pad_token = pad_token.unwrap_or("".to_string()); let word_delimiter_token = word_delimiter_token.unwrap_or("|".to_string()); let cleanup = cleanup.unwrap_or(true); let decoder = Some(Arc::new(RwLock::new( tk::decoders::ctc::CTC::new(pad_token, word_delimiter_token, cleanup).into(), ))); Decoder { decoder } } #[napi] pub fn fuse_decoder() -> Decoder { Decoder { decoder: Some(Arc::new(RwLock::new( tk::decoders::fuse::Fuse::new().into(), ))), } } #[napi] pub fn metaspace_decoder( #[napi(ts_arg_type = "string = '▁'")] replacement: Option, #[napi(ts_arg_type = "prepend_scheme = 'always'")] prepend_scheme: Option, #[napi(ts_arg_type = "split = true")] split: Option, ) -> Result { use tk::pre_tokenizers::metaspace::PrependScheme; let split = split.unwrap_or(true); let replacement = replacement.unwrap_or("▁".to_string()); if replacement.chars().count() != 1 { return Err(Error::from_reason( "replacement is supposed to be a single char", )); } let replacement = replacement.chars().next().unwrap(); let prepend_scheme: PrependScheme = match prepend_scheme.unwrap_or(String::from("always")).as_str() { "always" => PrependScheme::Always, "first" => PrependScheme::First, "never" => PrependScheme::Never, _ => { return Err(Error::from_reason( "prepend_scheme is supposed to be either 'always', 'first' or 'never'", )); } }; Ok(Decoder { decoder: Some(Arc::new(RwLock::new( tk::decoders::metaspace::Metaspace::new(replacement, prepend_scheme, split).into(), ))), }) } #[napi] pub fn replace_decoder(pattern: String, content: String) -> Result { Ok(Decoder { decoder: Some(Arc::new(RwLock::new( tk::normalizers::replace::Replace::new(pattern, content) .map_err(|e| Error::from_reason(e.to_string()))? .into(), ))), }) } #[napi] pub fn sequence_decoder(decoders: Vec<&Decoder>) -> Decoder { let sequence: Vec = decoders .into_iter() .filter_map(|decoder| { decoder .decoder .as_ref() .map(|decoder| (**decoder).read().unwrap().clone()) }) .clone() .collect(); Decoder { decoder: Some(Arc::new(RwLock::new(tk::DecoderWrapper::Sequence( tk::decoders::sequence::Sequence::new(sequence), )))), } } #[napi] pub fn strip_decoder(content: String, left: u32, right: u32) -> Result { let content: char = content.chars().next().ok_or(Error::from_reason( "Expected non empty string for strip pattern", ))?; Ok(Decoder { decoder: Some(Arc::new(RwLock::new( tk::decoders::strip::Strip::new(content, left as usize, right as usize).into(), ))), }) } #[napi] pub fn word_piece_decoder( #[napi(ts_arg_type = "string = '##'")] prefix: Option, #[napi(ts_arg_type = "bool = true")] cleanup: Option, ) -> Decoder { let prefix = prefix.unwrap_or("##".to_string()); let cleanup = cleanup.unwrap_or(true); Decoder { decoder: Some(Arc::new(RwLock::new( tk::decoders::wordpiece::WordPiece::new(prefix, cleanup).into(), ))), } }