| use crate::parallelism::*; |
| use crate::tokenizer::{Encoding, Result}; |
| use serde::{Deserialize, Serialize}; |
|
|
| |
| #[derive(Debug, Clone, Copy, Serialize, Deserialize)] |
| pub enum PaddingDirection { |
| Left, |
| Right, |
| } |
|
|
| impl std::convert::AsRef<str> for PaddingDirection { |
| fn as_ref(&self) -> &str { |
| match self { |
| PaddingDirection::Left => "left", |
| PaddingDirection::Right => "right", |
| } |
| } |
| } |
|
|
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct PaddingParams { |
| pub strategy: PaddingStrategy, |
| pub direction: PaddingDirection, |
| pub pad_to_multiple_of: Option<usize>, |
| pub pad_id: u32, |
| pub pad_type_id: u32, |
| pub pad_token: String, |
| } |
|
|
| impl Default for PaddingParams { |
| fn default() -> Self { |
| Self { |
| strategy: PaddingStrategy::BatchLongest, |
| direction: PaddingDirection::Right, |
| pad_to_multiple_of: None, |
| pad_id: 0, |
| pad_type_id: 0, |
| pad_token: String::from("[PAD]"), |
| } |
| } |
| } |
|
|
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub enum PaddingStrategy { |
| BatchLongest, |
| Fixed(usize), |
| } |
|
|
| pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Result<()> { |
| if encodings.is_empty() { |
| return Ok(()); |
| } |
|
|
| let mut pad_length = match params.strategy { |
| PaddingStrategy::Fixed(size) => size, |
| PaddingStrategy::BatchLongest => encodings |
| .maybe_par_iter() |
| .map(|e| e.get_ids().len()) |
| .max() |
| .unwrap(), |
| }; |
|
|
| if let Some(multiple) = params.pad_to_multiple_of { |
| if multiple > 0 && pad_length % multiple > 0 { |
| pad_length += multiple - pad_length % multiple; |
| } |
| } |
|
|
| encodings.maybe_par_iter_mut().for_each(|encoding| { |
| encoding.pad( |
| pad_length, |
| params.pad_id, |
| params.pad_type_id, |
| ¶ms.pad_token, |
| params.direction, |
| ) |
| }); |
|
|
| Ok(()) |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::tokenizer::Encoding; |
| use std::collections::HashMap; |
|
|
| #[test] |
| fn pad_to_multiple() { |
| fn get_encodings() -> [Encoding; 2] { |
| [ |
| Encoding::new( |
| vec![0, 1, 2, 3, 4], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::new(), |
| ), |
| Encoding::new( |
| vec![0, 1, 2], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| vec![], |
| HashMap::new(), |
| ), |
| ] |
| } |
|
|
| |
| let mut encodings = get_encodings(); |
| let mut params = PaddingParams { |
| strategy: PaddingStrategy::Fixed(7), |
| direction: PaddingDirection::Right, |
| pad_to_multiple_of: Some(8), |
| pad_id: 0, |
| pad_type_id: 0, |
| pad_token: String::from("[PAD]"), |
| }; |
| pad_encodings(&mut encodings, ¶ms).unwrap(); |
| assert!(encodings.iter().all(|e| e.get_ids().len() == 8)); |
|
|
| |
| let mut encodings = get_encodings(); |
| params.strategy = PaddingStrategy::BatchLongest; |
| params.pad_to_multiple_of = Some(6); |
| pad_encodings(&mut encodings, ¶ms).unwrap(); |
| assert!(encodings.iter().all(|e| e.get_ids().len() == 6)); |
|
|
| |
| params.pad_to_multiple_of = Some(0); |
| pad_encodings(&mut encodings, ¶ms).unwrap(); |
| } |
| } |
|
|