File size: 2,905 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
extern crate tokenizers as tk;

use crate::encoding::*;
use crate::tokenizer::Tokenizer;
use napi::bindgen_prelude::*;
use tk::tokenizer::{EncodeInput, Encoding};

pub struct EncodeTask<'s> {
  pub tokenizer: Tokenizer,
  pub input: Option<EncodeInput<'s>>,
  pub add_special_tokens: bool,
}

impl Task for EncodeTask<'static> {
  type Output = Encoding;
  type JsValue = JsEncoding;

  fn compute(&mut self) -> Result<Self::Output> {
    self
      .tokenizer
      .tokenizer
      .read()
      .unwrap()
      .encode_char_offsets(
        self
          .input
          .take()
          .ok_or(Error::from_reason("No provided input"))?,
        self.add_special_tokens,
      )
      .map_err(|e| Error::from_reason(format!("{}", e)))
  }

  fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
    Ok(JsEncoding {
      encoding: Some(output),
    })
  }
}

pub struct DecodeTask {
  pub tokenizer: Tokenizer,
  pub ids: Vec<u32>,
  pub skip_special_tokens: bool,
}

impl Task for DecodeTask {
  type Output = String;
  type JsValue = String;

  fn compute(&mut self) -> Result<Self::Output> {
    self
      .tokenizer
      .tokenizer
      .read()
      .unwrap()
      .decode(&self.ids, self.skip_special_tokens)
      .map_err(|e| Error::from_reason(format!("{}", e)))
  }

  fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
    Ok(output)
  }
}
pub struct EncodeBatchTask<'s> {
  pub tokenizer: Tokenizer,
  pub inputs: Option<Vec<EncodeInput<'s>>>,
  pub add_special_tokens: bool,
}

impl Task for EncodeBatchTask<'static> {
  type Output = Vec<Encoding>;
  type JsValue = Vec<JsEncoding>;

  fn compute(&mut self) -> Result<Self::Output> {
    self
      .tokenizer
      .tokenizer
      .read()
      .unwrap()
      .encode_batch_char_offsets(
        self
          .inputs
          .take()
          .ok_or(Error::from_reason("No provided input"))?,
        self.add_special_tokens,
      )
      .map_err(|e| Error::from_reason(format!("{}", e)))
  }

  fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
    Ok(
      output
        .into_iter()
        .map(|encoding| JsEncoding {
          encoding: Some(encoding),
        })
        .collect(),
    )
  }
}

pub struct DecodeBatchTask {
  pub tokenizer: Tokenizer,
  pub ids: Vec<Vec<u32>>,
  pub skip_special_tokens: bool,
}

impl Task for DecodeBatchTask {
  type Output = Vec<String>;
  type JsValue = Vec<String>;

  fn compute(&mut self) -> Result<Self::Output> {
    let ids: Vec<_> = self.ids.iter().map(|s| s.as_slice()).collect();
    self
      .tokenizer
      .tokenizer
      .read()
      .unwrap()
      .decode_batch(&ids, self.skip_special_tokens)
      .map_err(|e| Error::from_reason(format!("{}", e)))
  }

  fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
    Ok(output)
  }
}