| use crate::parallelism::*; |
| use crate::tokenizer::{Offsets, Token}; |
| use crate::utils::padding::PaddingDirection; |
| use crate::utils::truncation::TruncationDirection; |
| use serde::{Deserialize, Serialize}; |
| use std::collections::HashMap; |
| use std::ops::Range; |
|
|
| |
| #[derive(Default, PartialEq, Debug, Clone, Serialize, Deserialize)] |
| pub struct Encoding { |
| |
| ids: Vec<u32>, |
| |
| type_ids: Vec<u32>, |
| |
| tokens: Vec<String>, |
| |
| words: Vec<Option<u32>>, |
| |
| offsets: Vec<Offsets>, |
| |
| special_tokens_mask: Vec<u32>, |
| |
| attention_mask: Vec<u32>, |
| |
| overflowing: Vec<Encoding>, |
| |
| |
| sequence_ranges: HashMap<usize, Range<usize>>, |
| } |
| impl Encoding { |
| #[allow(clippy::too_many_arguments)] |
| pub fn new( |
| ids: Vec<u32>, |
| type_ids: Vec<u32>, |
| tokens: Vec<String>, |
| words: Vec<Option<u32>>, |
| offsets: Vec<Offsets>, |
| special_tokens_mask: Vec<u32>, |
| attention_mask: Vec<u32>, |
| overflowing: Vec<Self>, |
| sequence_ranges: HashMap<usize, Range<usize>>, |
| ) -> Self { |
| Self { |
| ids, |
| type_ids, |
| tokens, |
| words, |
| offsets, |
| special_tokens_mask, |
| attention_mask, |
| overflowing, |
| sequence_ranges, |
| } |
| } |
|
|
| pub fn with_capacity(len: usize) -> Self { |
| Self { |
| ids: Vec::with_capacity(len), |
| type_ids: Vec::with_capacity(len), |
| tokens: Vec::with_capacity(len), |
| words: Vec::with_capacity(len), |
| offsets: Vec::with_capacity(len), |
| special_tokens_mask: Vec::with_capacity(len), |
| attention_mask: Vec::with_capacity(len), |
| overflowing: vec![], |
| sequence_ranges: HashMap::new(), |
| } |
| } |
|
|
| pub fn from_tokens(tokens: Vec<Token>, type_id: u32) -> Self { |
| let length = tokens.len(); |
| let (ids, tokens, offsets) = tokens.into_iter().fold( |
| ( |
| Vec::with_capacity(length), |
| Vec::with_capacity(length), |
| Vec::with_capacity(length), |
| ), |
| |(mut ids, mut tokens, mut offsets), t| { |
| ids.push(t.id); |
| tokens.push(t.value); |
| offsets.push(t.offsets); |
| (ids, tokens, offsets) |
| }, |
| ); |
|
|
| Self { |
| ids, |
| tokens, |
| offsets, |
| words: vec![None; length], |
| type_ids: vec![type_id; length], |
| attention_mask: vec![1; length], |
| special_tokens_mask: vec![0; length], |
| overflowing: vec![], |
| sequence_ranges: HashMap::new(), |
| } |
| } |
|
|
| |
| pub fn is_empty(&self) -> bool { |
| self.ids.is_empty() |
| } |
|
|
| |
| pub fn len(&self) -> usize { |
| self.ids.len() |
| } |
|
|
| |
| pub fn n_sequences(&self) -> usize { |
| if self.sequence_ranges.is_empty() { |
| 1 |
| } else { |
| self.sequence_ranges.len() |
| } |
| } |
|
|
| |
| pub fn set_sequence_id(&mut self, sequence_id: usize) { |
| self.sequence_ranges.insert(sequence_id, 0..self.len()); |
| } |
|
|
| pub fn get_tokens(&self) -> &[String] { |
| &self.tokens[..] |
| } |
|
|
| pub fn get_word_ids(&self) -> &[Option<u32>] { |
| &self.words |
| } |
|
|
| pub fn get_word_ids_mut(&mut self) -> &mut [Option<u32>] { |
| &mut self.words |
| } |
|
|
| pub fn get_sequence_ids(&self) -> Vec<Option<usize>> { |
| let mut sequences = vec![None; self.len()]; |
| for seq_id in 0..self.n_sequences() { |
| let range = self.sequence_range(seq_id); |
| let seq_len = range.len(); |
| sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len)); |
| } |
| sequences |
| } |
|
|
| pub fn get_ids(&self) -> &[u32] { |
| &self.ids |
| } |
|
|
| pub fn get_type_ids(&self) -> &[u32] { |
| &self.type_ids |
| } |
|
|
| pub fn set_type_ids(&mut self, type_ids: Vec<u32>) { |
| self.type_ids = type_ids; |
| } |
|
|
| pub fn get_offsets(&self) -> &[Offsets] { |
| &self.offsets |
| } |
|
|
| pub fn get_offsets_mut(&mut self) -> &mut [Offsets] { |
| &mut self.offsets |
| } |
|
|
| pub fn get_special_tokens_mask(&self) -> &[u32] { |
| &self.special_tokens_mask |
| } |
|
|
| pub fn get_attention_mask(&self) -> &[u32] { |
| &self.attention_mask |
| } |
|
|
| pub fn get_overflowing(&self) -> &Vec<Encoding> { |
| &self.overflowing |
| } |
|
|
| pub fn set_overflowing(&mut self, overflowing: Vec<Encoding>) { |
| self.overflowing = overflowing; |
| } |
|
|
| pub fn get_overflowing_mut(&mut self) -> &mut Vec<Encoding> { |
| &mut self.overflowing |
| } |
|
|
| pub fn take_overflowing(&mut self) -> Vec<Encoding> { |
| std::mem::take(&mut self.overflowing) |
| } |
|
|
| pub(crate) fn process_tokens_with_offsets_mut<F>(&mut self, func: F) |
| where |
| F: FnMut((usize, (&String, &mut Offsets))), |
| { |
| self.tokens |
| .iter() |
| .zip(self.offsets.iter_mut()) |
| .enumerate() |
| .for_each(func) |
| } |
|
|
| |
| |
| fn sequence_range(&self, sequence_id: usize) -> Range<usize> { |
| self.sequence_ranges |
| .get(&sequence_id) |
| .cloned() |
| .unwrap_or(0..self.len()) |
| } |
|
|
| |
| pub fn token_to_sequence(&self, token: usize) -> Option<usize> { |
| if token > self.len() { |
| None |
| } else if self.sequence_ranges.is_empty() { |
| Some(0) |
| } else { |
| self.sequence_ranges.iter().find_map(|(seq_id, range)| { |
| if range.contains(&token) { |
| Some(*seq_id) |
| } else { |
| None |
| } |
| }) |
| } |
| } |
|
|
| |
| |
| pub fn word_to_tokens(&self, word: u32, sequence_id: usize) -> Option<(usize, usize)> { |
| let (mut start, mut end) = (None, None); |
| let sequence_range = self.sequence_range(sequence_id); |
|
|
| self.words |
| .get(sequence_range.clone())? |
| .iter() |
| .enumerate() |
| .take_while(|(_, w)| **w <= Some(word)) |
| .filter(|(_, w)| **w == Some(word)) |
| .for_each(|(i, _)| { |
| if start.is_none() || Some(i) < start { |
| start = Some(i); |
| } |
| if end.is_none() || Some(i) >= end { |
| end = Some(i + 1); |
| } |
| }); |
|
|
| if let (Some(start), Some(end)) = (start, end) { |
| Some((sequence_range.start + start, sequence_range.start + end)) |
| } else { |
| None |
| } |
| } |
|
|
| |
| pub fn word_to_chars(&self, word: u32, sequence_id: usize) -> Option<Offsets> { |
| self.word_to_tokens(word, sequence_id) |
| .and_then(|(start, end)| { |
| if end == 0 { |
| None |
| } else { |
| Some((self.offsets[start].0, self.offsets[end - 1].1)) |
| } |
| }) |
| } |
|
|
| |
| pub fn token_to_chars(&self, token: usize) -> Option<(usize, Offsets)> { |
| Some(( |
| self.token_to_sequence(token)?, |
| self.offsets.get(token).copied()?, |
| )) |
| } |
|
|
| |
| pub fn token_to_word(&self, token: usize) -> Option<(usize, u32)> { |
| Some(( |
| self.token_to_sequence(token)?, |
| self.words.get(token).copied().flatten()?, |
| )) |
| } |
|
|
| |
| pub fn char_to_token(&self, pos: usize, sequence_id: usize) -> Option<usize> { |
| let sequence_range = self.sequence_range(sequence_id); |
|
|
| self.offsets |
| .get(sequence_range.clone())? |
| .iter() |
| .position(|(start, end)| pos >= *start && pos < *end) |
| .map(|pos| sequence_range.start + pos) |
| } |
|
|
| |
| pub fn char_to_word(&self, pos: usize, sequence_id: usize) -> Option<u32> { |
| Some( |
| self.char_to_token(pos, sequence_id) |
| .and_then(|token| self.token_to_word(token))? |
| .1, |
| ) |
| } |
|
|
| |
| |
| |
| pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncationDirection) { |
| let encoding_len = self.ids.len(); |
| if max_len >= encoding_len { |
| return; |
| } |
|
|
| if max_len == 0 { |
| let o = std::mem::replace(self, Encoding::with_capacity(0)); |
| self.overflowing.push(o); |
| return; |
| } |
|
|
| assert!(stride < max_len, "`stride` must be strictly less than `max_len={}` (note that `max_len` may be shorter than the max length of the original model, as it subtracts the number of special characters", max_len); |
|
|
| |
| self.sequence_ranges.clear(); |
|
|
| let offset = max_len - stride; |
| let mut end = false; |
| let parts_ranges: Vec<(usize, usize)> = match direction { |
| TruncationDirection::Right => (0..encoding_len) |
| .step_by(offset) |
| .filter_map(|start| { |
| if !end { |
| let stop = std::cmp::min(start + max_len, encoding_len); |
| end = stop == encoding_len; |
| Some((start, stop)) |
| } else { |
| None |
| } |
| }) |
| .collect(), |
| TruncationDirection::Left => (0..encoding_len) |
| .rev() |
| .step_by(offset) |
| .filter_map(|stop| { |
| let stop = stop + 1; |
| let start = if stop < max_len { 0 } else { stop - max_len }; |
| if start < stop && !end { |
| end = start == 0; |
| Some((start, stop)) |
| } else { |
| None |
| } |
| }) |
| .collect(), |
| }; |
|
|
| let mut i = 0; |
| let (start, stop) = parts_ranges[i]; |
| let mut new_encoding = Encoding { |
| ids: self.ids[start..stop].to_vec(), |
| type_ids: self.type_ids[start..stop].to_vec(), |
| tokens: self.tokens[start..stop].to_vec(), |
| words: self.words[start..stop].to_vec(), |
| offsets: self.offsets[start..stop].to_vec(), |
| special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(), |
| attention_mask: self.attention_mask[start..stop].to_vec(), |
| overflowing: vec![], |
| sequence_ranges: HashMap::new(), |
| }; |
|
|
| loop { |
| if i == parts_ranges.len() - 1 { |
| break; |
| } |
| i += 1; |
| let (start, stop) = parts_ranges[i]; |
| new_encoding.overflowing.push(Encoding { |
| ids: self.ids[start..stop].to_vec(), |
| type_ids: self.type_ids[start..stop].to_vec(), |
| tokens: self.tokens[start..stop].to_vec(), |
| words: self.words[start..stop].to_vec(), |
| offsets: self.offsets[start..stop].to_vec(), |
| special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(), |
| attention_mask: self.attention_mask[start..stop].to_vec(), |
| overflowing: vec![], |
| sequence_ranges: HashMap::new(), |
| }); |
| } |
| *self = new_encoding; |
| } |
|
|
| |
| pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self { |
| let mut encoding = Encoding::default(); |
|
|
| |
| |
| |
|
|
| |
| |
| for sub in encodings { |
| encoding.merge_with(sub, growing_offsets); |
| } |
|
|
| encoding |
| } |
|
|
| |
| pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) { |
| |
| |
| let mut overflowings = vec![]; |
|
|
| |
| for self_o in &self.overflowing { |
| |
| let mut n_encoding = self_o.clone(); |
| n_encoding.merge_with(pair.clone(), growing_offsets); |
| overflowings.push(n_encoding); |
|
|
| |
| for other_o in &pair.overflowing { |
| let mut n_encoding = self_o.clone(); |
| n_encoding.merge_with(other_o.clone(), growing_offsets); |
| overflowings.push(n_encoding); |
| } |
| } |
| |
| for other_o in &pair.overflowing { |
| let mut n_encoding = self.clone(); |
| n_encoding.merge_with(other_o.clone(), growing_offsets); |
| overflowings.push(n_encoding); |
| } |
|
|
| |
| let original_self_len = self.len(); |
|
|
| self.sequence_ranges |
| .extend(pair.sequence_ranges.into_iter().map(|(seq_id, range)| { |
| ( |
| seq_id, |
| original_self_len + range.start..original_self_len + range.end, |
| ) |
| })); |
| self.ids.extend(pair.ids); |
| self.type_ids.extend(pair.type_ids); |
| self.tokens.extend(pair.tokens); |
| self.words.extend(pair.words); |
|
|
| let starting_offset = if growing_offsets { |
| self.offsets.last().map_or(0, |o| o.1) |
| } else { |
| 0 |
| }; |
| self.offsets.extend( |
| pair.offsets |
| .into_iter() |
| .map(|(start, end)| (start + starting_offset, end + starting_offset)) |
| .collect::<Vec<_>>(), |
| ); |
| self.special_tokens_mask.extend(pair.special_tokens_mask); |
| self.attention_mask.extend(pair.attention_mask); |
| self.overflowing = overflowings; |
| } |
|
|
| pub fn pad( |
| &mut self, |
| target_length: usize, |
| pad_id: u32, |
| pad_type_id: u32, |
| pad_token: &str, |
| direction: PaddingDirection, |
| ) { |
| |
| self.overflowing.maybe_par_iter_mut().for_each(|encoding| { |
| encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction) |
| }); |
|
|
| |
| if self.ids.len() >= target_length { |
| |
| return; |
| } |
| let pad_length = target_length - self.ids.len(); |
|
|
| match direction { |
| PaddingDirection::Left => { |
| self.ids = (0..pad_length) |
| .map(|_| pad_id) |
| .chain(self.ids.drain(..)) |
| .collect(); |
| self.type_ids = (0..pad_length) |
| .map(|_| pad_type_id) |
| .chain(self.type_ids.drain(..)) |
| .collect(); |
| self.tokens = (0..pad_length) |
| .map(|_| pad_token.to_owned()) |
| .chain(self.tokens.drain(..)) |
| .collect(); |
| self.words = (0..pad_length) |
| .map(|_| None) |
| .chain(self.words.drain(..)) |
| .collect(); |
| self.attention_mask = (0..pad_length) |
| .map(|_| 0) |
| .chain(self.attention_mask.drain(..)) |
| .collect(); |
| self.special_tokens_mask = (0..pad_length) |
| .map(|_| 1) |
| .chain(self.special_tokens_mask.drain(..)) |
| .collect(); |
| self.offsets = (0..pad_length) |
| .map(|_| (0, 0)) |
| .chain(self.offsets.drain(..)) |
| .collect(); |
| self.sequence_ranges |
| .iter_mut() |
| .for_each(|(_seq_id, range)| { |
| *range = (range.start + pad_length)..(range.end + pad_length) |
| }); |
| } |
| PaddingDirection::Right => { |
| self.ids.extend((0..pad_length).map(|_| pad_id)); |
| self.type_ids.extend((0..pad_length).map(|_| pad_type_id)); |
| self.tokens |
| .extend((0..pad_length).map(|_| pad_token.to_owned())); |
| self.words.extend((0..pad_length).map(|_| None)); |
| self.attention_mask.extend((0..pad_length).map(|_| 0)); |
| self.special_tokens_mask.extend((0..pad_length).map(|_| 1)); |
| self.offsets.extend((0..pad_length).map(|_| (0, 0))); |
| } |
| } |
| } |
| } |
|
|
| impl std::iter::FromIterator<Encoding> for Encoding { |
| fn from_iter<I: IntoIterator<Item = Encoding>>(iter: I) -> Self { |
| Self::merge(iter, false) |
| } |
| } |
|
|
| impl std::iter::FromIterator<(u32, String, (usize, usize), Option<u32>, u32)> for Encoding { |
| fn from_iter<I: IntoIterator<Item = (u32, String, (usize, usize), Option<u32>, u32)>>( |
| iter: I, |
| ) -> Self { |
| let items = iter.into_iter(); |
| let (lower, upper) = items.size_hint(); |
| let length = upper.unwrap_or(lower); |
| let mut encoding = Self::with_capacity(length); |
|
|
| for (id, token, offsets, word, type_id) in items { |
| encoding.ids.push(id); |
| encoding.tokens.push(token); |
| encoding.offsets.push(offsets); |
| encoding.type_ids.push(type_id); |
| encoding.words.push(word); |
| encoding.special_tokens_mask.push(0); |
| encoding.attention_mask.push(1); |
| } |
|
|
| encoding |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use std::iter::FromIterator; |
|
|
| #[test] |
| fn merge_encodings() { |
| let mut a = Encoding { |
| ids: vec![1], |
| type_ids: vec![0], |
| tokens: vec![String::from("Hello ")], |
| words: vec![Some(0)], |
| offsets: vec![(0, 6)], |
| special_tokens_mask: vec![0], |
| attention_mask: vec![1], |
| ..Default::default() |
| }; |
| let b = Encoding { |
| ids: vec![2], |
| type_ids: vec![1], |
| tokens: vec![String::from("World!")], |
| words: vec![Some(0)], |
| offsets: vec![(0, 6)], |
| special_tokens_mask: vec![0], |
| attention_mask: vec![1], |
| ..Default::default() |
| }; |
| a.merge_with(b, true); |
|
|
| assert_eq!( |
| a, |
| Encoding { |
| ids: vec![1, 2], |
| type_ids: vec![0, 1], |
| tokens: vec![String::from("Hello "), String::from("World!")], |
| words: vec![Some(0), Some(0)], |
| offsets: vec![(0, 6), (6, 12)], |
| special_tokens_mask: vec![0, 0], |
| attention_mask: vec![1, 1], |
| ..Default::default() |
| } |
| ); |
| } |
|
|
| #[test] |
| fn truncate() { |
| let mut a = Encoding { |
| ids: vec![1, 2, 3], |
| type_ids: vec![0, 0, 0], |
| tokens: vec![ |
| String::from("Hello"), |
| String::from("World"), |
| String::from("!"), |
| ], |
| words: vec![Some(0), Some(1), Some(2)], |
| offsets: vec![(0, 5), (6, 11), (11, 12)], |
| special_tokens_mask: vec![0, 0, 0], |
| attention_mask: vec![1, 1, 1], |
| ..Default::default() |
| }; |
| a.truncate(2, 0, TruncationDirection::Right); |
|
|
| assert_eq!( |
| a, |
| Encoding { |
| ids: vec![1, 2], |
| type_ids: vec![0, 0], |
| tokens: vec![String::from("Hello"), String::from("World")], |
| words: vec![Some(0), Some(1)], |
| offsets: vec![(0, 5), (6, 11)], |
| special_tokens_mask: vec![0, 0], |
| attention_mask: vec![1, 1], |
| overflowing: vec![Encoding { |
| ids: vec![3], |
| type_ids: vec![0], |
| tokens: vec![String::from("!")], |
| words: vec![Some(2)], |
| offsets: vec![(11, 12)], |
| special_tokens_mask: vec![0], |
| attention_mask: vec![1], |
| ..Default::default() |
| }], |
| ..Default::default() |
| } |
| ); |
| } |
|
|
| #[test] |
| fn truncate_to_empty() { |
| let mut a = Encoding { |
| ids: vec![1, 2, 3], |
| type_ids: vec![0, 0, 0], |
| tokens: vec![ |
| String::from("Hello"), |
| String::from("World"), |
| String::from("!"), |
| ], |
| words: vec![Some(0), Some(1), Some(2)], |
| offsets: vec![(0, 5), (6, 11), (11, 12)], |
| special_tokens_mask: vec![0, 0, 0], |
| attention_mask: vec![1, 1, 1], |
| ..Default::default() |
| }; |
| a.truncate(0, 0, TruncationDirection::Right); |
|
|
| assert_eq!( |
| a, |
| Encoding { |
| overflowing: vec![Encoding { |
| ids: vec![1, 2, 3], |
| type_ids: vec![0, 0, 0], |
| tokens: vec![ |
| String::from("Hello"), |
| String::from("World"), |
| String::from("!"), |
| ], |
| words: vec![Some(0), Some(1), Some(2)], |
| offsets: vec![(0, 5), (6, 11), (11, 12)], |
| special_tokens_mask: vec![0, 0, 0], |
| attention_mask: vec![1, 1, 1], |
| overflowing: vec![], |
| ..Default::default() |
| }], |
| ..Default::default() |
| } |
| ); |
| } |
|
|
| #[test] |
| fn truncate_overflow_with_stride() { |
| let mut enc = Encoding { |
| ids: vec![1, 2, 3, 4, 5], |
| type_ids: vec![0, 0, 0, 0, 0], |
| tokens: vec![ |
| String::from("42"), |
| String::from("is"), |
| String::from("the"), |
| String::from("answer"), |
| String::from("!"), |
| ], |
| words: vec![Some(0), Some(1), Some(2), Some(3), Some(4)], |
| offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)], |
| special_tokens_mask: vec![0, 0, 0, 0, 0], |
| attention_mask: vec![1, 1, 1, 1, 1], |
| overflowing: vec![], |
| ..Default::default() |
| }; |
| enc.truncate(4, 2, TruncationDirection::Right); |
|
|
| assert_eq!( |
| enc, |
| Encoding { |
| ids: vec![1, 2, 3, 4], |
| type_ids: vec![0, 0, 0, 0], |
| tokens: vec![ |
| String::from("42"), |
| String::from("is"), |
| String::from("the"), |
| String::from("answer"), |
| ], |
| words: vec![Some(0), Some(1), Some(2), Some(3)], |
| offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13)], |
| special_tokens_mask: vec![0, 0, 0, 0], |
| attention_mask: vec![1, 1, 1, 1], |
| overflowing: vec![Encoding { |
| ids: vec![3, 4, 5], |
| type_ids: vec![0, 0, 0], |
| tokens: vec![ |
| String::from("the"), |
| String::from("answer"), |
| String::from("!"), |
| ], |
| words: vec![Some(2), Some(3), Some(4)], |
| offsets: vec![(4, 7), (7, 13), (13, 14)], |
| special_tokens_mask: vec![0, 0, 0], |
| attention_mask: vec![1, 1, 1], |
| overflowing: vec![], |
| ..Default::default() |
| }], |
| ..Default::default() |
| } |
| ); |
| } |
|
|
| #[test] |
| fn truncate_left() { |
| let mut a = Encoding { |
| ids: vec![1, 2, 3], |
| type_ids: vec![0, 0, 0], |
| tokens: vec![ |
| String::from("Hello"), |
| String::from("World"), |
| String::from("!"), |
| ], |
| words: vec![Some(0), Some(1), Some(2)], |
| offsets: vec![(0, 5), (6, 11), (11, 12)], |
| special_tokens_mask: vec![0, 0, 0], |
| attention_mask: vec![1, 1, 1], |
| ..Default::default() |
| }; |
| a.truncate(2, 0, TruncationDirection::Left); |
|
|
| assert_eq!( |
| a, |
| Encoding { |
| ids: vec![2, 3], |
| type_ids: vec![0, 0], |
| tokens: vec![String::from("World"), String::from("!")], |
| words: vec![Some(1), Some(2)], |
| offsets: vec![(6, 11), (11, 12)], |
| special_tokens_mask: vec![0, 0], |
| attention_mask: vec![1, 1], |
| overflowing: vec![Encoding { |
| ids: vec![1], |
| type_ids: vec![0], |
| tokens: vec![String::from("Hello")], |
| words: vec![Some(0)], |
| offsets: vec![(0, 5)], |
| special_tokens_mask: vec![0], |
| attention_mask: vec![1], |
| ..Default::default() |
| }], |
| ..Default::default() |
| } |
| ); |
| } |
|
|
| #[test] |
| fn mappings() { |
| let encoding = Encoding { |
| ids: vec![0; 11], |
| tokens: vec![ |
| |
| "He".into(), |
| "llo".into(), |
| "won".into(), |
| "der".into(), |
| "ful".into(), |
| "friend".into(), |
| "!".into(), |
| |
| "How".into(), |
| "are".into(), |
| "you".into(), |
| "?".into(), |
| ], |
| offsets: vec![ |
| |
| (0, 2), |
| (2, 5), |
| (7, 10), |
| (10, 13), |
| (13, 16), |
| (17, 23), |
| (23, 24), |
| |
| (0, 3), |
| (4, 7), |
| (8, 11), |
| (11, 12), |
| ], |
| words: vec![ |
| |
| Some(0), |
| Some(0), |
| Some(1), |
| Some(1), |
| Some(1), |
| Some(2), |
| Some(3), |
| |
| Some(0), |
| Some(1), |
| Some(2), |
| Some(3), |
| ], |
| sequence_ranges: HashMap::from_iter(vec![(0, 0..7), (1, 7..11)]), |
| ..Default::default() |
| }; |
| assert_eq!(encoding.word_to_tokens(0, 0), Some((0, 2))); |
| assert_eq!(encoding.word_to_tokens(1, 0), Some((2, 5))); |
| assert_eq!(encoding.word_to_tokens(2, 0), Some((5, 6))); |
| assert_eq!(encoding.word_to_tokens(3, 0), Some((6, 7))); |
| assert_eq!(encoding.word_to_tokens(0, 1), Some((7, 8))); |
| assert_eq!(encoding.word_to_tokens(1, 1), Some((8, 9))); |
| assert_eq!(encoding.word_to_tokens(2, 1), Some((9, 10))); |
| assert_eq!(encoding.word_to_tokens(3, 1), Some((10, 11))); |
|
|
| assert_eq!(encoding.word_to_chars(0, 0), Some((0, 5))); |
| assert_eq!(encoding.word_to_chars(1, 0), Some((7, 16))); |
| assert_eq!(encoding.word_to_chars(0, 1), Some((0, 3))); |
| assert_eq!(encoding.word_to_chars(1, 1), Some((4, 7))); |
|
|
| assert_eq!(encoding.token_to_chars(0), Some((0, (0, 2)))); |
| assert_eq!(encoding.token_to_chars(1), Some((0, (2, 5)))); |
| assert_eq!(encoding.token_to_chars(7), Some((1, (0, 3)))); |
| assert_eq!(encoding.token_to_chars(9), Some((1, (8, 11)))); |
|
|
| assert_eq!(encoding.token_to_word(1), Some((0, 0))); |
| assert_eq!(encoding.token_to_word(2), Some((0, 1))); |
| assert_eq!(encoding.token_to_word(7), Some((1, 0))); |
| assert_eq!(encoding.token_to_word(9), Some((1, 2))); |
| assert_eq!(encoding.token_to_word(11), None); |
|
|
| assert_eq!(encoding.char_to_token(3, 0), Some(1)); |
| assert_eq!(encoding.char_to_token(8, 0), Some(2)); |
| assert_eq!(encoding.char_to_token(16, 0), None); |
| assert_eq!(encoding.char_to_token(23, 0), Some(6)); |
| assert_eq!(encoding.char_to_token(2, 1), Some(7)); |
| assert_eq!(encoding.char_to_token(9, 1), Some(9)); |
|
|
| assert_eq!(encoding.char_to_word(3, 0), Some(0)); |
| assert_eq!(encoding.char_to_word(8, 0), Some(1)); |
| assert_eq!(encoding.char_to_word(16, 0), None); |
| assert_eq!(encoding.char_to_word(23, 0), Some(3)); |
| assert_eq!(encoding.char_to_word(2, 1), Some(0)); |
| assert_eq!(encoding.char_to_word(9, 1), Some(2)); |
| } |
|
|
| #[test] |
| fn padding() { |
| let mut a = Encoding { |
| ids: vec![1], |
| type_ids: vec![0], |
| tokens: vec![String::from("Hello ")], |
| words: vec![Some(0)], |
| offsets: vec![(0, 6)], |
| special_tokens_mask: vec![0], |
| attention_mask: vec![1], |
| sequence_ranges: HashMap::from([(0, 0..1)]), |
| ..Default::default() |
| }; |
| let target_length = 2; |
| let pad_id = 99; |
| let pad_type_id = 0; |
| let pad_token = "[PAD]"; |
| a.pad( |
| target_length, |
| pad_id, |
| pad_type_id, |
| pad_token, |
| PaddingDirection::Left, |
| ); |
| assert_eq!(a.sequence_ranges, HashMap::from([(0, 1..2)])); |
| } |
| } |
|
|