File size: 3,228 Bytes
708f4a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import unittest
import sys
import os
import tempfile
import mmap
import json
from pathlib import Path
try:
from crayon.c_ext import crayon_cpu, crayon_trainer, crayon_compiler
EXTENSIONS_AVAILABLE = True
except ImportError:
EXTENSIONS_AVAILABLE = False
@unittest.skipUnless(EXTENSIONS_AVAILABLE, "C++ extensions not available")
class TestCrayonExtensions(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create a small test vocabulary and build a DAT
cls.test_vocab = ["a", "ab", "abc", "b", "c", " ", "def"]
# id mapping: 0:a, 1:ab, 2:abc, 3:b, 4:c, 5:" ", 6:def
fd, cls.temp_dat = tempfile.mkstemp(suffix=".dat")
os.close(fd)
# Build DAT using the NEW compiler
stats = crayon_compiler.compile_dat(cls.test_vocab, cls.temp_dat)
# Load into CPU engine
with open(cls.temp_dat, "rb") as f:
cls.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
crayon_cpu.load_dat(cls.mmap_obj)
@classmethod
def tearDownClass(cls):
if hasattr(cls, 'mmap_obj'):
# The C extension holds a reference to the mmap object
# If we don't clear it, closing mmap throws BufferError
try:
crayon_cpu.load_dat(b"") # clear the buffer inside crayon_cpu (hacky but it should work)
except Exception:
pass
try:
cls.mmap_obj.close()
except BufferError:
pass
if hasattr(cls, 'temp_dat') and os.path.exists(cls.temp_dat):
os.unlink(cls.temp_dat)
def test_compiler_version(self):
self.assertEqual(crayon_compiler.get_version(), "2.0.0-hyperfast")
def test_cpu_hardware_info(self):
info = crayon_cpu.get_hardware_info()
self.assertIsInstance(info, str)
self.assertIn("[", info)
def test_tokenize_simple(self):
# "abc" should be its own token (id 2)
tokens = crayon_cpu.tokenize("abc")
self.assertEqual(tokens, [2])
def test_tokenize_longest_match(self):
# "ab" + "c" vs "abc" -> should pick "abc" (id 2)
tokens = crayon_cpu.tokenize("abc")
self.assertEqual(tokens, [2])
# "a" + "b" -> should pick "ab" (id 1)
tokens = crayon_cpu.tokenize("ab")
self.assertEqual(tokens, [1])
def test_tokenize_fallback_unk(self):
# "x" is not in vocab. UNK is ID 1 by convention in the engine fallback.
# Wait, in OUR engine, if it fails to find a match, it appends ID 1 (hardcoded fallback).
tokens = crayon_cpu.tokenize("x")
self.assertEqual(tokens, [1])
def test_trainer_basic(self):
corpus = b"banana banana banana"
# Train a small BPE. Vocab size must be > 256.
merges = crayon_trainer.train_fast(corpus, 260, min_freq=1, verbose=0)
self.assertIsInstance(merges, list)
self.assertGreater(len(merges), 0)
# Each merge is a tuple of ((token_a, token_b), new_id)
for m in merges:
self.assertIsInstance(m, tuple)
self.assertEqual(len(m), 2)
if __name__ == "__main__":
unittest.main() |