|
|
| import unittest |
| import sys |
| import os |
| import tempfile |
| import mmap |
| import json |
| from pathlib import Path |
|
|
|
|
| try: |
| from crayon.c_ext import crayon_cpu, crayon_trainer, crayon_compiler |
| EXTENSIONS_AVAILABLE = True |
| except ImportError: |
| EXTENSIONS_AVAILABLE = False |
|
|
| @unittest.skipUnless(EXTENSIONS_AVAILABLE, "C++ extensions not available") |
| class TestCrayonExtensions(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| |
| cls.test_vocab = ["a", "ab", "abc", "b", "c", " ", "def"] |
| |
| |
| fd, cls.temp_dat = tempfile.mkstemp(suffix=".dat") |
| os.close(fd) |
| |
| |
| stats = crayon_compiler.compile_dat(cls.test_vocab, cls.temp_dat) |
| |
| |
| with open(cls.temp_dat, "rb") as f: |
| cls.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) |
| crayon_cpu.load_dat(cls.mmap_obj) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| if hasattr(cls, 'mmap_obj'): |
| |
| |
| try: |
| crayon_cpu.load_dat(b"") |
| except Exception: |
| pass |
| try: |
| cls.mmap_obj.close() |
| except BufferError: |
| pass |
| if hasattr(cls, 'temp_dat') and os.path.exists(cls.temp_dat): |
| os.unlink(cls.temp_dat) |
|
|
| def test_compiler_version(self): |
| self.assertEqual(crayon_compiler.get_version(), "2.0.0-hyperfast") |
|
|
| def test_cpu_hardware_info(self): |
| info = crayon_cpu.get_hardware_info() |
| self.assertIsInstance(info, str) |
| self.assertIn("[", info) |
|
|
| def test_tokenize_simple(self): |
| |
| tokens = crayon_cpu.tokenize("abc") |
| self.assertEqual(tokens, [2]) |
|
|
| def test_tokenize_longest_match(self): |
| |
| tokens = crayon_cpu.tokenize("abc") |
| self.assertEqual(tokens, [2]) |
| |
| |
| tokens = crayon_cpu.tokenize("ab") |
| self.assertEqual(tokens, [1]) |
|
|
| def test_tokenize_fallback_unk(self): |
| |
| |
| tokens = crayon_cpu.tokenize("x") |
| self.assertEqual(tokens, [1]) |
|
|
| def test_trainer_basic(self): |
| corpus = b"banana banana banana" |
| |
| merges = crayon_trainer.train_fast(corpus, 260, min_freq=1, verbose=0) |
| self.assertIsInstance(merges, list) |
| self.assertGreater(len(merges), 0) |
| |
| for m in merges: |
| self.assertIsInstance(m, tuple) |
| self.assertEqual(len(m), 2) |
|
|
| if __name__ == "__main__": |
| unittest.main() |