|
|
use std::cmp::Ordering; |
|
|
use std::collections::HashMap as StdHashMap; |
|
|
|
|
|
use dary_heap::OctonaryHeap; |
|
|
use fancy_regex::Regex; |
|
|
use pyo3::prelude::*; |
|
|
|
|
|
use ahash::{AHashMap, AHashSet}; |
|
|
use compact_str::CompactString; |
|
|
use rayon::prelude::*; |
|
|
|
|
|
|
|
|
const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; |
|
|
|
|
|
type Pair = (u32, u32); |
|
|
|
|
|
|
|
|
#[pyclass] |
|
|
pub struct Tokenizer { |
|
|
|
|
|
pub merges: StdHashMap<Pair, u32>, |
|
|
|
|
|
pub pattern: String, |
|
|
|
|
|
compiled_pattern: Regex, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Debug)] |
|
|
struct Word { |
|
|
ids: Vec<u32>, |
|
|
} |
|
|
|
|
|
impl Word { |
|
|
#[inline] |
|
|
fn new(ids: Vec<u32>) -> Self { |
|
|
Self { ids } |
|
|
} |
|
|
|
|
|
#[inline] |
|
|
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a { |
|
|
self.ids.windows(2).map(|w| (w[0], w[1])) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> { |
|
|
let (a, b) = pair; |
|
|
let n = self.ids.len(); |
|
|
if n < 2 { |
|
|
return Vec::new(); |
|
|
} |
|
|
|
|
|
let mut out: Vec<u32> = Vec::with_capacity(n); |
|
|
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6); |
|
|
|
|
|
let mut i = 0; |
|
|
while i < n { |
|
|
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b { |
|
|
let left = out.last().copied(); |
|
|
let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None }; |
|
|
|
|
|
|
|
|
if let Some(x) = left { |
|
|
deltas.push(((x, a), -1)); |
|
|
deltas.push(((x, new_id), 1)); |
|
|
} |
|
|
deltas.push(((a, b), -1)); |
|
|
if let Some(y) = right { |
|
|
deltas.push(((b, y), -1)); |
|
|
deltas.push(((new_id, y), 1)); |
|
|
} |
|
|
|
|
|
|
|
|
out.push(new_id); |
|
|
i += 2; |
|
|
} else { |
|
|
out.push(self.ids[i]); |
|
|
i += 1; |
|
|
} |
|
|
} |
|
|
|
|
|
self.ids = out; |
|
|
deltas |
|
|
} |
|
|
} |
|
|
|
|
|
#[derive(Debug, Eq)] |
|
|
struct MergeJob { |
|
|
pair: Pair, |
|
|
count: u64, |
|
|
|
|
|
pos: AHashSet<usize>, |
|
|
} |
|
|
|
|
|
impl PartialEq for MergeJob { |
|
|
fn eq(&self, other: &Self) -> bool { |
|
|
self.count == other.count && self.pair == other.pair |
|
|
} |
|
|
} |
|
|
|
|
|
impl PartialOrd for MergeJob { |
|
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
|
|
Some(self.cmp(other)) |
|
|
} |
|
|
} |
|
|
|
|
|
impl Ord for MergeJob { |
|
|
fn cmp(&self, other: &Self) -> Ordering { |
|
|
|
|
|
if self.count != other.count { |
|
|
self.count.cmp(&other.count) |
|
|
} else { |
|
|
|
|
|
other.pair.cmp(&self.pair) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
#[inline] |
|
|
fn count_pairs_parallel( |
|
|
words: &[Word], |
|
|
counts: &[i32], |
|
|
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) { |
|
|
words |
|
|
.par_iter() |
|
|
.enumerate() |
|
|
.map(|(i, w)| { |
|
|
let mut local_pc: AHashMap<Pair, i32> = AHashMap::new(); |
|
|
let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new(); |
|
|
if w.ids.len() >= 2 && counts[i] != 0 { |
|
|
for (a, b) in w.pairs() { |
|
|
*local_pc.entry((a, b)).or_default() += counts[i]; |
|
|
local_wtu.entry((a, b)).or_default().insert(i); |
|
|
} |
|
|
} |
|
|
(local_pc, local_wtu) |
|
|
}) |
|
|
.reduce( |
|
|
|| (AHashMap::new(), AHashMap::new()), |
|
|
|(mut acc_pc, mut acc_wtu), (pc, wtu)| { |
|
|
for (k, v) in pc { |
|
|
*acc_pc.entry(k).or_default() += v; |
|
|
} |
|
|
for (k, s) in wtu { |
|
|
acc_wtu.entry(k).or_default().extend(s); |
|
|
} |
|
|
(acc_pc, acc_wtu) |
|
|
}, |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
impl Tokenizer { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) { |
|
|
assert!(vocab_size >= 256, "vocab_size must be at least 256"); |
|
|
let num_merges = vocab_size - 256; |
|
|
log::info!("Starting BPE training: {} merges to compute", num_merges); |
|
|
self.merges.clear(); |
|
|
|
|
|
|
|
|
log::info!("Computing initial pair counts from {} unique sequences", words.len()); |
|
|
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts); |
|
|
|
|
|
|
|
|
log::info!("Building heap with {} unique pairs", pair_counts.len()); |
|
|
let mut heap = OctonaryHeap::with_capacity(pair_counts.len()); |
|
|
for (pair, pos) in where_to_update.drain() { |
|
|
let c = *pair_counts.get(&pair).unwrap_or(&0); |
|
|
if c > 0 { |
|
|
heap.push(MergeJob { |
|
|
pair, |
|
|
count: c as u64, |
|
|
pos, |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
log::info!("Starting merge loop"); |
|
|
let mut merges_done = 0u32; |
|
|
let mut last_log_percent = 0u32; |
|
|
|
|
|
while merges_done < num_merges { |
|
|
let Some(mut top) = heap.pop() else { break; }; |
|
|
|
|
|
|
|
|
let current = *pair_counts.get(&top.pair).unwrap_or(&0); |
|
|
if top.count != current as u64 { |
|
|
top.count = current as u64; |
|
|
if top.count > 0 { |
|
|
heap.push(top); |
|
|
} |
|
|
continue; |
|
|
} |
|
|
if top.count == 0 { |
|
|
break; |
|
|
} |
|
|
|
|
|
|
|
|
let new_id = 256 + merges_done; |
|
|
self.merges.insert(top.pair, new_id); |
|
|
|
|
|
|
|
|
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new(); |
|
|
for &word_idx in &top.pos { |
|
|
|
|
|
let changes = words[word_idx].merge_pair(top.pair, new_id); |
|
|
|
|
|
for (pair, delta) in changes { |
|
|
let delta_total = delta * counts[word_idx]; |
|
|
if delta_total != 0 { |
|
|
*pair_counts.entry(pair).or_default() += delta_total; |
|
|
if delta > 0 { |
|
|
local_pos_updates.entry(pair).or_default().insert(word_idx); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (pair, pos) in local_pos_updates { |
|
|
let cnt = *pair_counts.get(&pair).unwrap_or(&0); |
|
|
if cnt > 0 { |
|
|
heap.push(MergeJob { |
|
|
pair, |
|
|
count: cnt as u64, |
|
|
pos, |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
merges_done += 1; |
|
|
|
|
|
|
|
|
let current_percent = (merges_done * 100) / num_merges; |
|
|
if current_percent > last_log_percent { |
|
|
log::info!( |
|
|
"Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})", |
|
|
current_percent, merges_done, num_merges, top.pair, new_id, top.count |
|
|
); |
|
|
last_log_percent = current_percent; |
|
|
} |
|
|
} |
|
|
|
|
|
log::info!("Finished training: {} merges completed", merges_done); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[pymethods] |
|
|
impl Tokenizer { |
|
|
|
|
|
#[new] |
|
|
pub fn new() -> Self { |
|
|
Self { |
|
|
merges: StdHashMap::new(), |
|
|
pattern: String::new(), |
|
|
compiled_pattern: Regex::new("").expect("Empty regex should be valid"), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))] |
|
|
#[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")] |
|
|
pub fn train_from_iterator( |
|
|
&mut self, |
|
|
py: pyo3::Python<'_>, |
|
|
iterator: &pyo3::Bound<'_, pyo3::PyAny>, |
|
|
vocab_size: u32, |
|
|
buffer_size: usize, |
|
|
pattern: Option<String>, |
|
|
) -> PyResult<()> { |
|
|
|
|
|
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string()); |
|
|
|
|
|
|
|
|
self.pattern = pattern_str.clone(); |
|
|
self.compiled_pattern = Regex::new(&pattern_str) |
|
|
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?; |
|
|
|
|
|
|
|
|
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe { |
|
|
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))? |
|
|
}; |
|
|
|
|
|
|
|
|
let mut counts: AHashMap<CompactString, i32> = AHashMap::new(); |
|
|
|
|
|
|
|
|
let mut buf: Vec<String> = Vec::with_capacity(buffer_size); |
|
|
|
|
|
log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size); |
|
|
let mut total_sequences = 0u64; |
|
|
|
|
|
|
|
|
|
|
|
let refill = |buf: &mut Vec<String>| -> PyResult<bool> { |
|
|
pyo3::Python::with_gil(|py| { |
|
|
buf.clear(); |
|
|
let it = py_iter.bind(py); |
|
|
loop { |
|
|
if buf.len() >= buffer_size { |
|
|
return Ok(false); |
|
|
} |
|
|
|
|
|
let next_obj = unsafe { |
|
|
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr())) |
|
|
}; |
|
|
match next_obj { |
|
|
Some(obj) => { |
|
|
let s: String = obj.extract()?; |
|
|
buf.push(s); |
|
|
} |
|
|
None => { |
|
|
if pyo3::PyErr::occurred(py) { |
|
|
return Err(pyo3::PyErr::fetch(py)); |
|
|
} else { |
|
|
return Ok(true); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
}) |
|
|
}; |
|
|
|
|
|
|
|
|
loop { |
|
|
let exhausted = refill(&mut buf)?; |
|
|
if buf.is_empty() && exhausted { |
|
|
break; |
|
|
} |
|
|
|
|
|
total_sequences += buf.len() as u64; |
|
|
|
|
|
let pattern = self.compiled_pattern.clone(); |
|
|
let local: AHashMap<CompactString, i32> = py.allow_threads(|| { |
|
|
buf.par_iter() |
|
|
.map(|s| { |
|
|
let mut m: AHashMap<CompactString, i32> = AHashMap::new(); |
|
|
for mat in pattern.find_iter(s) { |
|
|
let piece = mat.expect("regex match failed").as_str(); |
|
|
*m.entry(CompactString::from(piece)).or_default() += 1; |
|
|
} |
|
|
m |
|
|
}) |
|
|
.reduce( |
|
|
|| AHashMap::new(), |
|
|
|mut a, b| { |
|
|
for (k, v) in b { |
|
|
*a.entry(k).or_default() += v; |
|
|
} |
|
|
a |
|
|
}, |
|
|
) |
|
|
}); |
|
|
|
|
|
|
|
|
for (k, v) in local { |
|
|
*counts.entry(k).or_default() += v; |
|
|
} |
|
|
|
|
|
if exhausted { |
|
|
break; |
|
|
} |
|
|
} |
|
|
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len()); |
|
|
|
|
|
|
|
|
let mut words = Vec::with_capacity(counts.len()); |
|
|
let mut cvec = Vec::with_capacity(counts.len()); |
|
|
for (chunk, c) in counts.into_iter() { |
|
|
words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect())); |
|
|
cvec.push(c); |
|
|
} |
|
|
|
|
|
self.train_core_incremental(words, cvec, vocab_size); |
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn get_pattern(&self) -> String { |
|
|
self.pattern.clone() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> { |
|
|
let mut mergeable_ranks = Vec::new(); |
|
|
|
|
|
|
|
|
let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect(); |
|
|
|
|
|
for (i, bytes) in token_bytes.iter().enumerate() { |
|
|
mergeable_ranks.push((bytes.clone(), i as u32)); |
|
|
} |
|
|
|
|
|
|
|
|
let mut sorted_merges: Vec<_> = self.merges.iter().collect(); |
|
|
sorted_merges.sort_by_key(|&(_, &token_id)| token_id); |
|
|
|
|
|
for (&pair, &merged_id) in sorted_merges { |
|
|
let (left, right) = pair; |
|
|
let mut merged_bytes = token_bytes[left as usize].clone(); |
|
|
merged_bytes.extend(&token_bytes[right as usize]); |
|
|
|
|
|
if token_bytes.len() <= merged_id as usize { |
|
|
token_bytes.resize(merged_id as usize + 1, Vec::new()); |
|
|
} |
|
|
token_bytes[merged_id as usize] = merged_bytes.clone(); |
|
|
|
|
|
mergeable_ranks.push((merged_bytes, merged_id)); |
|
|
} |
|
|
|
|
|
mergeable_ranks |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode(&self, text: &str) -> Vec<u32> { |
|
|
let mut all_ids = Vec::new(); |
|
|
|
|
|
|
|
|
for m in self.compiled_pattern.find_iter(text) { |
|
|
let chunk = m.expect("regex match failed").as_str(); |
|
|
|
|
|
|
|
|
let mut ids: Vec<u32> = chunk.bytes().map(|b| b as u32).collect(); |
|
|
|
|
|
|
|
|
while ids.len() >= 2 { |
|
|
|
|
|
let mut best_pair: Option<(usize, Pair, u32)> = None; |
|
|
|
|
|
for i in 0..ids.len() - 1 { |
|
|
let pair: Pair = (ids[i], ids[i + 1]); |
|
|
if let Some(&new_id) = self.merges.get(&pair) { |
|
|
if best_pair.is_none() || new_id < best_pair.unwrap().2 { |
|
|
best_pair = Some((i, pair, new_id)); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if let Some((idx, _pair, new_id)) = best_pair { |
|
|
ids[idx] = new_id; |
|
|
ids.remove(idx + 1); |
|
|
} else { |
|
|
|
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
all_ids.extend(ids); |
|
|
} |
|
|
|
|
|
all_ids |
|
|
} |
|
|
} |
|
|
|
|
|
#[pymodule] |
|
|
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> { |
|
|
pyo3_log::init(); |
|
|
m.add_class::<Tokenizer>()?; |
|
|
Ok(()) |
|
|
} |
|
|
|