File size: 2,905 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | extern crate tokenizers as tk;
use crate::encoding::*;
use crate::tokenizer::Tokenizer;
use napi::bindgen_prelude::*;
use tk::tokenizer::{EncodeInput, Encoding};
pub struct EncodeTask<'s> {
pub tokenizer: Tokenizer,
pub input: Option<EncodeInput<'s>>,
pub add_special_tokens: bool,
}
impl Task for EncodeTask<'static> {
type Output = Encoding;
type JsValue = JsEncoding;
fn compute(&mut self) -> Result<Self::Output> {
self
.tokenizer
.tokenizer
.read()
.unwrap()
.encode_char_offsets(
self
.input
.take()
.ok_or(Error::from_reason("No provided input"))?,
self.add_special_tokens,
)
.map_err(|e| Error::from_reason(format!("{}", e)))
}
fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
Ok(JsEncoding {
encoding: Some(output),
})
}
}
pub struct DecodeTask {
pub tokenizer: Tokenizer,
pub ids: Vec<u32>,
pub skip_special_tokens: bool,
}
impl Task for DecodeTask {
type Output = String;
type JsValue = String;
fn compute(&mut self) -> Result<Self::Output> {
self
.tokenizer
.tokenizer
.read()
.unwrap()
.decode(&self.ids, self.skip_special_tokens)
.map_err(|e| Error::from_reason(format!("{}", e)))
}
fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
Ok(output)
}
}
pub struct EncodeBatchTask<'s> {
pub tokenizer: Tokenizer,
pub inputs: Option<Vec<EncodeInput<'s>>>,
pub add_special_tokens: bool,
}
impl Task for EncodeBatchTask<'static> {
type Output = Vec<Encoding>;
type JsValue = Vec<JsEncoding>;
fn compute(&mut self) -> Result<Self::Output> {
self
.tokenizer
.tokenizer
.read()
.unwrap()
.encode_batch_char_offsets(
self
.inputs
.take()
.ok_or(Error::from_reason("No provided input"))?,
self.add_special_tokens,
)
.map_err(|e| Error::from_reason(format!("{}", e)))
}
fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
Ok(
output
.into_iter()
.map(|encoding| JsEncoding {
encoding: Some(encoding),
})
.collect(),
)
}
}
pub struct DecodeBatchTask {
pub tokenizer: Tokenizer,
pub ids: Vec<Vec<u32>>,
pub skip_special_tokens: bool,
}
impl Task for DecodeBatchTask {
type Output = Vec<String>;
type JsValue = Vec<String>;
fn compute(&mut self) -> Result<Self::Output> {
let ids: Vec<_> = self.ids.iter().map(|s| s.as_slice()).collect();
self
.tokenizer
.tokenizer
.read()
.unwrap()
.decode_batch(&ids, self.skip_special_tokens)
.map_err(|e| Error::from_reason(format!("{}", e)))
}
fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
Ok(output)
}
}
|