use super::{super::OrderedVocabIter, WordLevel, WordLevelBuilder}; use serde::{ de::{MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; use std::collections::HashSet; impl Serialize for WordLevel { fn serialize(&self, serializer: S) -> Result where S: Serializer, { let mut model = serializer.serialize_struct("WordLevel", 3)?; let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); model.serialize_field("type", "WordLevel")?; model.serialize_field("vocab", &ordered_vocab)?; model.serialize_field("unk_token", &self.unk_token)?; model.end() } } impl<'de> Deserialize<'de> for WordLevel { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { deserializer.deserialize_struct( "WordLevel", &["type", "vocab", "unk_token"], WordLevelVisitor, ) } } struct WordLevelVisitor; impl<'de> Visitor<'de> for WordLevelVisitor { type Value = WordLevel; fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { write!(fmt, "struct WordLevel") } fn visit_map(self, mut map: V) -> std::result::Result where V: MapAccess<'de>, { let mut builder = WordLevelBuilder::new(); let mut missing_fields = vec![ // for retrocompatibility the "type" field is not mandatory "unk_token", "vocab", ] .into_iter() .collect::>(); while let Some(key) = map.next_key::()? { match key.as_ref() { "vocab" => builder = builder.vocab(map.next_value()?), "unk_token" => builder = builder.unk_token(map.next_value()?), "type" => match map.next_value()? { "WordLevel" => {} u => { return Err(serde::de::Error::invalid_value( serde::de::Unexpected::Str(u), &"WordLevel", )) } }, _ => {} } missing_fields.remove::(&key); } if !missing_fields.is_empty() { Err(serde::de::Error::missing_field( missing_fields.iter().next().unwrap(), )) } else { Ok(builder.build().map_err(serde::de::Error::custom)?) } } } #[cfg(test)] mod tests { use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder}; #[test] fn serde() { let wl = WordLevel::default(); let wl_s = r#"{"type":"WordLevel","vocab":{},"unk_token":""}"#; assert_eq!(serde_json::to_string(&wl).unwrap(), wl_s); assert_eq!(serde_json::from_str::(wl_s).unwrap(), wl); } #[test] fn incomplete_vocab() { let vocab: Vocab = [("".into(), 0), ("b".into(), 2)] .iter() .cloned() .collect(); let wordlevel = WordLevelBuilder::default() .vocab(vocab) .unk_token("".to_string()) .build() .unwrap(); let wl_s = r#"{"type":"WordLevel","vocab":{"":0,"b":2},"unk_token":""}"#; assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s); assert_eq!(serde_json::from_str::(wl_s).unwrap(), wordlevel); } #[test] fn deserialization_should_fail() { let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#; assert!(serde_json::from_str::(missing_unk) .unwrap_err() .to_string() .starts_with("missing field `unk_token`")); let wrong_type = r#"{"type":"WordPiece","vocab":{}}"#; assert!(serde_json::from_str::(wrong_type) .unwrap_err() .to_string() .starts_with("invalid value: string \"WordPiece\", expected WordLevel")); } }