| use crate::processors::byte_level::process_offsets; |
| use crate::tokenizer::{Encoding, PostProcessor, Result}; |
| use serde::{Deserialize, Serialize}; |
| use std::collections::HashMap; |
| use std::iter::FromIterator; |
|
|
| #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] |
| #[serde(tag = "type")] |
| pub struct RobertaProcessing { |
| sep: (String, u32), |
| cls: (String, u32), |
| trim_offsets: bool, |
| add_prefix_space: bool, |
| } |
|
|
| impl Default for RobertaProcessing { |
| fn default() -> Self { |
| Self { |
| sep: ("</s>".into(), 2), |
| cls: ("<s>".into(), 0), |
| trim_offsets: true, |
| add_prefix_space: true, |
| } |
| } |
| } |
|
|
| impl RobertaProcessing { |
| pub fn new(sep: (String, u32), cls: (String, u32)) -> Self { |
| Self { |
| sep, |
| cls, |
| ..Default::default() |
| } |
| } |
|
|
| #[must_use] |
| pub fn trim_offsets(mut self, v: bool) -> Self { |
| self.trim_offsets = v; |
| self |
| } |
|
|
| #[must_use] |
| pub fn add_prefix_space(mut self, v: bool) -> Self { |
| self.add_prefix_space = v; |
| self |
| } |
| } |
|
|
| impl PostProcessor for RobertaProcessing { |
| fn added_tokens(&self, is_pair: bool) -> usize { |
| if is_pair { |
| 4 |
| } else { |
| 2 |
| } |
| } |
|
|
| fn process_encodings( |
| &self, |
| mut encodings: Vec<Encoding>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> { |
| if self.trim_offsets { |
| for encoding in encodings.iter_mut() { |
| process_offsets(encoding, self.add_prefix_space); |
| encoding |
| .get_overflowing_mut() |
| .iter_mut() |
| .for_each(|encoding| process_offsets(encoding, self.add_prefix_space)); |
| } |
| } |
|
|
| |
| encodings |
| .iter_mut() |
| .for_each(|encoding| encoding.set_type_ids(vec![0; encoding.len()])); |
|
|
| if !add_special_tokens { |
| return Ok(encodings); |
| } |
|
|
| let encodings: Vec<Encoding> = encodings |
| .iter_mut() |
| .enumerate() |
| .map(|(i, encoding)| { |
| if i == 0 { |
| let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); |
| let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); |
| let tokens = [ |
| &[self.cls.0.clone()], |
| encoding.get_tokens(), |
| &[self.sep.0.clone()], |
| ] |
| .concat(); |
| let words = [&[None], encoding.get_word_ids(), &[None]].concat(); |
| let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); |
| let special_tokens = |
| [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat(); |
| let attention_mask = vec![1; ids.len()]; |
|
|
| |
| |
| let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); |
| Encoding::new( |
| ids, |
| type_ids, |
| tokens, |
| words, |
| offsets, |
| special_tokens, |
| attention_mask, |
| encoding |
| .take_overflowing() |
| .into_iter() |
| .map(|encoding| { |
| let ids = |
| [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); |
| let type_ids = vec![0; encoding.get_ids().len() + 2]; |
| let tokens = [ |
| &[self.cls.0.clone()], |
| encoding.get_tokens(), |
| &[self.sep.0.clone()], |
| ] |
| .concat(); |
| let words = [&[None], encoding.get_word_ids(), &[None]].concat(); |
| let offsets = |
| [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); |
| let special_tokens = |
| [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]] |
| .concat(); |
| let attention_mask = vec![1; ids.len()]; |
|
|
| |
| |
| let sequence_ranges = |
| HashMap::from_iter(vec![(0, 1..ids.len() - 1)]); |
| Encoding::new( |
| ids, |
| type_ids, |
| tokens, |
| words, |
| offsets, |
| special_tokens, |
| attention_mask, |
| vec![], |
| sequence_ranges, |
| ) |
| }) |
| .collect(), |
| sequence_ranges, |
| ) |
| } else { |
| let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); |
| let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; |
| let pair_tokens = [ |
| &[self.sep.0.clone()], |
| encoding.get_tokens(), |
| &[self.sep.0.clone()], |
| ] |
| .concat(); |
| let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat(); |
| let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); |
| let pair_special_tokens = |
| [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat(); |
| let pair_attention_mask = vec![1; pair_ids.len()]; |
|
|
| |
| |
| let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); |
| Encoding::new( |
| pair_ids, |
| pair_type_ids, |
| pair_tokens, |
| pair_words, |
| pair_offsets, |
| pair_special_tokens, |
| pair_attention_mask, |
| encoding |
| .take_overflowing() |
| .into_iter() |
| .map(|encoding| { |
| let pair_ids = |
| [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); |
| let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; |
| let pair_tokens = [ |
| &[self.sep.0.clone()], |
| encoding.get_tokens(), |
| &[self.sep.0.clone()], |
| ] |
| .concat(); |
| let pair_words = |
| [&[None], encoding.get_word_ids(), &[None]].concat(); |
| let pair_offsets = |
| [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat(); |
| let pair_special_tokens = |
| [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]] |
| .concat(); |
| let pair_attention_mask = vec![1; pair_ids.len()]; |
|
|
| |
| |
| let pair_sequence_ranges = |
| HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]); |
| Encoding::new( |
| pair_ids, |
| pair_type_ids, |
| pair_tokens, |
| pair_words, |
| pair_offsets, |
| pair_special_tokens, |
| pair_attention_mask, |
| vec![], |
| pair_sequence_ranges, |
| ) |
| }) |
| .collect(), |
| pair_sequence_ranges, |
| ) |
| } |
| }) |
| .collect(); |
|
|
| Ok(encodings) |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn serde() { |
| let roberta = RobertaProcessing::default(); |
| let roberta_r = r#"{ |
| "type":"RobertaProcessing", |
| "sep":["</s>",2], |
| "cls":["<s>",0], |
| "trim_offsets":true, |
| "add_prefix_space":true |
| }"# |
| .replace(char::is_whitespace, ""); |
| assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r); |
| assert_eq!( |
| serde_json::from_str::<RobertaProcessing>(&roberta_r).unwrap(), |
| roberta |
| ); |
| } |
|
|
| #[test] |
| fn roberta_processing() { |
| let processor = RobertaProcessing::default(); |
| assert_eq!(processor.added_tokens(false), 2); |
| assert_eq!(processor.added_tokens(true), 4); |
|
|
| use crate::Token; |
| let encoding = Encoding::from_tokens( |
| vec![ |
| Token::new(12, "Hello".into(), (0, 5)), |
| Token::new(14, "there".into(), (6, 11)), |
| ], |
| 0, |
| ); |
| let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); |
| let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); |
| assert_eq!( |
| single_encoding, |
| Encoding::new( |
| vec![0, 12, 14, 2], |
| vec![0, 0, 0, 0], |
| vec!["<s>".into(), "Hello".into(), "there".into(), "</s>".into()], |
| vec![None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0)], |
| vec![1, 0, 0, 1], |
| vec![1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..3)]), |
| ) |
| ); |
| assert_eq!(single_encoding.token_to_sequence(2), Some(0)); |
| assert_eq!(single_encoding.token_to_sequence(3), None); |
| let pair_encoding = processor |
| .process(encoding.clone(), Some(pair.clone()), true) |
| .unwrap(); |
| assert_eq!( |
| pair_encoding, |
| Encoding::new( |
| vec![0, 12, 14, 2, 2, 15, 2], |
| vec![0, 0, 0, 0, 0, 0, 0], |
| vec![ |
| "<s>".into(), |
| "Hello".into(), |
| "there".into(), |
| "</s>".into(), |
| "</s>".into(), |
| "pair".into(), |
| "</s>".into() |
| ], |
| vec![None, None, None, None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 0), (0, 4), (0, 0)], |
| vec![1, 0, 0, 1, 1, 0, 1], |
| vec![1, 1, 1, 1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..3), (1, 5..6)]), |
| ) |
| ); |
| assert_eq!(pair_encoding.token_to_sequence(2), Some(0)); |
| assert_eq!(pair_encoding.token_to_sequence(3), None); |
| assert_eq!(pair_encoding.token_to_sequence(4), None); |
| assert_eq!(pair_encoding.token_to_sequence(5), Some(1)); |
| assert_eq!(pair_encoding.token_to_sequence(6), None); |
|
|
| |
| let pair_encoding = processor.process(encoding, Some(pair), false).unwrap(); |
| assert_eq!( |
| pair_encoding, |
| Encoding::new( |
| vec![12, 14, 15], |
| vec![0, 0, 0], |
| vec!["Hello".into(), "there".into(), "pair".into(),], |
| vec![None, None, None], |
| vec![(0, 5), (6, 11), (0, 4)], |
| vec![0, 0, 0], |
| vec![1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 0..2), (1, 2..3)]), |
| ) |
| ); |
| assert_eq!(pair_encoding.token_to_sequence(0), Some(0)); |
| assert_eq!(pair_encoding.token_to_sequence(1), Some(0)); |
| assert_eq!(pair_encoding.token_to_sequence(2), Some(1)); |
| } |
| } |
|
|