| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| use crate::{Encoding, PostProcessor, Result}; |
| use itertools::Itertools; |
| use serde::{Deserialize, Serialize}; |
| use std::collections::{HashMap, HashSet}; |
| use std::convert::{TryFrom, TryInto}; |
| use std::result::Result as StdResult; |
|
|
| |
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] |
| pub enum Sequence { |
| |
| A, |
| |
| B, |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] |
| pub enum Piece { |
| Sequence { id: Sequence, type_id: u32 }, |
| SpecialToken { id: String, type_id: u32 }, |
| } |
|
|
| impl Piece { |
| fn extract_id(s: &str) -> Option<Self> { |
| if s.starts_with('$') { |
| let rest = &s['$'.len_utf8()..]; |
|
|
| |
| match rest { |
| "" => Some(Self::Sequence { |
| id: Sequence::A, |
| type_id: 0, |
| }), |
| "A" | "a" => Some(Self::Sequence { |
| id: Sequence::A, |
| type_id: 0, |
| }), |
| "B" | "b" => Some(Self::Sequence { |
| id: Sequence::B, |
| type_id: 0, |
| }), |
| n => { |
| if let Ok(type_id) = n.parse::<u32>() { |
| Some(Self::Sequence { |
| id: Sequence::A, |
| type_id, |
| }) |
| } else { |
| None |
| } |
| } |
| } |
| } else { |
| Some(Self::SpecialToken { |
| id: s.to_owned(), |
| type_id: 0, |
| }) |
| } |
| } |
|
|
| fn with_type_id(self, type_id: u32) -> Self { |
| match self { |
| Self::Sequence { id, .. } => Self::Sequence { id, type_id }, |
| Self::SpecialToken { id, .. } => Self::SpecialToken { id, type_id }, |
| } |
| } |
| } |
|
|
| impl TryFrom<String> for Piece { |
| type Error = String; |
|
|
| fn try_from(s: String) -> StdResult<Self, Self::Error> { |
| let parts = s.split(':').collect::<Vec<_>>(); |
|
|
| let err = || format!("Cannot build Piece from string \"{s}\""); |
| match parts.as_slice() { |
| [id, type_id] => { |
| let type_id: u32 = type_id.parse().map_err(|_| err())?; |
| let piece = Self::extract_id(id).ok_or_else(err)?; |
| Ok(piece.with_type_id(type_id)) |
| } |
| [id] => Self::extract_id(id).ok_or_else(err), |
| _ => Err(err()), |
| } |
| } |
| } |
|
|
| impl TryFrom<&str> for Piece { |
| type Error = String; |
|
|
| fn try_from(s: &str) -> StdResult<Self, Self::Error> { |
| Piece::try_from(s.to_owned()) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] |
| pub struct SpecialToken { |
| |
| id: String, |
| |
| ids: Vec<u32>, |
| |
| tokens: Vec<String>, |
| } |
|
|
| impl From<(String, u32)> for SpecialToken { |
| fn from(v: (String, u32)) -> Self { |
| Self { |
| id: v.0.clone(), |
| ids: vec![v.1], |
| tokens: vec![v.0], |
| } |
| } |
| } |
| impl From<(&str, u32)> for SpecialToken { |
| fn from(v: (&str, u32)) -> Self { |
| Self::from((v.0.to_owned(), v.1)) |
| } |
| } |
| impl From<(u32, String)> for SpecialToken { |
| fn from(v: (u32, String)) -> Self { |
| Self::from((v.1, v.0)) |
| } |
| } |
| impl From<(u32, &str)> for SpecialToken { |
| fn from(v: (u32, &str)) -> Self { |
| Self::from((v.1.to_owned(), v.0)) |
| } |
| } |
|
|
| impl SpecialToken { |
| pub fn new(id: String, ids: Vec<u32>, tokens: Vec<String>) -> Result<Self> { |
| if ids.len() != tokens.len() { |
| Err("SpecialToken: ids and tokens must be of the same length".into()) |
| } else { |
| Ok(Self { id, ids, tokens }) |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] |
| #[serde(transparent)] |
| pub struct Template(Vec<Piece>); |
|
|
| impl<T> TryFrom<Vec<T>> for Template |
| where |
| T: TryInto<Piece, Error = String>, |
| { |
| type Error = String; |
|
|
| fn try_from(v: Vec<T>) -> StdResult<Self, Self::Error> { |
| Ok(Self( |
| v.into_iter() |
| .map(|p| p.try_into()) |
| .collect::<StdResult<Vec<_>, Self::Error>>()?, |
| )) |
| } |
| } |
|
|
| impl TryFrom<String> for Template { |
| type Error = String; |
|
|
| fn try_from(s: String) -> StdResult<Self, Self::Error> { |
| Self::try_from(s.as_ref()) |
| } |
| } |
|
|
| impl TryFrom<&str> for Template { |
| type Error = String; |
|
|
| fn try_from(s: &str) -> StdResult<Self, Self::Error> { |
| Self::try_from(s.split(' ').collect::<Vec<_>>()) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| #[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, Eq)] |
| #[serde(transparent)] |
| pub struct Tokens( |
| #[serde(serialize_with = "crate::utils::ordered_map")] pub HashMap<String, SpecialToken>, |
| ); |
|
|
| impl<T: Into<SpecialToken>> From<Vec<T>> for Tokens { |
| fn from(v: Vec<T>) -> Self { |
| Self( |
| v.into_iter() |
| .map(|t| { |
| let token: SpecialToken = t.into(); |
| (token.id.clone(), token) |
| }) |
| .collect(), |
| ) |
| } |
| } |
|
|
| impl From<HashMap<String, SpecialToken>> for Tokens { |
| fn from(v: HashMap<String, SpecialToken>) -> Self { |
| Self(v) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq)] |
| #[serde(tag = "type", from = "TemplateProcessingDeserializer")] |
| #[builder(build_fn(validate = "Self::validate"))] |
| pub struct TemplateProcessing { |
| #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] |
| single: Template, |
| #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")] |
| pair: Template, |
| #[builder(setter(skip), default = "self.default_added(true)")] |
| #[serde(skip)] |
| added_single: usize, |
| #[builder(setter(skip), default = "self.default_added(false)")] |
| #[serde(skip)] |
| added_pair: usize, |
| #[builder(setter(into), default)] |
| special_tokens: Tokens, |
| } |
|
|
| impl From<&str> for TemplateProcessingBuilderError { |
| fn from(e: &str) -> Self { |
| e.to_string().into() |
| } |
| } |
|
|
| impl PartialEq for TemplateProcessingBuilderError { |
| fn eq(&self, other: &Self) -> bool { |
| self.to_string() == other.to_string() |
| } |
| } |
|
|
| |
| |
| #[doc(hidden)] |
| #[derive(Deserialize)] |
| #[serde(tag = "type")] |
| struct TemplateProcessingDeserializer { |
| single: Template, |
| pair: Template, |
| special_tokens: Tokens, |
| } |
| impl From<TemplateProcessingDeserializer> for TemplateProcessing { |
| fn from(t: TemplateProcessingDeserializer) -> Self { |
| let added_single = count_added(&t.single, Some(&t.special_tokens)); |
| let added_pair = count_added(&t.pair, Some(&t.special_tokens)); |
| Self { |
| single: t.single, |
| pair: t.pair, |
| added_single, |
| added_pair, |
| special_tokens: t.special_tokens, |
| } |
| } |
| } |
|
|
| |
| fn count_added(container: &Template, special_tokens: Option<&Tokens>) -> usize { |
| container |
| .0 |
| .iter() |
| .map(|p| match p { |
| Piece::Sequence { .. } => 0, |
| Piece::SpecialToken { id, .. } => { |
| special_tokens.map_or(0, |spt| spt.0.get(id).map_or(0, |s| s.ids.len())) |
| } |
| }) |
| .sum() |
| } |
|
|
| impl TemplateProcessingBuilder { |
| fn default_added(&self, is_single: bool) -> usize { |
| let container = if is_single { |
| self.single.as_ref() |
| } else { |
| self.pair.as_ref() |
| }; |
| container.map_or(0, |pieces| { |
| count_added(pieces, self.special_tokens.as_ref()) |
| }) |
| } |
|
|
| fn validate(&self) -> std::result::Result<(), String> { |
| let pair_has_both = self.pair.as_ref().map_or(true, |pair| { |
| let mut has_a = false; |
| let mut has_b = false; |
| for piece in &pair.0 { |
| if let Piece::Sequence { |
| id: Sequence::A, .. |
| } = piece |
| { |
| has_a = true; |
| } |
| if let Piece::Sequence { |
| id: Sequence::B, .. |
| } = piece |
| { |
| has_b = true; |
| } |
| } |
| has_a && has_b |
| }); |
| if !pair_has_both { |
| return Err("Template for `pair` must use both sequences".into()); |
| } |
|
|
| let check = |sp| { |
| let exist = self |
| .special_tokens |
| .as_ref() |
| .map_or(false, |map| map.0.contains_key(sp)); |
|
|
| match exist { |
| false => Some(sp), |
| true => None, |
| } |
| }; |
|
|
| let empty = []; |
| let missing: HashSet<&str> = self |
| .single |
| .as_ref() |
| .map_or(empty.iter(), |s| s.0.iter()) |
| .chain(self.pair.as_ref().map_or(empty.iter(), |s| s.0.iter())) |
| .filter_map(|piece| match piece { |
| Piece::Sequence { .. } => None, |
| Piece::SpecialToken { id, .. } => check(id.as_ref()), |
| }) |
| .collect::<HashSet<_>>(); |
|
|
| if missing.is_empty() { |
| Ok(()) |
| } else { |
| Err(format!( |
| "Missing SpecialToken(s) with id(s) `{}`", |
| missing.iter().join(", ") |
| )) |
| } |
| } |
| } |
|
|
| impl Default for TemplateProcessing { |
| fn default() -> Self { |
| Self { |
| single: "$0".try_into().unwrap(), |
| pair: "$1".try_into().unwrap(), |
| added_single: 0, |
| added_pair: 0, |
| special_tokens: Tokens::default(), |
| } |
| } |
| } |
|
|
| impl TemplateProcessing { |
| pub fn builder() -> TemplateProcessingBuilder { |
| TemplateProcessingBuilder::default() |
| } |
|
|
| fn apply_template( |
| &self, |
| template: &[Piece], |
| mut encodings: Vec<Encoding>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> { |
| let final_encodings: Vec<Encoding> = template |
| .iter() |
| .flat_map(|piece| { |
| match piece { |
| Piece::Sequence { id, type_id } => { |
| let i = usize::from(*id != Sequence::A); |
| let encoding = &mut encodings[i]; |
| encoding.set_type_ids(vec![*type_id; encoding.len()]); |
| encoding.set_sequence_id(i); |
| Some(encoding.clone()) |
| } |
| Piece::SpecialToken { id, type_id } => { |
| if add_special_tokens { |
| let tok = &self.special_tokens.0[id]; |
| let len = tok.ids.len(); |
|
|
| let encoding = Encoding::new( |
| tok.ids.clone(), |
| std::iter::repeat(*type_id).take(len).collect(), |
| tok.tokens.clone(), |
| |
| std::iter::repeat(None).take(len).collect(), |
| |
| std::iter::repeat((0, 0)).take(len).collect(), |
| |
| std::iter::repeat(1).take(len).collect(), |
| |
| std::iter::repeat(1).take(len).collect(), |
| |
| vec![], |
| |
| HashMap::new(), |
| ); |
| Some(encoding) |
| } else { |
| None |
| } |
| } |
| } |
| }) |
| .collect(); |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| Ok(final_encodings) |
| } |
| } |
|
|
| impl PostProcessor for TemplateProcessing { |
| fn added_tokens(&self, is_pair: bool) -> usize { |
| if is_pair { |
| self.added_pair |
| } else { |
| self.added_single |
| } |
| } |
|
|
| fn process_encodings( |
| &self, |
| encodings: Vec<Encoding>, |
| add_special_tokens: bool, |
| ) -> Result<Vec<Encoding>> { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| let template = match encodings.len() { |
| 2 => &self.pair.0, |
| 1 => &self.single.0, |
| _ => todo!(), |
| }; |
| let encodings = self.apply_template(template, encodings, add_special_tokens)?; |
| Ok(encodings) |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use std::convert::TryInto; |
| use std::iter::FromIterator; |
|
|
| #[test] |
| fn piece_serde() { |
| let seq_0 = Piece::Sequence { |
| id: Sequence::A, |
| type_id: 0, |
| }; |
| let seq_0_s = r#"{"Sequence":{"id":"A","type_id":0}}"#; |
|
|
| assert_eq!(serde_json::to_string(&seq_0).unwrap(), seq_0_s); |
| assert_eq!(serde_json::from_str::<Piece>(seq_0_s).unwrap(), seq_0); |
|
|
| let seq_1 = Piece::Sequence { |
| id: Sequence::B, |
| type_id: 1, |
| }; |
| let seq_1_s = r#"{"Sequence":{"id":"B","type_id":1}}"#; |
| assert_eq!(serde_json::to_string(&seq_1).unwrap(), seq_1_s); |
| assert_eq!(serde_json::from_str::<Piece>(seq_1_s).unwrap(), seq_1); |
|
|
| let spe = Piece::SpecialToken { |
| id: "[CLS]".into(), |
| type_id: 0, |
| }; |
| let spe_s = r#"{"SpecialToken":{"id":"[CLS]","type_id":0}}"#; |
| assert_eq!(serde_json::to_string(&spe).unwrap(), spe_s); |
| assert_eq!(serde_json::from_str::<Piece>(spe_s).unwrap(), spe); |
| } |
|
|
| #[test] |
| fn piece() { |
| assert_eq!( |
| Ok(Piece::Sequence { |
| id: Sequence::A, |
| type_id: 0 |
| }), |
| "$".try_into() |
| ); |
| assert_eq!( |
| Ok(Piece::Sequence { |
| id: Sequence::B, |
| type_id: 0 |
| }), |
| "$B".try_into() |
| ); |
| assert_eq!( |
| Ok(Piece::Sequence { |
| id: Sequence::A, |
| type_id: 1 |
| }), |
| "$1".try_into() |
| ); |
| assert_eq!( |
| Ok(Piece::Sequence { |
| id: Sequence::B, |
| type_id: 2 |
| }), |
| "$B:2".try_into() |
| ); |
| assert_eq!( |
| Ok(Piece::Sequence { |
| id: Sequence::A, |
| type_id: 1 |
| }), |
| "$:1".try_into() |
| ); |
| assert!(Piece::try_from("$C:1").is_err()); |
| assert!(Piece::try_from("$A:").is_err()); |
| } |
|
|
| #[test] |
| fn special_token_serde() { |
| let simple = SpecialToken::from(("[CLS]", 0)); |
| let simple_s = r#"{"id":"[CLS]","ids":[0],"tokens":["[CLS]"]}"#; |
| assert_eq!(serde_json::to_string(&simple).unwrap(), simple_s); |
| assert_eq!( |
| serde_json::from_str::<SpecialToken>(simple_s).unwrap(), |
| simple |
| ); |
|
|
| let complete = SpecialToken::new( |
| "[2FR]".into(), |
| vec![1, 2, 3], |
| vec!["convert".into(), "to".into(), "FR".into()], |
| ) |
| .unwrap(); |
| let complete_s = r#"{"id":"[2FR]","ids":[1,2,3],"tokens":["convert","to","FR"]}"#; |
| assert_eq!(serde_json::to_string(&complete).unwrap(), complete_s); |
| assert_eq!( |
| serde_json::from_str::<SpecialToken>(complete_s).unwrap(), |
| complete |
| ); |
|
|
| let malformed = SpecialToken::new( |
| "[2FR]".into(), |
| vec![1, 2], |
| vec!["convert".into(), "to".into(), "FR".into()], |
| ); |
| assert!(malformed.is_err()); |
| let malformed = SpecialToken::new( |
| "[2FR]".into(), |
| vec![1, 2, 3], |
| vec!["convert".into(), "FR".into()], |
| ); |
| assert!(malformed.is_err()); |
| } |
|
|
| #[test] |
| fn template_serde() { |
| let template = Template(vec![ |
| Piece::Sequence { |
| id: Sequence::A, |
| type_id: 0, |
| }, |
| Piece::SpecialToken { |
| id: "[CLS]".into(), |
| type_id: 0, |
| }, |
| ]); |
| let template_s = |
| r#"[{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[CLS]","type_id":0}}]"#; |
| assert_eq!(serde_json::to_string(&template).unwrap(), template_s); |
| assert_eq!( |
| serde_json::from_str::<Template>(template_s).unwrap(), |
| template |
| ); |
| } |
|
|
| #[test] |
| fn tokens_serde() { |
| let tokens = Tokens::from(vec![("[CLS]", 1), ("[SEP]", 0)]); |
| let tokens_s = r#"{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[0],"tokens":["[SEP]"]}}"#; |
| let tokens_ser = serde_json::to_string(&tokens).unwrap(); |
| assert_eq!(tokens_ser, tokens_s); |
| assert_eq!(serde_json::from_str::<Tokens>(tokens_s).unwrap(), tokens); |
| } |
|
|
| fn get_bert_template() -> TemplateProcessing { |
| TemplateProcessing::builder() |
| .try_single(vec!["[CLS]", "$0", "[SEP]"]) |
| .unwrap() |
| .try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1") |
| .unwrap() |
| .special_tokens(vec![("[CLS]", 1), ("[SEP]", 0)]) |
| .build() |
| .unwrap() |
| } |
|
|
| #[test] |
| fn template_processing_serde() { |
| let template = tests::get_bert_template(); |
| let template_s = "{\ |
| \"type\":\"TemplateProcessing\",\ |
| \"single\":[\ |
| {\"SpecialToken\":{\"id\":\"[CLS]\",\"type_id\":0}},\ |
| {\"Sequence\":{\"id\":\"A\",\"type_id\":0}},\ |
| {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":0}}\ |
| ],\ |
| \"pair\":[\ |
| {\"SpecialToken\":{\"id\":\"[CLS]\",\"type_id\":0}},\ |
| {\"Sequence\":{\"id\":\"A\",\"type_id\":0}},\ |
| {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":0}},\ |
| {\"Sequence\":{\"id\":\"B\",\"type_id\":1}},\ |
| {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":1}}\ |
| ],\ |
| \"special_tokens\":{\ |
| \"[CLS]\":{\ |
| \"id\":\"[CLS]\",\"ids\":[1],\"tokens\":[\"[CLS]\"]\ |
| },\ |
| \"[SEP]\":{\ |
| \"id\":\"[SEP]\",\"ids\":[0],\"tokens\":[\"[SEP]\"]\ |
| }\ |
| }}"; |
| let template_ser = serde_json::to_string(&template).unwrap(); |
| assert_eq!(template_ser, template_s); |
| assert_eq!( |
| serde_json::from_str::<TemplateProcessing>(template_s).unwrap(), |
| template |
| ); |
| } |
|
|
| #[test] |
| fn missing_special_tokens() { |
| let processor = TemplateProcessing::builder() |
| .try_single("[CLS] $0 [SEP]") |
| .unwrap() |
| .try_pair("[CLS] $A:0 [SEP] $B:1 [SEP]") |
| .unwrap() |
| .build(); |
|
|
| let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into()); |
| let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into()); |
| assert!(processor == err_a || processor == err_b); |
| } |
|
|
| #[test] |
| fn template_processing() { |
| let processor = tests::get_bert_template(); |
| assert_eq!(processor.added_tokens(false), 2); |
| assert_eq!(processor.added_tokens(true), 3); |
|
|
| use crate::Token; |
| let encoding = Encoding::from_tokens( |
| vec![ |
| Token::new(12, "Hello".into(), (0, 5)), |
| Token::new(14, "there".into(), (6, 11)), |
| ], |
| 0, |
| ); |
| let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0); |
| let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); |
| assert_eq!( |
| single_encoding, |
| Encoding::new( |
| vec![1, 12, 14, 0], |
| vec![0, 0, 0, 0], |
| vec![ |
| "[CLS]".into(), |
| "Hello".into(), |
| "there".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0)], |
| vec![1, 0, 0, 1], |
| vec![1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..3)]), |
| ) |
| ); |
| assert_eq!(single_encoding.token_to_sequence(2), Some(0)); |
| assert_eq!(single_encoding.token_to_sequence(3), None); |
| let pair_encoding = processor.process(encoding, Some(pair), true).unwrap(); |
| assert_eq!( |
| pair_encoding, |
| Encoding::new( |
| vec![1, 12, 14, 0, 15, 0], |
| vec![0, 0, 0, 0, 1, 1], |
| vec![ |
| "[CLS]".into(), |
| "Hello".into(), |
| "there".into(), |
| "[SEP]".into(), |
| "pair".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)], |
| vec![1, 0, 0, 1, 0, 1], |
| vec![1, 1, 1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]), |
| ) |
| ); |
| assert_eq!(pair_encoding.token_to_sequence(2), Some(0)); |
| assert_eq!(pair_encoding.token_to_sequence(3), None); |
| assert_eq!(pair_encoding.token_to_sequence(4), Some(1)); |
| assert_eq!(pair_encoding.token_to_sequence(5), None); |
| } |
|
|
| #[test] |
| fn template_processing_overflowing() { |
| let processor = tests::get_bert_template(); |
| assert_eq!(processor.added_tokens(false), 2); |
| assert_eq!(processor.added_tokens(true), 3); |
|
|
| use crate::Token; |
| let mut encoding = Encoding::from_tokens( |
| vec![ |
| Token::new(12, "Hello".into(), (0, 5)), |
| Token::new(14, "there".into(), (6, 11)), |
| ], |
| 0, |
| ); |
| let overflowing = Encoding::from_tokens(vec![Token::new(13, "you".into(), (12, 15))], 0); |
| encoding.set_overflowing(vec![overflowing]); |
|
|
| let mut pair = Encoding::from_tokens( |
| vec![ |
| Token::new(15, "pair".into(), (0, 4)), |
| Token::new(16, "with".into(), (5, 9)), |
| ], |
| 0, |
| ); |
| let pair_overflowing = |
| Encoding::from_tokens(vec![Token::new(17, "info".into(), (10, 14))], 0); |
| pair.set_overflowing(vec![pair_overflowing]); |
|
|
| let single_encoding = processor.process(encoding.clone(), None, true).unwrap(); |
| assert_eq!( |
| single_encoding, |
| Encoding::new( |
| vec![1, 12, 14, 0], |
| vec![0, 0, 0, 0], |
| vec![ |
| "[CLS]".into(), |
| "Hello".into(), |
| "there".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0)], |
| vec![1, 0, 0, 1], |
| vec![1, 1, 1, 1], |
| vec![Encoding::new( |
| vec![1, 13, 0], |
| vec![0, 0, 0], |
| vec!["[CLS]".into(), "you".into(), "[SEP]".into()], |
| vec![None, None, None], |
| vec![(0, 0), (12, 15), (0, 0)], |
| vec![1, 0, 1], |
| vec![1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..2)]), |
| )], |
| HashMap::from_iter(vec![(0, 1..3)]), |
| ) |
| ); |
| assert_eq!(single_encoding.token_to_sequence(2), Some(0)); |
| assert_eq!(single_encoding.token_to_sequence(3), None); |
| let pair_encoding = processor.process(encoding, Some(pair), true).unwrap(); |
| println!("{pair_encoding:#?}"); |
| assert_eq!( |
| pair_encoding, |
| Encoding::new( |
| vec![1, 12, 14, 0, 15, 16, 0], |
| vec![0, 0, 0, 0, 1, 1, 1], |
| vec![ |
| "[CLS]".into(), |
| "Hello".into(), |
| "there".into(), |
| "[SEP]".into(), |
| "pair".into(), |
| "with".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (5, 9), (0, 0)], |
| vec![1, 0, 0, 1, 0, 0, 1], |
| vec![1, 1, 1, 1, 1, 1, 1], |
| vec![ |
| Encoding::new( |
| vec![1, 13, 0, 15, 16, 0], |
| vec![0, 0, 0, 1, 1, 1], |
| vec![ |
| "[CLS]".into(), |
| "you".into(), |
| "[SEP]".into(), |
| "pair".into(), |
| "with".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None, None], |
| vec![(0, 0), (12, 15), (0, 0), (0, 4), (5, 9), (0, 0)], |
| vec![1, 0, 1, 0, 0, 1], |
| vec![1, 1, 1, 1, 1, 1], |
| vec![Encoding::new( |
| vec![1, 13, 0, 17, 0], |
| vec![0, 0, 0, 0, 1], |
| vec![ |
| "[CLS]".into(), |
| "you".into(), |
| "[SEP]".into(), |
| "info".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None,], |
| vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)], |
| vec![1, 0, 1, 0, 1], |
| vec![1, 1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]), |
| ),], |
| HashMap::from_iter(vec![(1, 3..5), (0, 1..2)]), |
| ), |
| Encoding::new( |
| vec![1, 13, 0, 17, 0], |
| vec![0, 0, 0, 0, 1], |
| vec![ |
| "[CLS]".into(), |
| "you".into(), |
| "[SEP]".into(), |
| "info".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None,], |
| vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)], |
| vec![1, 0, 1, 0, 1], |
| vec![1, 1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]), |
| ), |
| Encoding::new( |
| vec![1, 12, 14, 0, 17, 0], |
| vec![0, 0, 0, 0, 0, 1], |
| vec![ |
| "[CLS]".into(), |
| "Hello".into(), |
| "there".into(), |
| "[SEP]".into(), |
| "info".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None, None], |
| vec![(0, 0), (0, 5), (6, 11), (0, 0), (10, 14), (0, 0)], |
| vec![1, 0, 0, 1, 0, 1], |
| vec![1, 1, 1, 1, 1, 1], |
| vec![Encoding::new( |
| vec![1, 13, 0, 17, 0], |
| vec![0, 0, 0, 0, 1], |
| vec![ |
| "[CLS]".into(), |
| "you".into(), |
| "[SEP]".into(), |
| "info".into(), |
| "[SEP]".into() |
| ], |
| vec![None, None, None, None, None,], |
| vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)], |
| vec![1, 0, 1, 0, 1], |
| vec![1, 1, 1, 1, 1], |
| vec![], |
| HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]), |
| ),], |
| HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]), |
| ) |
| ], |
| HashMap::from_iter(vec![(0, 1..3), (1, 4..6)]), |
| ) |
| ); |
| assert_eq!(pair_encoding.token_to_sequence(2), Some(0)); |
| assert_eq!(pair_encoding.token_to_sequence(3), None); |
| assert_eq!(pair_encoding.token_to_sequence(4), Some(1)); |
| assert_eq!(pair_encoding.token_to_sequence(5), Some(1)); |
| assert_eq!(pair_encoding.token_to_sequence(6), None); |
| } |
| #[test] |
| fn pair_must_use_both_sequences() { |
| let processor = TemplateProcessing::builder() |
| .try_single("$0") |
| .unwrap() |
| .try_pair("$0 $1") |
| .unwrap() |
| .build(); |
| assert_eq!( |
| processor, |
| Err("Template for `pair` must use both sequences".into()) |
| ); |
| } |
|
|
| #[test] |
| fn expect_wrong_error_message() { |
| let processor = TemplateProcessing::builder() |
| .try_single("$0") |
| .unwrap() |
| .try_pair("$0 $1") |
| .unwrap() |
| .build(); |
| assert_ne!( |
| processor, |
| Err("Expect the left side error message to be different from the right side!".into()) |
| ); |
| } |
| } |
|
|