| | |
| | |
| | |
| |
|
| | use crate::{Error, Result}; |
| | use std::collections::HashMap; |
| | use std::path::Path; |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct TokenizerConfig { |
| | |
| | pub model_path: String, |
| | |
| | pub vocab_size: usize, |
| | |
| | pub bos_id: i64, |
| | |
| | pub eos_id: i64, |
| | |
| | pub unk_id: i64, |
| | |
| | pub pad_id: i64, |
| | } |
| |
|
| | impl Default for TokenizerConfig { |
| | fn default() -> Self { |
| | Self { |
| | model_path: "models/bpe.model".to_string(), |
| | vocab_size: 6681, |
| | bos_id: 1, |
| | eos_id: 2, |
| | unk_id: 0, |
| | pad_id: 3, |
| | } |
| | } |
| | } |
| |
|
| | |
| | #[derive(Debug)] |
| | pub struct TextTokenizer { |
| | |
| | config: TokenizerConfig, |
| | |
| | token_to_id: HashMap<String, i64>, |
| | |
| | id_to_token: HashMap<i64, String>, |
| | |
| | char_vocab: HashMap<char, i64>, |
| | } |
| |
|
| | impl TextTokenizer { |
| | |
| | pub fn new(config: TokenizerConfig) -> Result<Self> { |
| | let mut token_to_id = HashMap::new(); |
| | let mut id_to_token = HashMap::new(); |
| | let mut char_vocab = HashMap::new(); |
| |
|
| | |
| | token_to_id.insert("<unk>".to_string(), config.unk_id); |
| | token_to_id.insert("<s>".to_string(), config.bos_id); |
| | token_to_id.insert("</s>".to_string(), config.eos_id); |
| | token_to_id.insert("<pad>".to_string(), config.pad_id); |
| |
|
| | id_to_token.insert(config.unk_id, "<unk>".to_string()); |
| | id_to_token.insert(config.bos_id, "<s>".to_string()); |
| | id_to_token.insert(config.eos_id, "</s>".to_string()); |
| | id_to_token.insert(config.pad_id, "<pad>".to_string()); |
| |
|
| | |
| | let mut next_id = 4i64; |
| | for c in ' '..='~' { |
| | char_vocab.insert(c, next_id); |
| | token_to_id.insert(c.to_string(), next_id); |
| | id_to_token.insert(next_id, c.to_string()); |
| | next_id += 1; |
| | } |
| |
|
| | |
| | |
| | for code_point in 0x4E00u32..=0x9FFF { |
| | if let Some(c) = char::from_u32(code_point) { |
| | char_vocab.insert(c, next_id); |
| | token_to_id.insert(c.to_string(), next_id); |
| | id_to_token.insert(next_id, c.to_string()); |
| | next_id += 1; |
| |
|
| | if next_id >= config.vocab_size as i64 { |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | Ok(Self { |
| | config, |
| | token_to_id, |
| | id_to_token, |
| | char_vocab, |
| | }) |
| | } |
| |
|
| | |
| | pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let path = path.as_ref(); |
| | if !path.exists() { |
| | return Err(Error::FileNotFound(path.display().to_string())); |
| | } |
| |
|
| | |
| | |
| | let config = TokenizerConfig { |
| | model_path: path.display().to_string(), |
| | ..Default::default() |
| | }; |
| |
|
| | Self::new(config) |
| | } |
| |
|
| | |
| | pub fn encode(&self, text: &str) -> Result<Vec<i64>> { |
| | let mut tokens = Vec::new(); |
| |
|
| | |
| | tokens.push(self.config.bos_id); |
| |
|
| | |
| | |
| | for ch in text.chars() { |
| | if let Some(&id) = self.char_vocab.get(&ch) { |
| | tokens.push(id); |
| | } else if let Some(&id) = self.token_to_id.get(&ch.to_string()) { |
| | tokens.push(id); |
| | } else { |
| | |
| | tokens.push(self.config.unk_id); |
| | } |
| | } |
| |
|
| | |
| | tokens.push(self.config.eos_id); |
| |
|
| | Ok(tokens) |
| | } |
| |
|
| | |
| | pub fn encode_without_special(&self, text: &str) -> Result<Vec<i64>> { |
| | let mut tokens = Vec::new(); |
| |
|
| | for ch in text.chars() { |
| | if let Some(&id) = self.char_vocab.get(&ch) { |
| | tokens.push(id); |
| | } else if let Some(&id) = self.token_to_id.get(&ch.to_string()) { |
| | tokens.push(id); |
| | } else { |
| | tokens.push(self.config.unk_id); |
| | } |
| | } |
| |
|
| | Ok(tokens) |
| | } |
| |
|
| | |
| | pub fn decode(&self, tokens: &[i64]) -> Result<String> { |
| | let mut text = String::new(); |
| |
|
| | for &token_id in tokens { |
| | |
| | if token_id == self.config.bos_id |
| | || token_id == self.config.eos_id |
| | || token_id == self.config.pad_id |
| | { |
| | continue; |
| | } |
| |
|
| | if let Some(token) = self.id_to_token.get(&token_id) { |
| | text.push_str(token); |
| | } else { |
| | |
| | text.push('?'); |
| | } |
| | } |
| |
|
| | Ok(text) |
| | } |
| |
|
| | |
| | pub fn vocab_size(&self) -> usize { |
| | self.config.vocab_size |
| | } |
| |
|
| | |
| | pub fn bos_id(&self) -> i64 { |
| | self.config.bos_id |
| | } |
| |
|
| | |
| | pub fn eos_id(&self) -> i64 { |
| | self.config.eos_id |
| | } |
| |
|
| | |
| | pub fn unk_id(&self) -> i64 { |
| | self.config.unk_id |
| | } |
| |
|
| | |
| | pub fn pad_id(&self) -> i64 { |
| | self.config.pad_id |
| | } |
| |
|
| | |
| | pub fn pad_sequences(&self, sequences: &[Vec<i64>], max_len: Option<usize>) -> Vec<Vec<i64>> { |
| | let max_length = max_len.unwrap_or_else(|| sequences.iter().map(|s| s.len()).max().unwrap_or(0)); |
| |
|
| | sequences |
| | .iter() |
| | .map(|seq| { |
| | let mut padded = seq.clone(); |
| | while padded.len() < max_length { |
| | padded.push(self.config.pad_id); |
| | } |
| | padded.truncate(max_length); |
| | padded |
| | }) |
| | .collect() |
| | } |
| |
|
| | |
| | pub fn create_attention_mask(&self, tokens: &[i64]) -> Vec<i64> { |
| | tokens |
| | .iter() |
| | .map(|&t| if t == self.config.pad_id { 0 } else { 1 }) |
| | .collect() |
| | } |
| |
|
| | |
| | pub fn batch_encode(&self, texts: &[&str]) -> Result<Vec<Vec<i64>>> { |
| | texts.iter().map(|text| self.encode(text)).collect() |
| | } |
| |
|
| | |
| | pub fn batch_encode_padded( |
| | &self, |
| | texts: &[&str], |
| | max_len: Option<usize>, |
| | ) -> Result<Vec<Vec<i64>>> { |
| | let encoded: Vec<Vec<i64>> = self.batch_encode(texts)?; |
| | Ok(self.pad_sequences(&encoded, max_len)) |
| | } |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_tokenizer_creation() { |
| | let config = TokenizerConfig::default(); |
| | let tokenizer = TextTokenizer::new(config).unwrap(); |
| | assert!(tokenizer.vocab_size() > 0); |
| | } |
| |
|
| | #[test] |
| | fn test_encode_decode() { |
| | let config = TokenizerConfig::default(); |
| | let tokenizer = TextTokenizer::new(config).unwrap(); |
| |
|
| | let text = "Hello world"; |
| | let tokens = tokenizer.encode(text).unwrap(); |
| |
|
| | |
| | assert_eq!(tokens[0], tokenizer.bos_id()); |
| | assert_eq!(*tokens.last().unwrap(), tokenizer.eos_id()); |
| |
|
| | let decoded = tokenizer.decode(&tokens).unwrap(); |
| | assert_eq!(decoded, text); |
| | } |
| |
|
| | #[test] |
| | fn test_encode_chinese() { |
| | let config = TokenizerConfig::default(); |
| | let tokenizer = TextTokenizer::new(config).unwrap(); |
| |
|
| | let text = "你好"; |
| | let tokens = tokenizer.encode(text).unwrap(); |
| |
|
| | |
| | assert_eq!(tokens.len(), 4); |
| | } |
| |
|
| | #[test] |
| | fn test_pad_sequences() { |
| | let config = TokenizerConfig::default(); |
| | let tokenizer = TextTokenizer::new(config).unwrap(); |
| |
|
| | let seq1 = vec![1, 2, 3]; |
| | let seq2 = vec![1, 2, 3, 4, 5]; |
| |
|
| | let padded = tokenizer.pad_sequences(&[seq1, seq2], None); |
| |
|
| | assert_eq!(padded[0].len(), 5); |
| | assert_eq!(padded[1].len(), 5); |
| | assert_eq!(padded[0][3], tokenizer.pad_id()); |
| | } |
| |
|
| | #[test] |
| | fn test_attention_mask() { |
| | let config = TokenizerConfig::default(); |
| | let tokenizer = TextTokenizer::new(config).unwrap(); |
| |
|
| | let tokens = vec![1, 2, tokenizer.pad_id(), tokenizer.pad_id()]; |
| | let mask = tokenizer.create_attention_mask(&tokens); |
| |
|
| | assert_eq!(mask, vec![1, 1, 0, 0]); |
| | } |
| | } |
| |
|