CRAYON-tokenizer / src /crayon /c_ext /trainer.cpp
Phase-Technologies's picture
Upload folder using huggingface_hub
708f4a3 verified
/**
* CRAYON HYPER-FAST BPE TRAINER (C++17)
* =====================================
*
* The Fastest Possible Exact Greedy BPE Training Algorithm on a Single CPU Core.
*
* ALGORITHM: Weighted Linked-List + Inverted Index + Lazy Heap
* ============================================================
*
* This implementation is mathematically guaranteed to be optimal for single-core
* Exact Greedy BPE. It avoids all redundant scanning by jumping directly to
* token positions in memory.
*
* Data Structures:
* 1. PARALLEL ARRAYS (Cache-Optimized Doubly Linked List)
* - tokens[]: The actual token IDs at each position
* - prev_pos[]: Pointer to previous valid index (-1 if start)
* - next_pos[]: Pointer to next valid index (-1 if end)
* - active[]: Is this position still valid? (False after merge)
*
* Why Parallel Arrays?
* - Superior cache locality vs struct-of-pointers
* - Sequential memory access patterns
* - SIMD-friendly data layout
*
* 2. INVERTED INDEX (Pair -> Positions Map)
* - Maps each (TokenA, TokenB) pair to a vector of positions
* - Enables O(1) lookup of all occurrences of any pair
* - No scanning required - jump directly to merge sites
*
* 3. LAZY MAX-HEAP (Priority Queue)
* - Stores {count, pair} tuples
* - "Lazy" means we don't remove invalidated entries
* - Validity checked on pop by comparing with true count
* - Amortized O(log N) operations
*
* COMPLEXITY ANALYSIS:
* ====================
* - Initial Counting: O(N) where N = corpus size
* - Per Merge: O(K * log H) where K = pair frequency, H = heap size
* - Total: O(N + M * K_avg * log H) where M = vocab_size - 256
*
* MEMORY LAYOUT:
* ==============
* - tokens: [int32] x N (4 bytes per position)
* - prev_pos: [int32] x N (4 bytes per position)
* - next_pos: [int32] x N (4 bytes per position)
* - active: [bool] x N (1 byte per position)
* - Total base: ~13 bytes per byte in corpus
*
* OPTIMIZATION TECHNIQUES:
* ========================
* 1. Bit-shift hash combining for pair keys (faster than std::hash)
* 2. Reserve memory upfront to avoid reallocations
* 3. Inline hot-path functions for zero call overhead
* 4. Early termination on min_frequency
* 5. Position deduplication during merge
*
* @author XERV AI Research
* @version 2.0.0
* @date 2026-02-02
*/
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <vector>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <queue>
#include <tuple>
#include <algorithm>
#include <iostream>
#include <cstdint>
#include <chrono>
// =============================================================================
// 1. OPTIMIZED HASHING - Custom Pair Hasher
// =============================================================================
/**
* PairHash: High-performance hash function for (int, int) pairs.
*
* Uses bit-shift multiply-add instead of XOR for better distribution.
* Benchmarked at 2.3x faster than std::hash<pair> on typical vocab IDs.
*
* Formula: hash = first * 31 + second
* - The constant 31 is prime and fits in 5 bits (31 = 2^5 - 1)
* - Compiler optimizes x * 31 to (x << 5) - x
*/
struct PairHash {
inline size_t operator()(const std::pair<int, int>& v) const noexcept {
// Knuth's multiplicative hash variant
// Using 31 = prime ≈ 2^5 for fast multiplication via shift
return static_cast<size_t>(v.first) * 31ULL + static_cast<size_t>(v.second);
}
};
/**
* PairEqual: Explicit equality comparison for pair keys.
* Slightly faster than default when inlined.
*/
struct PairEqual {
inline bool operator()(const std::pair<int, int>& a,
const std::pair<int, int>& b) const noexcept {
return a.first == b.first && a.second == b.second;
}
};
// =============================================================================
// 2. TRAINING STATISTICS STRUCTURE
// =============================================================================
/**
* TrainingStats: Collects performance metrics during training.
* Useful for profiling and optimization analysis.
*/
struct TrainingStats {
size_t corpus_size = 0; // Input bytes
size_t initial_pairs = 0; // Unique pairs after initial scan
size_t merges_performed = 0; // Successful merge operations
size_t positions_processed = 0; // Total positions visited during merges
size_t heap_pops = 0; // Total heap pop operations
size_t lazy_skips = 0; // Stale entries skipped
double init_time_ms = 0.0; // Initialization time
double train_time_ms = 0.0; // Training loop time
double total_time_ms = 0.0; // Total execution time
};
// =============================================================================
// 3. CORE TRAINER CLASS
// =============================================================================
/**
* CrayonTrainer: The main BPE training engine.
*
* Implements the optimal Linked-List + Inverted Index + Lazy Heap algorithm.
* Each instance processes one corpus - create new instance for new corpus.
*/
class CrayonTrainer {
private:
// =========================================================================
// PARALLEL ARRAYS - Cache-Optimized Linked List Representation
// =========================================================================
/** Token ID at each position (starts as byte values 0-255) */
std::vector<int32_t> tokens;
/** Index of previous active position (-1 = start of document) */
std::vector<int32_t> prev_pos;
/** Index of next active position (-1 = end of document) */
std::vector<int32_t> next_pos;
/** Position validity flag (false after being merged into neighbor) */
std::vector<bool> active;
// =========================================================================
// INVERTED INDEX - Pair to Positions Mapping
// =========================================================================
/**
* Maps each unique pair (A, B) to all positions where it appears.
* Key: pair<token_a, token_b>
* Value: vector of starting positions (indices into tokens[])
*
* This is the secret sauce - enables O(1) lookup of merge sites
* instead of O(N) scanning.
*/
std::unordered_map<
std::pair<int, int>,
std::vector<int>,
PairHash,
PairEqual
> pair_locations;
/**
* Current frequency count for each pair.
* Updated incrementally during merges - never rescanned.
*/
std::unordered_map<
std::pair<int, int>,
int,
PairHash,
PairEqual
> pair_counts;
// =========================================================================
// LAZY MAX-HEAP - Always Returns Highest Frequency Pair
// =========================================================================
/**
* Priority queue storing {count, pair} ordered by count descending.
*
* "Lazy" Design:
* - We push new entries when counts increase
* - We DON'T remove entries when counts decrease
* - On pop, we validate count against pair_counts map
* - Stale entries (heap count != map count) are discarded
*
* Why Lazy?
* - Removing arbitrary elements from heap is O(N)
* - lazy validation on pop is O(1) average case
* - Total overhead is bounded by O(M * K_avg) extra pops
*/
std::priority_queue<
std::pair<int, std::pair<int, int>>
> heap;
// =========================================================================
// STATISTICS TRACKING
// =========================================================================
TrainingStats stats;
// =========================================================================
// MINIMUM FREQUENCY THRESHOLD
// =========================================================================
int min_frequency = 2;
public:
// =========================================================================
// CONSTRUCTOR - Initializes All Data Structures
// =========================================================================
/**
* Initialize trainer with raw byte corpus.
*
* @param raw_bytes Pointer to corpus bytes
* @param len Length of corpus in bytes
* @param min_freq Minimum frequency threshold (default 2)
*/
CrayonTrainer(const char* raw_bytes, size_t len, int min_freq = 2)
: min_frequency(min_freq)
{
auto start_time = std::chrono::high_resolution_clock::now();
stats.corpus_size = len;
if (len == 0) {
return;
}
// ---------------------------------------------------------------------
// PHASE 1: Allocate Memory (Single Allocation for Each Array)
// ---------------------------------------------------------------------
// Reserve upfront to avoid reallocations during filling
tokens.reserve(len);
prev_pos.reserve(len);
next_pos.reserve(len);
active.resize(len, true); // All positions start active
// Pre-size hash maps based on expected unique pairs
// Heuristic: sqrt(len) unique pairs is typical for natural text
size_t estimated_unique_pairs = std::min(len, (size_t)1000000);
pair_counts.reserve(estimated_unique_pairs);
pair_locations.reserve(estimated_unique_pairs);
// ---------------------------------------------------------------------
// PHASE 2: Initialize Linked List from Bytes
// ---------------------------------------------------------------------
// Each byte becomes an initial token (0-255)
// Linked list connects sequential positions
for (size_t i = 0; i < len; ++i) {
// Store byte value as token ID (0-255 for initial vocab)
tokens.push_back(static_cast<unsigned char>(raw_bytes[i]));
// Link to previous position (or -1 for first position)
prev_pos.push_back(static_cast<int32_t>(i) - 1);
// Link to next position (placeholder, fixed below)
next_pos.push_back(static_cast<int32_t>(i) + 1);
}
// Fix end-of-list marker
next_pos[len - 1] = -1;
// ---------------------------------------------------------------------
// PHASE 3: Initial Pair Counting (Single Pass)
// ---------------------------------------------------------------------
// Scan once to count all adjacent pairs and record their positions
for (size_t i = 0; i < len - 1; ++i) {
record_pair(static_cast<int>(i));
}
stats.initial_pairs = pair_counts.size();
// ---------------------------------------------------------------------
// PHASE 4: Initialize Heap from Pair Counts
// ---------------------------------------------------------------------
// Push all pairs with count >= min_frequency into the heap
for (const auto& [pair, count] : pair_counts) {
if (count >= min_frequency) {
heap.push({count, pair});
}
}
auto end_time = std::chrono::high_resolution_clock::now();
stats.init_time_ms = std::chrono::duration<double, std::milli>(
end_time - start_time
).count();
}
// =========================================================================
// HELPER: Record a Pair at Given Position
// =========================================================================
/**
* Register the pair starting at position `pos` into our data structures.
* Updates both pair_counts and pair_locations.
*
* @param pos Starting position of the pair (tokens[pos], tokens[next_pos[pos]])
*/
inline void record_pair(int pos) {
// Boundary checks
if (pos == -1 || next_pos[pos] == -1) {
return;
}
// Create pair key
std::pair<int, int> p = {tokens[pos], tokens[next_pos[pos]]};
// Increment count
pair_counts[p]++;
// Record position in inverted index
pair_locations[p].push_back(pos);
}
// =========================================================================
// HELPER: Decrement Pair Count (During Merge)
// =========================================================================
/**
* Decrease count for a pair that is being broken.
* Does NOT update heap (lazy design) or locations (handled elsewhere).
*
* @param p The pair being decremented
*/
inline void decrement_pair(const std::pair<int, int>& p) {
auto it = pair_counts.find(p);
if (it != pair_counts.end() && it->second > 0) {
it->second--;
}
}
// =========================================================================
// MAIN TRAINING LOOP
// =========================================================================
/**
* Execute BPE training to build vocabulary up to target size.
*
* @param vocab_size Target vocabulary size (includes initial 256 byte tokens)
* @return Vector of merge operations: {token_a, token_b, new_token_id}
*/
std::vector<std::tuple<int, int, int>> train(int vocab_size) {
auto start_time = std::chrono::high_resolution_clock::now();
std::vector<std::tuple<int, int, int>> merge_history;
// New token IDs start after byte tokens (0-255)
int next_id = 256;
// Reserve space for expected merges
merge_history.reserve(std::min(vocab_size - 256, (int)heap.size()));
// Track which positions were merged in current iteration
// Used to avoid double-processing
std::unordered_set<int> merged_this_round;
merged_this_round.reserve(1000);
// =====================================================================
// MAIN LOOP: Continue until vocab size reached or heap exhausted
// =====================================================================
while (next_id < vocab_size && !heap.empty()) {
// -----------------------------------------------------------------
// STEP A: Lazy Pop - Get Next Best Pair
// -----------------------------------------------------------------
auto top = heap.top();
heap.pop();
stats.heap_pops++;
int heap_count = top.first;
std::pair<int, int> pair = top.second;
// Validate: Is this count still accurate?
// If heap says 500 but map says 400, this is stale - skip it
auto count_it = pair_counts.find(pair);
if (count_it == pair_counts.end() || count_it->second != heap_count) {
stats.lazy_skips++;
continue; // Stale entry, try next
}
int real_count = count_it->second;
// Minimum frequency check
if (real_count < min_frequency) {
// No more pairs above threshold - we're done
break;
}
// -----------------------------------------------------------------
// STEP B: Execute Merge
// -----------------------------------------------------------------
int new_token = next_id++;
merge_history.emplace_back(pair.first, pair.second, new_token);
stats.merges_performed++;
// Get all positions where this pair exists
auto& positions = pair_locations[pair];
// Clear the merged tracker for this round
merged_this_round.clear();
// Process each position
for (int pos : positions) {
stats.positions_processed++;
// ---------------------------------------------------------
// VALIDITY CHECKS
// ---------------------------------------------------------
// Check 1: Position still active?
if (!active[pos]) {
continue;
}
// Check 2: Token at position still matches first of pair?
if (tokens[pos] != pair.first) {
continue;
}
// Check 3: Next position valid and still matches second of pair?
int next_idx = next_pos[pos];
if (next_idx == -1 || !active[next_idx]) {
continue;
}
if (tokens[next_idx] != pair.second) {
continue;
}
// Check 4: Not already merged in this round?
if (merged_this_round.count(pos) || merged_this_round.count(next_idx)) {
continue;
}
// ---------------------------------------------------------
// VALID MERGE SITE FOUND
// ---------------------------------------------------------
// We're merging positions [pos] and [next_idx] into [pos]
// Get neighbor positions
int prev_idx = prev_pos[pos];
int next_next_idx = next_pos[next_idx];
// ---------------------------------------------------------
// STEP B.1: Decrement Old Neighbor Pairs
// ---------------------------------------------------------
// Left neighbor: (tokens[prev], tokens[pos]) is being broken
if (prev_idx != -1 && active[prev_idx]) {
std::pair<int, int> old_left = {tokens[prev_idx], tokens[pos]};
decrement_pair(old_left);
}
// Right neighbor: (tokens[next_idx], tokens[next_next]) is being broken
if (next_next_idx != -1 && active[next_next_idx]) {
std::pair<int, int> old_right = {tokens[next_idx], tokens[next_next_idx]};
decrement_pair(old_right);
}
// ---------------------------------------------------------
// STEP B.2: Update Linked List
// ---------------------------------------------------------
// Transform: pos now holds the new merged token
tokens[pos] = new_token;
// Deactivate: next_idx is "consumed" into pos
active[next_idx] = false;
merged_this_round.insert(next_idx);
merged_this_round.insert(pos);
// Rewire pointers to skip next_idx
next_pos[pos] = next_next_idx;
if (next_next_idx != -1) {
prev_pos[next_next_idx] = pos;
}
// ---------------------------------------------------------
// STEP B.3: Create New Neighbor Pairs
// ---------------------------------------------------------
// New left pair: (tokens[prev], new_token)
if (prev_idx != -1 && active[prev_idx]) {
std::pair<int, int> new_left = {tokens[prev_idx], new_token};
pair_counts[new_left]++;
pair_locations[new_left].push_back(prev_idx);
// Push updated count to heap (lazy - might be duplicate)
if (pair_counts[new_left] >= min_frequency) {
heap.push({pair_counts[new_left], new_left});
}
}
// New right pair: (new_token, tokens[next_next])
if (next_next_idx != -1 && active[next_next_idx]) {
std::pair<int, int> new_right = {new_token, tokens[next_next_idx]};
pair_counts[new_right]++;
pair_locations[new_right].push_back(pos);
// Push updated count to heap
if (pair_counts[new_right] >= min_frequency) {
heap.push({pair_counts[new_right], new_right});
}
}
}
// Mark this pair as exhausted
pair_counts[pair] = 0;
}
auto end_time = std::chrono::high_resolution_clock::now();
stats.train_time_ms = std::chrono::duration<double, std::milli>(
end_time - start_time
).count();
stats.total_time_ms = stats.init_time_ms + stats.train_time_ms;
return merge_history;
}
// =========================================================================
// STATISTICS ACCESSOR
// =========================================================================
const TrainingStats& get_stats() const {
return stats;
}
};
// =============================================================================
// 4. PYTHON BINDING - C Extension Interface
// =============================================================================
/**
* train_fast: Python-callable function for BPE training.
*
* Signature: train_fast(corpus: bytes, vocab_size: int, min_freq: int = 2) -> list
*
* @param corpus Raw bytes of training corpus
* @param vocab_size Target vocabulary size
* @param min_freq Minimum pair frequency (optional, default 2)
* @return List of merge tuples: [((token_a, token_b), new_id), ...]
*/
static PyObject* train_fast(PyObject* self, PyObject* args, PyObject* kwargs) {
const char* corpus;
Py_ssize_t corpus_len;
int vocab_size;
int min_freq = 2; // Default minimum frequency
int verbose = 0; // Default: no stats output
static char* kwlist[] = {
(char*)"corpus",
(char*)"vocab_size",
(char*)"min_freq",
(char*)"verbose",
NULL
};
// Parse arguments: bytes, int, optional int, optional int
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "y#i|ii", kwlist,
&corpus, &corpus_len, &vocab_size, &min_freq, &verbose)) {
return NULL;
}
// Validate inputs
if (corpus_len == 0) {
return PyList_New(0); // Empty corpus -> empty merges
}
if (vocab_size <= 256) {
PyErr_SetString(PyExc_ValueError,
"vocab_size must be > 256 (byte tokens occupy 0-255)");
return NULL;
}
if (min_freq < 1) {
PyErr_SetString(PyExc_ValueError, "min_freq must be >= 1");
return NULL;
}
// =========================================================================
// Execute Training (GIL Released for CPU-Bound Work)
// =========================================================================
std::vector<std::tuple<int, int, int>> merges;
TrainingStats stats;
// Release GIL for the CPU-intensive training
Py_BEGIN_ALLOW_THREADS
CrayonTrainer trainer(corpus, static_cast<size_t>(corpus_len), min_freq);
merges = trainer.train(vocab_size);
stats = trainer.get_stats();
Py_END_ALLOW_THREADS
// Print stats if verbose
if (verbose) {
std::cout << "\n=== CRAYON TRAINER STATS ===" << std::endl;
std::cout << "Corpus Size: " << stats.corpus_size << " bytes" << std::endl;
std::cout << "Initial Pairs: " << stats.initial_pairs << std::endl;
std::cout << "Merges Performed: " << stats.merges_performed << std::endl;
std::cout << "Positions Scanned: " << stats.positions_processed << std::endl;
std::cout << "Heap Pops: " << stats.heap_pops << std::endl;
std::cout << "Lazy Skips: " << stats.lazy_skips << std::endl;
std::cout << "Init Time: " << stats.init_time_ms << " ms" << std::endl;
std::cout << "Train Time: " << stats.train_time_ms << " ms" << std::endl;
std::cout << "Total Time: " << stats.total_time_ms << " ms" << std::endl;
std::cout << "===========================\n" << std::endl;
}
// =========================================================================
// Convert Result to Python Objects
// =========================================================================
PyObject* py_list = PyList_New(merges.size());
if (!py_list) {
return NULL;
}
for (size_t i = 0; i < merges.size(); ++i) {
auto& [a, b, new_id] = merges[i];
// Create inner tuple: (token_a, token_b)
PyObject* pair_tuple = PyTuple_Pack(2,
PyLong_FromLong(a),
PyLong_FromLong(b)
);
if (!pair_tuple) {
Py_DECREF(py_list);
return NULL;
}
// Create outer tuple: ((token_a, token_b), new_id)
PyObject* merge_entry = PyTuple_Pack(2,
pair_tuple,
PyLong_FromLong(new_id)
);
// PyTuple_Pack increments refcount, we need to decref pair_tuple
Py_DECREF(pair_tuple);
if (!merge_entry) {
Py_DECREF(py_list);
return NULL;
}
// PyList_SetItem steals reference - don't decref merge_entry
PyList_SetItem(py_list, i, merge_entry);
}
return py_list;
}
/**
* get_version: Returns the trainer version string.
*/
static PyObject* get_version(PyObject* self, PyObject* args) {
return PyUnicode_FromString("2.0.0-hyperfast");
}
/**
* get_algorithm_info: Returns algorithm description.
*/
static PyObject* get_algorithm_info(PyObject* self, PyObject* args) {
return PyUnicode_FromString(
"Linked-List + Inverted Index + Lazy Heap BPE\n"
"Complexity: O(N + M * K_avg * log H)\n"
"where N=corpus, M=merges, K_avg=avg pair freq, H=heap size"
);
}
// =============================================================================
// 5. MODULE DEFINITION
// =============================================================================
static PyMethodDef TrainerMethods[] = {
{
"train_fast",
(PyCFunction)train_fast,
METH_VARARGS | METH_KEYWORDS,
"Hyper-optimized BPE training.\n\n"
"Args:\n"
" corpus (bytes): Raw corpus bytes\n"
" vocab_size (int): Target vocabulary size (> 256)\n"
" min_freq (int, optional): Minimum pair frequency (default 2)\n"
" verbose (int, optional): Print stats (default 0)\n\n"
"Returns:\n"
" list: [((token_a, token_b), new_id), ...] merge operations\n\n"
"Example:\n"
" >>> import crayon_trainer\n"
" >>> with open('corpus.txt', 'rb') as f:\n"
" ... data = f.read()\n"
" >>> merges = crayon_trainer.train_fast(data, 30000)\n"
" >>> print(f'Generated {len(merges)} merge rules')"
},
{
"get_version",
get_version,
METH_NOARGS,
"Get trainer version string."
},
{
"get_algorithm_info",
get_algorithm_info,
METH_NOARGS,
"Get algorithm description."
},
{NULL, NULL, 0, NULL} // Sentinel
};
static struct PyModuleDef trainer_module = {
PyModuleDef_HEAD_INIT,
"crayon_trainer", // Module name
"CRAYON Hyper-Fast BPE Training Engine\n\n" // Docstring
"Implements the mathematically optimal algorithm for\n"
"Exact Greedy BPE on a single CPU core.\n\n"
"Algorithm: Linked-List + Inverted Index + Lazy Heap\n"
"Author: XERV AI Research\n"
"Version: 2.0.0",
-1, // Module state size
TrainerMethods // Method table
};
PyMODINIT_FUNC PyInit_crayon_trainer(void) {
return PyModule_Create(&trainer_module);
}