use crate::arc_rwlock_serde; use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask}; use crate::trainers::Trainer; use napi::bindgen_prelude::*; use napi_derive::napi; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use tokenizers as tk; use tokenizers::models::bpe::{BpeBuilder, Merges, Vocab}; use tokenizers::models::wordlevel::WordLevelBuilder; use tokenizers::models::wordpiece::WordPieceBuilder; #[napi] #[derive(Clone, Serialize, Deserialize)] pub struct Model { #[serde(flatten, with = "arc_rwlock_serde")] pub(crate) model: Option>>, } impl From for Model where M: Into, { fn from(wrapper: M) -> Self { Self { model: Some(Arc::new(RwLock::new(wrapper.into()))), } } } #[napi(js_name = "BPE")] pub struct Bpe {} #[napi] impl Bpe { #[napi(factory, ts_return_type = "Model")] pub fn empty() -> Result { let bpe = tk::models::bpe::BPE::default(); Ok(Model { model: Some(Arc::new(RwLock::new(bpe.into()))), }) } #[napi(factory, ts_return_type = "Model")] pub fn init(vocab: Vocab, merges: Merges, options: Option) -> Result { let options = options.unwrap_or_default(); let mut builder = tk::models::bpe::BPE::builder().vocab_and_merges(vocab, merges); builder = options.apply_to_bpe_builder(builder); let model = builder .build() .map_err(|e| Error::from_reason(e.to_string()))?; Ok(Model { model: Some(Arc::new(RwLock::new(model.into()))), }) } #[napi(ts_return_type = "Promise")] pub fn from_file( vocab: String, merges: String, options: Option, ) -> AsyncTask { let options = options.unwrap_or_default(); let mut builder = tk::models::bpe::BPE::from_file(&vocab, &merges); builder = options.apply_to_bpe_builder(builder); AsyncTask::new(BPEFromFilesTask { builder: Some(builder), }) } } impl tk::Model for Model { type Trainer = Trainer; fn tokenize(&self, sequence: &str) -> tk::Result> { self .model .as_ref() .ok_or("Uninitialized Model")? .read() .unwrap() .tokenize(sequence) } fn token_to_id(&self, token: &str) -> Option { self.model.as_ref()?.read().unwrap().token_to_id(token) } fn id_to_token(&self, id: u32) -> Option { self.model.as_ref()?.read().unwrap().id_to_token(id) } fn get_vocab(&self) -> HashMap { self .model .as_ref() .expect("Uninitialized Model") .read() .unwrap() .get_vocab() } fn get_vocab_size(&self) -> usize { self .model .as_ref() .expect("Uninitialized Model") .read() .unwrap() .get_vocab_size() } fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result> { self .model .as_ref() .ok_or("Uninitialized Model")? .read() .unwrap() .save(folder, name) } fn get_trainer(&self) -> Self::Trainer { self .model .as_ref() .expect("Uninitialized Model") .read() .unwrap() .get_trainer() .into() } } #[derive(Default)] #[napi(object)] pub struct BpeOptions { pub cache_capacity: Option, pub dropout: Option, pub unk_token: Option, pub continuing_subword_prefix: Option, pub end_of_word_suffix: Option, pub fuse_unk: Option, pub byte_fallback: Option, } impl BpeOptions { fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder { if let Some(cache_capacity) = self.cache_capacity { builder = builder.cache_capacity(cache_capacity as usize); } if let Some(dropout) = self.dropout { builder = builder.dropout(dropout as f32); } if let Some(unk_token) = self.unk_token { builder = builder.unk_token(unk_token); } if let Some(continuing_subword_prefix) = self.continuing_subword_prefix { builder = builder.continuing_subword_prefix(continuing_subword_prefix); } if let Some(end_of_word_suffix) = self.end_of_word_suffix { builder = builder.end_of_word_suffix(end_of_word_suffix); } if let Some(fuse_unk) = self.fuse_unk { builder = builder.fuse_unk(fuse_unk); } if let Some(byte_fallback) = self.byte_fallback { builder = builder.byte_fallback(byte_fallback); } builder } } #[derive(Default)] #[napi(object)] pub struct WordPieceOptions { pub unk_token: Option, pub continuing_subword_prefix: Option, pub max_input_chars_per_word: Option, } impl WordPieceOptions { fn apply_to_wordpiece_builder(self, mut builder: WordPieceBuilder) -> WordPieceBuilder { if let Some(token) = self.unk_token { builder = builder.unk_token(token); } if let Some(prefix) = self.continuing_subword_prefix { builder = builder.continuing_subword_prefix(prefix); } if let Some(max) = self.max_input_chars_per_word { builder = builder.max_input_chars_per_word(max as usize); } builder } } #[napi] pub struct WordPiece {} #[napi] impl WordPiece { #[napi(factory, ts_return_type = "Model")] pub fn init(vocab: Vocab, options: Option) -> Result { let options = options.unwrap_or_default(); let mut builder = tk::models::wordpiece::WordPiece::builder().vocab(vocab); builder = options.apply_to_wordpiece_builder(builder); let model = builder .build() .map_err(|e| Error::from_reason(e.to_string()))?; Ok(Model { model: Some(Arc::new(RwLock::new(model.into()))), }) } #[napi(factory)] pub fn empty() -> Model { let wordpiece = tk::models::wordpiece::WordPiece::default(); Model { model: Some(Arc::new(RwLock::new(wordpiece.into()))), } } #[napi(ts_return_type = "Promise")] pub fn from_file( vocab: String, options: Option, ) -> AsyncTask { let options = options.unwrap_or_default(); let mut builder = tk::models::wordpiece::WordPiece::from_file(&vocab); builder = options.apply_to_wordpiece_builder(builder); AsyncTask::new(WordPieceFromFilesTask { builder: Some(builder), }) } } #[derive(Default)] #[napi(object)] pub struct WordLevelOptions { pub unk_token: Option, } impl WordLevelOptions { fn apply_to_wordlevel_builder(self, mut builder: WordLevelBuilder) -> WordLevelBuilder { if let Some(token) = self.unk_token { builder = builder.unk_token(token); } builder } } #[napi] pub struct WordLevel {} #[napi] impl WordLevel { #[napi(factory, ts_return_type = "Model")] pub fn init(vocab: Vocab, options: Option) -> Result { let options = options.unwrap_or_default(); let mut builder = tk::models::wordlevel::WordLevel::builder().vocab(vocab); builder = options.apply_to_wordlevel_builder(builder); let model = builder .build() .map_err(|e| Error::from_reason(e.to_string()))?; Ok(Model { model: Some(Arc::new(RwLock::new(model.into()))), }) } #[napi(factory)] pub fn empty() -> Model { let wordlevel = tk::models::wordlevel::WordLevel::default(); Model { model: Some(Arc::new(RwLock::new(wordlevel.into()))), } } #[napi(ts_return_type = "Promise")] pub fn from_file( vocab: String, options: Option, ) -> AsyncTask { let options = options.unwrap_or_default(); let mut builder = tk::models::wordlevel::WordLevel::builder().files(vocab); builder = options.apply_to_wordlevel_builder(builder); AsyncTask::new(WordLevelFromFilesTask { builder: Some(builder), }) } } #[derive(Default)] #[napi(object)] pub struct UnigramOptions { pub unk_id: Option, pub byte_fallback: Option, } #[napi] pub struct Unigram {} #[napi] impl Unigram { #[napi(factory, ts_return_type = "Model")] pub fn init(vocab: Vec<(String, f64)>, options: Option) -> Result { let options = options.unwrap_or_default(); let unigram = tk::models::unigram::Unigram::from( vocab, options.unk_id.map(|u| u as usize), options.byte_fallback.unwrap_or(false), ) .map_err(|e| Error::from_reason(e.to_string()))?; Ok(Model { model: Some(Arc::new(RwLock::new(unigram.into()))), }) } #[napi(factory, ts_return_type = "Model")] pub fn empty() -> Model { let unigram = tk::models::unigram::Unigram::default(); Model { model: Some(Arc::new(RwLock::new(unigram.into()))), } } }