File size: 1,368 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | use crate::models::Model;
use napi_derive::napi;
use std::sync::{Arc, RwLock};
use tokenizers as tk;
use tokenizers::models::TrainerWrapper;
#[napi]
pub struct Trainer {
trainer: Option<Arc<RwLock<TrainerWrapper>>>,
}
impl From<TrainerWrapper> for Trainer {
fn from(trainer: TrainerWrapper) -> Self {
Self {
trainer: Some(Arc::new(RwLock::new(trainer))),
}
}
}
impl tk::Trainer for Trainer {
type Model = Model;
fn should_show_progress(&self) -> bool {
self
.trainer
.as_ref()
.expect("Uninitialized Trainer")
.read()
.unwrap()
.should_show_progress()
}
fn train(&self, model: &mut Self::Model) -> tk::Result<Vec<tk::AddedToken>> {
let special_tokens = self
.trainer
.as_ref()
.ok_or("Uninitialized Trainer")?
.read()
.unwrap()
.train(
&mut model
.model
.as_ref()
.ok_or("Uninitialized Model")?
.write()
.unwrap(),
)?;
Ok(special_tokens)
}
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
{
self
.trainer
.as_ref()
.ok_or("Uninitialized Trainer")?
.write()
.unwrap()
.feed(iterator, process)
}
}
|