| use std::collections::{HashMap, HashSet}; |
|
|
| use crate::utils::SysRegex; |
| use serde::{Deserialize, Serialize}; |
|
|
| use crate::tokenizer::{ |
| Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result, |
| SplitDelimiterBehavior, |
| }; |
| use crate::utils::macro_rules_attribute; |
|
|
| |
| |
| pub(crate) fn bytes_char() -> HashMap<u8, char> { |
| let mut bs: Vec<u8> = vec![]; |
| bs.extend(b'!'..=b'~'); |
| bs.extend(b'\xA1'..=b'\xAC'); |
| bs.extend(b'\xAE'..=b'\xFF'); |
|
|
| let mut cs: Vec<u32> = bs.iter().map(|i| *i as u32).collect(); |
| let mut n = 0; |
|
|
| for b in 0..=255u8 { |
| if !bs.contains(&b) { |
| bs.push(b); |
| cs.push(u32::pow(2, 8) + n); |
| n += 1; |
| } |
| } |
|
|
| bs.into_iter() |
| .zip(cs) |
| .map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) })) |
| .collect() |
| } |
|
|
| lazy_static! { |
| |
| |
| static ref RE: SysRegex = SysRegex::new( |
| r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" |
| ) |
| .unwrap(); |
| static ref BYTES_CHAR: HashMap<u8, char> = bytes_char(); |
| static ref CHAR_BYTES: HashMap<char, u8> = |
| bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); |
| } |
|
|
| #[derive(Copy, Clone, Debug, PartialEq, Eq)] |
| |
| |
| |
| #[macro_rules_attribute(impl_serde_type!)] |
| #[non_exhaustive] |
| pub struct ByteLevel { |
| |
| |
| pub add_prefix_space: bool, |
| |
| pub trim_offsets: bool, |
|
|
| |
| |
| #[serde(default = "default_true")] |
| pub use_regex: bool, |
| } |
|
|
| fn default_true() -> bool { |
| true |
| } |
|
|
| impl Default for ByteLevel { |
| fn default() -> Self { |
| Self { |
| add_prefix_space: true, |
| trim_offsets: true, |
| use_regex: true, |
| } |
| } |
| } |
|
|
| impl ByteLevel { |
| pub fn new(add_prefix_space: bool, trim_offsets: bool, use_regex: bool) -> Self { |
| Self { |
| add_prefix_space, |
| trim_offsets, |
| use_regex, |
| } |
| } |
|
|
| pub fn alphabet() -> HashSet<char> { |
| BYTES_CHAR.values().copied().collect() |
| } |
|
|
| #[must_use] |
| pub fn add_prefix_space(mut self, v: bool) -> Self { |
| self.add_prefix_space = v; |
| self |
| } |
|
|
| #[must_use] |
| pub fn trim_offsets(mut self, v: bool) -> Self { |
| self.trim_offsets = v; |
| self |
| } |
|
|
| #[must_use] |
| pub fn use_regex(mut self, v: bool) -> Self { |
| self.use_regex = v; |
| self |
| } |
| } |
|
|
| |
| |
| |
| impl PreTokenizer for ByteLevel { |
| fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { |
| let re_ref: &SysRegex = &RE; |
| pretokenized.split(|_, mut normalized| { |
| if self.add_prefix_space && !normalized.get().starts_with(' ') { |
| normalized.prepend(" "); |
| } |
| if self.use_regex { |
| normalized.split(re_ref, SplitDelimiterBehavior::Isolated) |
| } else { |
| Ok(vec![normalized]) |
| } |
| })?; |
| pretokenized.normalize(|normalized| { |
| let s = normalized.get(); |
| let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len()); |
| let mut i = 0; |
| for cur_char in s.chars() { |
| let size = cur_char.len_utf8(); |
| let bytes = s[i..i + size].as_bytes(); |
| i += size; |
| transformations.extend( |
| bytes |
| .iter() |
| .enumerate() |
| .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))), |
| ); |
| } |
| normalized.transform(transformations, 0); |
| Ok(()) |
| }) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| impl Decoder for ByteLevel { |
| fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> { |
| let toks = tokens |
| .into_iter() |
| .flat_map(|t| { |
| t.chars() |
| .try_fold(vec![], |mut acc, c| { |
| CHAR_BYTES.get(&c).map(|b| { |
| acc.push(*b); |
| acc |
| }) |
| }) |
| .unwrap_or_else(|| t.as_bytes().to_vec()) |
| }) |
| .collect::<Vec<u8>>(); |
| Ok(vec![String::from_utf8_lossy(&toks).to_string()]) |
| } |
| } |
|
|
| |
| impl PostProcessor for ByteLevel { |
| fn added_tokens(&self, _is_pair: bool) -> usize { |
| 0 |
| } |
|
|
| fn process_encodings( |
| &self, |
| mut encodings: Vec<Encoding>, |
| _add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> { |
| if self.trim_offsets { |
| for encoding in encodings.iter_mut() { |
| process_offsets(encoding, self.add_prefix_space); |
| encoding |
| .get_overflowing_mut() |
| .iter_mut() |
| .for_each(|encoding| process_offsets(encoding, self.add_prefix_space)); |
| } |
| } |
| for (i, encoding) in encodings.iter_mut().enumerate() { |
| encoding.set_sequence_id(i); |
| } |
| Ok(encodings) |
| |
| } |
| } |
|
|
| pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) { |
| encoding.process_tokens_with_offsets_mut(|(i, (token, offsets))| { |
| let mut leading_spaces = token |
| .chars() |
| .take_while(|c| *c == BYTES_CHAR[&b' '] || c.is_whitespace()) |
| .count(); |
| let trailing_spaces = token |
| .chars() |
| .rev() |
| .take_while(|c| *c == BYTES_CHAR[&b' '] || c.is_whitespace()) |
| .count(); |
|
|
| if leading_spaces > 0 || trailing_spaces > 0 { |
| if leading_spaces > 0 { |
| |
| |
| |
| let is_first = i == 0 || offsets.0 == 0; |
| if is_first && add_prefix_space && leading_spaces == 1 { |
| |
| |
| |
| |
| leading_spaces = 0; |
| } |
| offsets.0 = std::cmp::min(offsets.0 + leading_spaces, offsets.1); |
| } |
| if trailing_spaces > 0 && offsets.1 >= trailing_spaces { |
| offsets.1 = std::cmp::max(offsets.1 - trailing_spaces, offsets.0); |
| } |
| } |
| }); |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::tokenizer::{ |
| Decoder, Encoding, OffsetReferential, OffsetType, PostProcessor, PreTokenizedString, |
| PreTokenizer, |
| }; |
| use std::iter::FromIterator; |
|
|
| #[test] |
| fn pre_tokenization() { |
| let bytelevel = ByteLevel::default().add_prefix_space(false); |
| let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into(); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![ |
| ("Hello", (0, 5)), |
| ("Ġmy", (5, 8)), |
| ("Ġfriend", (8, 15)), |
| (",", (15, 16)), |
| ("Ġhow", (16, 20)), |
| ("Ġis", (20, 23)), |
| ("Ġyour", (23, 28)), |
| ("Ġday", (28, 32)), |
| ("Ġgoing", (32, 38)), |
| ("?", (38, 39)) |
| ] |
| ); |
| } |
|
|
| #[test] |
| fn pre_tokenization_no_regex() { |
| let bytelevel = ByteLevel::default().use_regex(false); |
| let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into(); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![("ĠHelloĠmyĠfriend,ĠhowĠisĠyourĠdayĠgoing?", (0, 39))] |
| ); |
| } |
|
|
| #[test] |
| fn decoding() { |
| let bytelevel = ByteLevel::default().add_prefix_space(false); |
| assert_eq!( |
| bytelevel |
| .decode_chain( |
| vec![ |
| "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", |
| "?" |
| ] |
| .into_iter() |
| .map(|s| s.into()) |
| .collect::<Vec<String>>() |
| ) |
| .unwrap(), |
| vec!["Hello my friend, how is your day going?"] |
| ); |
| } |
|
|
| #[test] |
| fn add_prefix_space() { |
| let bytelevel = ByteLevel::default().add_prefix_space(true); |
| for s in &[ |
| " Hello my friend, how is your day going?", |
| "Hello my friend, how is your day going?", |
| ] { |
| let mut pretokenized = PreTokenizedString::from(*s); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Normalized, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![ |
| ("ĠHello", (0, 7)), |
| ("Ġmy", (7, 11)), |
| ("Ġfriend", (11, 19)), |
| (",", (19, 20)), |
| ("Ġhow", (20, 25)), |
| ("Ġis", (25, 29)), |
| ("Ġyour", (29, 35)), |
| ("Ġday", (35, 40)), |
| ("Ġgoing", (40, 47)), |
| ("?", (47, 48)) |
| ] |
| ); |
| } |
| } |
|
|
| #[test] |
| fn decode_works_on_separated_tokens() { |
| let samples = vec![ |
| "A Nuskhuri abbreviation of იესუ ქრისტე ( iesu kriste ) \" Jesus Christ \"", |
| "An equal number have descenders , like p or q in English \ |
| : გ , დ , ე , ვ , კ , ლ , ჟ , ტ , უ , ფ , ღ , ყ , ც", |
| ]; |
|
|
| let bytelevel = ByteLevel::default().add_prefix_space(false); |
| for sample in samples { |
| let mut pretokenized = PreTokenizedString::from(sample); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
| let separated_tokens = pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .iter() |
| .flat_map(|(s, _, _)| s.split("").map(|t| t.into())) |
| .collect::<Vec<_>>(); |
| assert_eq!( |
| sample, |
| bytelevel.decode_chain(separated_tokens).unwrap().join("") |
| ); |
| } |
| } |
|
|
| #[test] |
| fn handling_of_newlines() { |
| let mut pretokenized = PreTokenizedString::from("Hello there\nHello there"); |
| let bytelevel = ByteLevel::default().add_prefix_space(false); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
|
|
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![ |
| ("Hello", (0, 5)), |
| ("Ġthere", (5, 11)), |
| ("Ċ", (11, 12)), |
| ("Hello", (12, 17)), |
| ("Ġthere", (17, 23)) |
| ] |
| ); |
| } |
|
|
| #[test] |
| fn handling_of_multiple_whitespaces() { |
| let mut pretokenized = PreTokenizedString::from("Hello there dear"); |
| let bytelevel = ByteLevel::default().add_prefix_space(false); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
|
|
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![ |
| ("Hello", (0, 5)), |
| ("Ġthere", (5, 11)), |
| ("ĠĠĠĠĠĠ", (11, 17)), |
| ("Ġdear", (17, 22)) |
| ] |
| ); |
| } |
|
|
| #[test] |
| fn offsets_when_char_split_up() { |
| let input = "i⭢j"; |
| let mut pretokenized = PreTokenizedString::from(input); |
| let bytelevel = ByteLevel::default().add_prefix_space(false); |
| bytelevel.pre_tokenize(&mut pretokenized).unwrap(); |
|
|
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![("i", (0, 1)), ("âŃ¢", (1, 4)), ("j", (4, 5))] |
| ); |
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Normalized, OffsetType::Byte) |
| .into_iter() |
| .map(|(s, o, _)| (s, o)) |
| .collect::<Vec<_>>(), |
| vec![("i", (0, 1)), ("âŃ¢", (1, 7)), ("j", (7, 8))] |
| ); |
| assert_eq!( |
| pretokenized |
| .get_splits(OffsetReferential::Original, OffsetType::Byte) |
| .into_iter() |
| .map(|(_, o, _)| &input[o.0..o.1]) |
| .collect::<Vec<_>>(), |
| vec!["i", "⭢", "j"] |
| ); |
| } |
|
|
| #[test] |
| fn processor_trims_offsets_pre_tokenized() { |
| |
| |
| |
| let mut encoding = Encoding::new( |
| vec![0; 5], |
| vec![], |
| vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()], |
| vec![], |
| vec![(0, 1), (1, 4), (0, 1), (1, 4)], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::new(), |
| ); |
| process_offsets(&mut encoding, true); |
| assert_eq!( |
| encoding, |
| Encoding::new( |
| vec![0; 5], |
| vec![], |
| vec!["Ġl".into(), "ove".into(), "Ġl".into(), "ove".into()], |
| vec![], |
| vec![(0, 1), (1, 4), (0, 1), (1, 4)], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::new(), |
| ) |
| ); |
| } |
|
|
| #[test] |
| fn processor_trims_offsets() { |
| let start = Encoding::new( |
| vec![0; 5], |
| vec![], |
| vec![ |
| "Ġ".into(), |
| "ĠĠĠĠHelloĠĠ".into(), |
| "ĠĠHello".into(), |
| "HelloĠĠ".into(), |
| "ĠĠĠĠ".into(), |
| ], |
| vec![], |
| vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::new(), |
| ); |
| let expected = Encoding::new( |
| vec![0; 5], |
| vec![0; 5], |
| vec![ |
| "Ġ".into(), |
| "ĠĠĠĠHelloĠĠ".into(), |
| "ĠĠHello".into(), |
| "HelloĠĠ".into(), |
| "ĠĠĠĠ".into(), |
| ], |
| vec![], |
| vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::from_iter(vec![(0, 0..5)]), |
| ); |
|
|
| let bytelevel = ByteLevel::default().trim_offsets(true); |
| assert_eq!( |
| expected, |
| bytelevel.process(start.clone(), None, false).unwrap() |
| ); |
|
|
| let pair_expected = Encoding::new( |
| vec![0; 10], |
| vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1], |
| vec![ |
| "Ġ".into(), |
| "ĠĠĠĠHelloĠĠ".into(), |
| "ĠĠHello".into(), |
| "HelloĠĠ".into(), |
| "ĠĠĠĠ".into(), |
| "Ġ".into(), |
| "ĠĠĠĠHelloĠĠ".into(), |
| "ĠĠHello".into(), |
| "HelloĠĠ".into(), |
| "ĠĠĠĠ".into(), |
| ], |
| vec![], |
| vec![ |
| (0, 0), |
| (4, 9), |
| (13, 18), |
| (18, 23), |
| (29, 29), |
| (0, 0), |
| (4, 9), |
| (13, 18), |
| (18, 23), |
| (29, 29), |
| ], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]), |
| ); |
| assert_eq!( |
| pair_expected, |
| bytelevel |
| .process(start.clone(), Some(start), false) |
| .unwrap() |
| ); |
| } |
|
|
| #[test] |
| fn decode_unknown_characters() { |
| let byte_level = ByteLevel::default(); |
| assert_eq!( |
| byte_level |
| .decode_chain(vec![ |
| "Hello".into(), |
| "Ġthere".into(), |
| "Ġdear".into(), |
| "Ġfriend!".into(), |
| "Ġ".into(), |
| "[PA D]".into() |
| ]) |
| .unwrap(), |
| vec!["Hello there dear friend! [PA D]"] |
| ); |
| } |
|
|
| #[test] |
| fn deserialization() { |
| |
| let byte_level: ByteLevel = serde_json::from_str( |
| r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false}"#, |
| ) |
| .unwrap(); |
| assert!(byte_level.use_regex); |
|
|
| |
| let byte_level: ByteLevel = serde_json::from_str( |
| r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#, |
| ) |
| .unwrap(); |
| assert!(byte_level.use_regex); |
|
|
| let byte_level: ByteLevel = serde_json::from_str( |
| r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#, |
| ) |
| .unwrap(); |
| assert!(!byte_level.use_regex); |
| } |
| } |
|
|