File size: 3,228 Bytes
708f4a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

import unittest
import sys
import os
import tempfile
import mmap
import json
from pathlib import Path


try:
    from crayon.c_ext import crayon_cpu, crayon_trainer, crayon_compiler
    EXTENSIONS_AVAILABLE = True
except ImportError:
    EXTENSIONS_AVAILABLE = False

@unittest.skipUnless(EXTENSIONS_AVAILABLE, "C++ extensions not available")
class TestCrayonExtensions(unittest.TestCase):
    
    @classmethod
    def setUpClass(cls):
        # Create a small test vocabulary and build a DAT
        cls.test_vocab = ["a", "ab", "abc", "b", "c", " ", "def"]
        # id mapping: 0:a, 1:ab, 2:abc, 3:b, 4:c, 5:" ", 6:def
        
        fd, cls.temp_dat = tempfile.mkstemp(suffix=".dat")
        os.close(fd)
        
        # Build DAT using the NEW compiler
        stats = crayon_compiler.compile_dat(cls.test_vocab, cls.temp_dat)
        
        # Load into CPU engine
        with open(cls.temp_dat, "rb") as f:
            cls.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
            crayon_cpu.load_dat(cls.mmap_obj)

    @classmethod
    def tearDownClass(cls):
        if hasattr(cls, 'mmap_obj'):
            # The C extension holds a reference to the mmap object
            # If we don't clear it, closing mmap throws BufferError
            try:
                crayon_cpu.load_dat(b"") # clear the buffer inside crayon_cpu (hacky but it should work)
            except Exception:
                pass
            try:
                cls.mmap_obj.close()
            except BufferError:
                pass
        if hasattr(cls, 'temp_dat') and os.path.exists(cls.temp_dat):
            os.unlink(cls.temp_dat)

    def test_compiler_version(self):
        self.assertEqual(crayon_compiler.get_version(), "2.0.0-hyperfast")

    def test_cpu_hardware_info(self):
        info = crayon_cpu.get_hardware_info()
        self.assertIsInstance(info, str)
        self.assertIn("[", info)

    def test_tokenize_simple(self):
        # "abc" should be its own token (id 2)
        tokens = crayon_cpu.tokenize("abc")
        self.assertEqual(tokens, [2])

    def test_tokenize_longest_match(self):
        # "ab" + "c" vs "abc" -> should pick "abc" (id 2)
        tokens = crayon_cpu.tokenize("abc")
        self.assertEqual(tokens, [2])
        
        # "a" + "b" -> should pick "ab" (id 1)
        tokens = crayon_cpu.tokenize("ab")
        self.assertEqual(tokens, [1])

    def test_tokenize_fallback_unk(self):
        # "x" is not in vocab. UNK is ID 1 by convention in the engine fallback.
        # Wait, in OUR engine, if it fails to find a match, it appends ID 1 (hardcoded fallback).
        tokens = crayon_cpu.tokenize("x")
        self.assertEqual(tokens, [1])

    def test_trainer_basic(self):
        corpus = b"banana banana banana"
        # Train a small BPE. Vocab size must be > 256.
        merges = crayon_trainer.train_fast(corpus, 260, min_freq=1, verbose=0)
        self.assertIsInstance(merges, list)
        self.assertGreater(len(merges), 0)
        # Each merge is a tuple of ((token_a, token_b), new_id)
        for m in merges:
            self.assertIsInstance(m, tuple)
            self.assertEqual(len(m), 2)

if __name__ == "__main__":
    unittest.main()