File size: 3,278 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use crate::arc_rwlock_serde;
use serde::{Deserialize, Serialize};
extern crate tokenizers as tk;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use std::sync::{Arc, RwLock};

use tk::processors::PostProcessorWrapper;
use tk::Encoding;

#[derive(Clone, Serialize, Deserialize)]
#[napi]
pub struct Processor {
  #[serde(flatten, with = "arc_rwlock_serde")]
  processor: Option<Arc<RwLock<PostProcessorWrapper>>>,
}

impl tk::PostProcessor for Processor {
  fn added_tokens(&self, is_pair: bool) -> usize {
    self
      .processor
      .as_ref()
      .expect("Uninitialized PostProcessor")
      .read()
      .unwrap()
      .added_tokens(is_pair)
  }

  fn process_encodings(
    &self,
    encodings: Vec<Encoding>,
    add_special_tokens: bool,
  ) -> tk::Result<Vec<Encoding>> {
    self
      .processor
      .as_ref()
      .ok_or("Uninitialized PostProcessor")?
      .read()
      .unwrap()
      .process_encodings(encodings, add_special_tokens)
  }
}

#[napi]
pub fn bert_processing(sep: (String, u32), cls: (String, u32)) -> Result<Processor> {
  Ok(Processor {
    processor: Some(Arc::new(RwLock::new(
      tk::processors::bert::BertProcessing::new(sep, cls).into(),
    ))),
  })
}

#[napi]
pub fn roberta_processing(
  sep: (String, u32),
  cls: (String, u32),
  trim_offsets: Option<bool>,
  add_prefix_space: Option<bool>,
) -> Result<Processor> {
  let trim_offsets = trim_offsets.unwrap_or(true);
  let add_prefix_space = add_prefix_space.unwrap_or(true);

  let mut processor = tk::processors::roberta::RobertaProcessing::new(sep, cls);
  processor = processor.trim_offsets(trim_offsets);
  processor = processor.add_prefix_space(add_prefix_space);

  Ok(Processor {
    processor: Some(Arc::new(RwLock::new(processor.into()))),
  })
}

#[napi]
pub fn byte_level_processing(trim_offsets: Option<bool>) -> Result<Processor> {
  let mut byte_level = tk::processors::byte_level::ByteLevel::default();

  if let Some(trim_offsets) = trim_offsets {
    byte_level = byte_level.trim_offsets(trim_offsets);
  }

  Ok(Processor {
    processor: Some(Arc::new(RwLock::new(byte_level.into()))),
  })
}

#[napi]
pub fn template_processing(
  single: String,
  pair: Option<String>,
  special_tokens: Option<Vec<(String, u32)>>,
) -> Result<Processor> {
  let special_tokens = special_tokens.unwrap_or_default();
  let mut builder = tk::processors::template::TemplateProcessing::builder();
  builder.try_single(single).map_err(Error::from_reason)?;
  builder.special_tokens(special_tokens);
  if let Some(pair) = pair {
    builder.try_pair(pair).map_err(Error::from_reason)?;
  }
  let processor = builder
    .build()
    .map_err(|e| Error::from_reason(e.to_string()))?;

  Ok(Processor {
    processor: Some(Arc::new(RwLock::new(processor.into()))),
  })
}

#[napi]
pub fn sequence_processing(processors: Vec<&Processor>) -> Processor {
  let sequence: Vec<tk::PostProcessorWrapper> = processors
    .into_iter()
    .filter_map(|processor| {
      processor
        .processor
        .as_ref()
        .map(|processor| (**processor).read().unwrap().clone())
    })
    .clone()
    .collect();
  Processor {
    processor: Some(Arc::new(RwLock::new(PostProcessorWrapper::Sequence(
      tk::processors::sequence::Sequence::new(sequence),
    )))),
  }
}