use crate::processors::PostProcessorWrapper; use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::utils::macro_rules_attribute; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq)] #[macro_rules_attribute(impl_serde_type!)] pub struct Sequence { processors: Vec, } impl Sequence { pub fn new(processors: Vec) -> Self { Self { processors } } } impl PostProcessor for Sequence { fn added_tokens(&self, is_pair: bool) -> usize { self.processors .iter() .map(|p| p.added_tokens(is_pair)) .sum::() } fn process_encodings( &self, mut encodings: Vec, add_special_tokens: bool, ) -> Result> { for processor in &self.processors { encodings = processor.process_encodings(encodings, add_special_tokens)?; } Ok(encodings) } } #[cfg(test)] mod tests { use super::*; use crate::processors::{ByteLevel, PostProcessorWrapper}; use crate::tokenizer::{Encoding, PostProcessor}; use std::collections::HashMap; use std::iter::FromIterator; #[test] fn process_chain() { let start = Encoding::new( vec![0; 5], vec![0; 5], vec![ "Ġ".into(), "ĠĠĠĠHelloĠĠ".into(), "ĠĠHello".into(), "HelloĠĠ".into(), "ĠĠĠĠ".into(), ], vec![], vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)], vec![], vec![], vec![], HashMap::new(), ); let bytelevel = ByteLevel::default().trim_offsets(true); let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]); let expected = Encoding::new( vec![0; 5], vec![0; 5], vec![ "Ġ".into(), "ĠĠĠĠHelloĠĠ".into(), "ĠĠHello".into(), "HelloĠĠ".into(), "ĠĠĠĠ".into(), ], vec![], vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)], vec![], vec![], vec![], HashMap::from_iter(vec![(0, 0..5)]), ); assert_eq!( expected, bytelevel.process(start.clone(), None, false).unwrap() ); assert_eq!( expected, sequence.process(start.clone(), None, false).unwrap() ); let pair_expected = Encoding::new( vec![0; 10], vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1], vec![ "Ġ".into(), "ĠĠĠĠHelloĠĠ".into(), "ĠĠHello".into(), "HelloĠĠ".into(), "ĠĠĠĠ".into(), "Ġ".into(), "ĠĠĠĠHelloĠĠ".into(), "ĠĠHello".into(), "HelloĠĠ".into(), "ĠĠĠĠ".into(), ], vec![], vec![ (0, 0), (4, 9), (13, 18), (18, 23), (29, 29), (0, 0), (4, 9), (13, 18), (18, 23), (29, 29), ], vec![], vec![], vec![], HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]), ); assert_eq!( pair_expected, bytelevel .process(start.clone(), Some(start.clone()), false) .unwrap() ); assert_eq!( pair_expected, sequence.process(start.clone(), Some(start), false).unwrap() ); } }