File size: 1,465 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | use crate::decoders::DecoderWrapper;
use crate::tokenizer::{Decoder, Result};
use crate::utils::macro_rules_attribute;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Sequence {
decoders: Vec<DecoderWrapper>,
}
impl Sequence {
pub fn new(decoders: Vec<DecoderWrapper>) -> Self {
Self { decoders }
}
pub fn get_decoders(&self) -> &[DecoderWrapper] {
&self.decoders
}
pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] {
&mut self.decoders
}
}
impl Decoder for Sequence {
fn decode_chain(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
for decoder in &self.decoders {
tokens = decoder.decode_chain(tokens)?;
}
Ok(tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decoders::ctc::CTC;
use crate::pre_tokenizers::metaspace::Metaspace;
#[test]
fn sequence_basic() {
let decoders = vec![
DecoderWrapper::CTC(CTC::default()),
DecoderWrapper::Metaspace(Metaspace::default()),
];
let decoder = Sequence::new(decoders);
let tokens: Vec<String> = vec!["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"]
.into_iter()
.map(|s| s.to_string())
.collect();
let out_tokens = decoder.decode(tokens).unwrap();
assert_eq!(out_tokens, "Hi you");
}
}
|