2ira's picture
offline_compression_graph_code
72c0672 verified
use crate::arc_rwlock_serde;
use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask};
use crate::trainers::Trainer;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tokenizers as tk;
use tokenizers::models::bpe::{BpeBuilder, Merges, Vocab};
use tokenizers::models::wordlevel::WordLevelBuilder;
use tokenizers::models::wordpiece::WordPieceBuilder;
#[napi]
#[derive(Clone, Serialize, Deserialize)]
pub struct Model {
#[serde(flatten, with = "arc_rwlock_serde")]
pub(crate) model: Option<Arc<RwLock<tk::models::ModelWrapper>>>,
}
impl<M> From<M> for Model
where
M: Into<tk::models::ModelWrapper>,
{
fn from(wrapper: M) -> Self {
Self {
model: Some(Arc::new(RwLock::new(wrapper.into()))),
}
}
}
#[napi(js_name = "BPE")]
pub struct Bpe {}
#[napi]
impl Bpe {
#[napi(factory, ts_return_type = "Model")]
pub fn empty() -> Result<Model> {
let bpe = tk::models::bpe::BPE::default();
Ok(Model {
model: Some(Arc::new(RwLock::new(bpe.into()))),
})
}
#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vocab, merges: Merges, options: Option<BpeOptions>) -> Result<Model> {
let options = options.unwrap_or_default();
let mut builder = tk::models::bpe::BPE::builder().vocab_and_merges(vocab, merges);
builder = options.apply_to_bpe_builder(builder);
let model = builder
.build()
.map_err(|e| Error::from_reason(e.to_string()))?;
Ok(Model {
model: Some(Arc::new(RwLock::new(model.into()))),
})
}
#[napi(ts_return_type = "Promise<Model>")]
pub fn from_file(
vocab: String,
merges: String,
options: Option<BpeOptions>,
) -> AsyncTask<BPEFromFilesTask> {
let options = options.unwrap_or_default();
let mut builder = tk::models::bpe::BPE::from_file(&vocab, &merges);
builder = options.apply_to_bpe_builder(builder);
AsyncTask::new(BPEFromFilesTask {
builder: Some(builder),
})
}
}
impl tk::Model for Model {
type Trainer = Trainer;
fn tokenize(&self, sequence: &str) -> tk::Result<Vec<tk::Token>> {
self
.model
.as_ref()
.ok_or("Uninitialized Model")?
.read()
.unwrap()
.tokenize(sequence)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.model.as_ref()?.read().unwrap().token_to_id(token)
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}
fn get_vocab(&self) -> HashMap<String, u32> {
self
.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_vocab()
}
fn get_vocab_size(&self) -> usize {
self
.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_vocab_size()
}
fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result<Vec<PathBuf>> {
self
.model
.as_ref()
.ok_or("Uninitialized Model")?
.read()
.unwrap()
.save(folder, name)
}
fn get_trainer(&self) -> Self::Trainer {
self
.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_trainer()
.into()
}
}
#[derive(Default)]
#[napi(object)]
pub struct BpeOptions {
pub cache_capacity: Option<u32>,
pub dropout: Option<f64>,
pub unk_token: Option<String>,
pub continuing_subword_prefix: Option<String>,
pub end_of_word_suffix: Option<String>,
pub fuse_unk: Option<bool>,
pub byte_fallback: Option<bool>,
}
impl BpeOptions {
fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder {
if let Some(cache_capacity) = self.cache_capacity {
builder = builder.cache_capacity(cache_capacity as usize);
}
if let Some(dropout) = self.dropout {
builder = builder.dropout(dropout as f32);
}
if let Some(unk_token) = self.unk_token {
builder = builder.unk_token(unk_token);
}
if let Some(continuing_subword_prefix) = self.continuing_subword_prefix {
builder = builder.continuing_subword_prefix(continuing_subword_prefix);
}
if let Some(end_of_word_suffix) = self.end_of_word_suffix {
builder = builder.end_of_word_suffix(end_of_word_suffix);
}
if let Some(fuse_unk) = self.fuse_unk {
builder = builder.fuse_unk(fuse_unk);
}
if let Some(byte_fallback) = self.byte_fallback {
builder = builder.byte_fallback(byte_fallback);
}
builder
}
}
#[derive(Default)]
#[napi(object)]
pub struct WordPieceOptions {
pub unk_token: Option<String>,
pub continuing_subword_prefix: Option<String>,
pub max_input_chars_per_word: Option<u32>,
}
impl WordPieceOptions {
fn apply_to_wordpiece_builder(self, mut builder: WordPieceBuilder) -> WordPieceBuilder {
if let Some(token) = self.unk_token {
builder = builder.unk_token(token);
}
if let Some(prefix) = self.continuing_subword_prefix {
builder = builder.continuing_subword_prefix(prefix);
}
if let Some(max) = self.max_input_chars_per_word {
builder = builder.max_input_chars_per_word(max as usize);
}
builder
}
}
#[napi]
pub struct WordPiece {}
#[napi]
impl WordPiece {
#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vocab, options: Option<WordPieceOptions>) -> Result<Model> {
let options = options.unwrap_or_default();
let mut builder = tk::models::wordpiece::WordPiece::builder().vocab(vocab);
builder = options.apply_to_wordpiece_builder(builder);
let model = builder
.build()
.map_err(|e| Error::from_reason(e.to_string()))?;
Ok(Model {
model: Some(Arc::new(RwLock::new(model.into()))),
})
}
#[napi(factory)]
pub fn empty() -> Model {
let wordpiece = tk::models::wordpiece::WordPiece::default();
Model {
model: Some(Arc::new(RwLock::new(wordpiece.into()))),
}
}
#[napi(ts_return_type = "Promise<Model>")]
pub fn from_file(
vocab: String,
options: Option<WordPieceOptions>,
) -> AsyncTask<WordPieceFromFilesTask> {
let options = options.unwrap_or_default();
let mut builder = tk::models::wordpiece::WordPiece::from_file(&vocab);
builder = options.apply_to_wordpiece_builder(builder);
AsyncTask::new(WordPieceFromFilesTask {
builder: Some(builder),
})
}
}
#[derive(Default)]
#[napi(object)]
pub struct WordLevelOptions {
pub unk_token: Option<String>,
}
impl WordLevelOptions {
fn apply_to_wordlevel_builder(self, mut builder: WordLevelBuilder) -> WordLevelBuilder {
if let Some(token) = self.unk_token {
builder = builder.unk_token(token);
}
builder
}
}
#[napi]
pub struct WordLevel {}
#[napi]
impl WordLevel {
#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vocab, options: Option<WordLevelOptions>) -> Result<Model> {
let options = options.unwrap_or_default();
let mut builder = tk::models::wordlevel::WordLevel::builder().vocab(vocab);
builder = options.apply_to_wordlevel_builder(builder);
let model = builder
.build()
.map_err(|e| Error::from_reason(e.to_string()))?;
Ok(Model {
model: Some(Arc::new(RwLock::new(model.into()))),
})
}
#[napi(factory)]
pub fn empty() -> Model {
let wordlevel = tk::models::wordlevel::WordLevel::default();
Model {
model: Some(Arc::new(RwLock::new(wordlevel.into()))),
}
}
#[napi(ts_return_type = "Promise<Model>")]
pub fn from_file(
vocab: String,
options: Option<WordLevelOptions>,
) -> AsyncTask<WordLevelFromFilesTask> {
let options = options.unwrap_or_default();
let mut builder = tk::models::wordlevel::WordLevel::builder().files(vocab);
builder = options.apply_to_wordlevel_builder(builder);
AsyncTask::new(WordLevelFromFilesTask {
builder: Some(builder),
})
}
}
#[derive(Default)]
#[napi(object)]
pub struct UnigramOptions {
pub unk_id: Option<u32>,
pub byte_fallback: Option<bool>,
}
#[napi]
pub struct Unigram {}
#[napi]
impl Unigram {
#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vec<(String, f64)>, options: Option<UnigramOptions>) -> Result<Model> {
let options = options.unwrap_or_default();
let unigram = tk::models::unigram::Unigram::from(
vocab,
options.unk_id.map(|u| u as usize),
options.byte_fallback.unwrap_or(false),
)
.map_err(|e| Error::from_reason(e.to_string()))?;
Ok(Model {
model: Some(Arc::new(RwLock::new(unigram.into()))),
})
}
#[napi(factory, ts_return_type = "Model")]
pub fn empty() -> Model {
let unigram = tk::models::unigram::Unigram::default();
Model {
model: Some(Arc::new(RwLock::new(unigram.into()))),
}
}
}