|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import unittest
|
| import torch
|
| from pathlib import Path
|
| from transformers import AutoTokenizer
|
| from .tokenization_sapnous import SapnousTokenizer
|
|
|
| class TestSapnousTokenizer(unittest.TestCase):
|
| @classmethod
|
| def setUpClass(cls):
|
|
|
| cls.temp_dir = Path('test_tokenizer_files')
|
| cls.temp_dir.mkdir(exist_ok=True)
|
|
|
|
|
| cls.vocab_file = cls.temp_dir / 'vocab.json'
|
| cls.vocab = {
|
| '<|endoftext|>': 0,
|
| '<|startoftext|>': 1,
|
| '<|pad|>': 2,
|
| '<|vision_start|>': 3,
|
| '<|vision_end|>': 4,
|
| '<|image|>': 5,
|
| '<|video|>': 6,
|
| 'hello': 7,
|
| 'world': 8,
|
| 'test': 9,
|
| }
|
| with cls.vocab_file.open('w', encoding='utf-8') as f:
|
| import json
|
| json.dump(cls.vocab, f)
|
|
|
|
|
| cls.merges_file = cls.temp_dir / 'merges.txt'
|
| merges_content = "#version: 0.2\nh e\ne l\nl l\no w\nw o\no r\nr l\nl d"
|
| cls.merges_file.write_text(merges_content)
|
|
|
|
|
| cls.tokenizer = SapnousTokenizer(
|
| str(cls.vocab_file),
|
| str(cls.merges_file),
|
| )
|
|
|
| @classmethod
|
| def tearDownClass(cls):
|
|
|
| import shutil
|
| shutil.rmtree(cls.temp_dir)
|
|
|
| def test_tokenizer_initialization(self):
|
| self.assertEqual(self.tokenizer.vocab_size, len(self.vocab))
|
| self.assertEqual(self.tokenizer.get_vocab(), self.vocab)
|
|
|
|
|
| self.assertEqual(self.tokenizer.unk_token, '<|endoftext|>')
|
| self.assertEqual(self.tokenizer.bos_token, '<|startoftext|>')
|
| self.assertEqual(self.tokenizer.eos_token, '<|endoftext|>')
|
| self.assertEqual(self.tokenizer.pad_token, '<|pad|>')
|
|
|
| def test_tokenization(self):
|
| text = "hello world test"
|
| tokens = self.tokenizer.tokenize(text)
|
| self.assertIsInstance(tokens, list)
|
| self.assertTrue(all(isinstance(token, str) for token in tokens))
|
|
|
|
|
| input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| self.assertIsInstance(input_ids, list)
|
| self.assertEqual(len(input_ids), 3)
|
|
|
|
|
| decoded_text = self.tokenizer.decode(input_ids)
|
| self.assertEqual(decoded_text.strip(), text)
|
|
|
| def test_special_tokens_handling(self):
|
| text = "hello world"
|
|
|
| tokens_with_special = self.tokenizer.encode(text, add_special_tokens=True)
|
| self.assertTrue(tokens_with_special[0] == self.tokenizer.bos_token_id)
|
| self.assertTrue(tokens_with_special[-1] == self.tokenizer.eos_token_id)
|
|
|
|
|
| tokens_without_special = self.tokenizer.encode(text, add_special_tokens=False)
|
| self.assertNotEqual(tokens_without_special[0], self.tokenizer.bos_token_id)
|
| self.assertNotEqual(tokens_without_special[-1], self.tokenizer.eos_token_id)
|
|
|
| def test_vision_tokens(self):
|
|
|
| text = "This is an image description"
|
| vision_text = self.tokenizer.prepare_for_vision(text)
|
| self.assertTrue(vision_text.startswith('<|vision_start|>'))
|
| self.assertTrue(vision_text.endswith('<|vision_end|>'))
|
|
|
| image_text = self.tokenizer.prepare_for_image(text)
|
| self.assertTrue(image_text.startswith('<|image|>'))
|
|
|
| video_text = self.tokenizer.prepare_for_video(text)
|
| self.assertTrue(video_text.startswith('<|video|>'))
|
|
|
| def test_batch_encoding(self):
|
| texts = ["hello world", "test hello"]
|
| batch_encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
|
|
| self.assertIsInstance(batch_encoding["input_ids"], torch.Tensor)
|
| self.assertIsInstance(batch_encoding["attention_mask"], torch.Tensor)
|
| self.assertEqual(batch_encoding["input_ids"].shape[0], len(texts))
|
| self.assertEqual(batch_encoding["attention_mask"].shape[0], len(texts))
|
|
|
| def test_save_and_load(self):
|
|
|
| save_dir = Path('test_save_tokenizer')
|
| save_dir.mkdir(exist_ok=True)
|
|
|
| try:
|
| vocab_files = self.tokenizer.save_vocabulary(str(save_dir))
|
| self.assertTrue(all(Path(f).exists() for f in vocab_files))
|
|
|
|
|
| loaded_tokenizer = SapnousTokenizer(*vocab_files)
|
| self.assertEqual(loaded_tokenizer.get_vocab(), self.tokenizer.get_vocab())
|
|
|
|
|
| text = "hello world test"
|
| original_encoding = self.tokenizer.encode(text)
|
| loaded_encoding = loaded_tokenizer.encode(text)
|
| self.assertEqual(original_encoding, loaded_encoding)
|
| finally:
|
|
|
| import shutil
|
| shutil.rmtree(save_dir)
|
|
|
| def test_auto_tokenizer_registration(self):
|
|
|
| config = {
|
| "model_type": "sapnous",
|
| "vocab_file": str(self.vocab_file),
|
| "merges_file": str(self.merges_file)
|
| }
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(str(self.temp_dir), **config)
|
| self.assertIsInstance(tokenizer, SapnousTokenizer)
|
|
|
| if __name__ == '__main__':
|
| unittest.main() |