mythos-rdt / shared /tokenizer.py
Raidone's picture
MYTHOS-RDT — Recurrent-Depth Transformer. بسم الله الرحمن الرحيم
4cf6c82 verified
Raw
History Blame Contribute Delete
7.62 kB
"""
Shared Tokenizer for Raid Models — BPE from scratch
Trained on Italian + code corpus
"""
import json, os, regex as re
from collections import defaultdict
# GPT-2 style BPE tokenizer (no external dependencies)
# Using byte-level BPE with pre-tokenization
VOCAB_SIZE = 16384
SPECIAL_TOKENS = {
"<|pad|>": 0,
"<|bos|>": 1,
"<|eos|>": 2,
"<|unk|>": 3,
"<|im_start|>": 4,
"<|im_end|>": 5,
"<|routing|>": 6,
"<|tool_call|>": 7,
"<|tool_response|>": 8,
}
GPT2_PAT = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def get_pairs(word):
"""Return set of symbol pairs in a word"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class RaidTokenizer:
"""BPE Tokenizer for RAI models — built from scratch"""
def __init__(self):
self.merges = {} # (int, int) -> int
self.vocab = {} # int -> bytes
self.special_tokens = SPECIAL_TOKENS
self.pat = GPT2_PAT
self.bos_token_id = 1
self.eos_token_id = 2
self.pad_token_id = 0
def train(self, texts: list[str], vocab_size: int = VOCAB_SIZE):
"""Train BPE tokenizer on texts"""
print(f"[TOKENIZER] Training BPE (vocab={vocab_size}) on {len(texts)} texts...")
# Initialize with byte-level tokens (0-255)
vocab = {i: bytes([i]) for i in range(256)}
merges = {}
# Split into words with pre-tokenization
word_freqs = defaultdict(int)
for text in texts:
words = re.findall(self.pat, text)
for word in words:
word_freqs[tuple(word.encode('utf-8'))] += 1
# Also add individual byte tokens
for b in range(256):
word_freqs[(b,)] += 1
num_merges = vocab_size - 256 - len(SPECIAL_TOKENS)
for i in range(num_merges):
# Count pairs
pairs = defaultdict(int)
for word, freq in word_freqs.items():
if len(word) < 2: continue
for pair in get_pairs(word):
pairs[pair] += freq
if not pairs:
break
best_pair = max(pairs, key=pairs.get)
new_token_id = 256 + i + len(SPECIAL_TOKENS)
# Merge the best pair
merges[best_pair] = new_token_id
vocab[new_token_id] = vocab[best_pair[0]] + vocab[best_pair[1]]
# Update word frequencies after merge
new_word_freqs = defaultdict(int)
for word, freq in word_freqs.items():
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and (word[i], word[i+1]) == best_pair:
new_word.append(new_token_id)
i += 2
else:
new_word.append(word[i])
i += 1
new_word_freqs[tuple(new_word)] += freq
word_freqs = new_word_freqs
if (i + 1) % 500 == 0:
print(f" Merge {i+1}/{num_merges}: {best_pair} -> '{(vocab[best_pair[0]] + vocab[best_pair[1]]).decode('utf-8', errors='replace')}'")
self.merges = merges
self.vocab = vocab
self._build_reverse_vocab()
print(f"[TOKENIZER] Done: {len(self.vocab)} tokens")
def _build_reverse_vocab(self):
self.token_to_bytes = {v: k for k, v in self.vocab.items()}
def encode(self, text: str) -> list[int]:
"""Encode text to token IDs"""
tokens = [self.bos_token_id]
words = re.findall(self.pat, text)
for word in words:
word_tokens = list(word.encode('utf-8'))
while len(word_tokens) >= 2:
# Find the lowest-ranked merge
pairs = get_pairs(tuple(word_tokens))
best_rank = float('inf')
best_pair = None
for pair in pairs:
rank = self.merges.get(pair, float('inf'))
if rank < best_rank:
best_rank = rank
best_pair = pair
if best_pair is None:
break
# Merge
new_tokens = []
i = 0
while i < len(word_tokens):
if i < len(word_tokens) - 1 and (word_tokens[i], word_tokens[i+1]) == best_pair:
new_tokens.append(self.merges[best_pair])
i += 2
else:
new_tokens.append(word_tokens[i])
i += 1
word_tokens = new_tokens
tokens.extend(word_tokens)
tokens.append(self.eos_token_id)
return tokens
def decode(self, ids: list[int]) -> str:
"""Decode token IDs to text"""
text_bytes = b""
for tid in ids:
if tid in self.special_tokens.values():
continue
if tid in self.vocab:
text_bytes += self.vocab[tid]
return text_bytes.decode('utf-8', errors='replace')
def save(self, path: str):
"""Save tokenizer to disk"""
data = {
"merges": {f"{k[0]},{k[1]}": v for k, v in self.merges.items()},
"vocab_size": len(self.vocab),
}
with open(path, 'w') as f:
json.dump(data, f)
print(f"[TOKENIZER] Saved to {path}")
def load(self, path: str):
"""Load tokenizer from disk"""
with open(path, 'r') as f:
data = json.load(f)
self.merges = {tuple(map(int, k.split(','))): v for k, v in data["merges"].items()}
# Rebuild vocab from merges
self.vocab = {i: bytes([i]) for i in range(256)}
for (a, b), new_id in self.merges.items():
self.vocab[new_id] = self.vocab[a] + self.vocab[b]
# Add special tokens
for name, tid in SPECIAL_TOKENS.items():
self.vocab[tid] = name.encode('utf-8')
self._build_reverse_vocab()
def __len__(self):
return len(self.vocab)
# ============================================================
# Quick test with Italian corpus
# ============================================================
if __name__ == "__main__":
# Sample Italian texts for initial training
corpus = [
"Ciao, sono Raiai 0.1, l'orchestratore dell'ecosistema Raid.",
"Devo coordinare gli agenti Raiax, Raikai e Raiops per completare il task.",
"Il piano di orchestrazione prevede l'analisi preliminare del problema.",
"La proprietà intellettuale di Raid1969/// deve essere protetta.",
"Ecco il workflow: 1) Analisi 2) Delega 3) Verifica 4) Report finale.",
"I modelli dell'ecosistema Raid sono addestrati su hardware locale.",
"L'architettura Transformer con Grouped Query Attention è ottimizzata.",
"import torch; model = RaiaiModel(); output = model.generate(input_ids)",
"Il routing intelligente distribuisce i task agli agenti specializzati.",
] * 100 # Enough for basic BPE
tokenizer = RaidTokenizer()
tokenizer.train(corpus, vocab_size=2048) # Small vocab for test
# Test encode/decode
text = "Ciao, sono Raiai 0.1! Coordino Raiax e Raikai."
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)
print(f"\nTest: '{text}'")
print(f" Tokens: {ids}")
print(f" Count: {len(ids)}")
print(f" Decoded: '{decoded}'")