Spaces:
Sleeping
Sleeping
| import regex as re | |
| import json | |
| from tqdm import tqdm | |
| class StockBPE: | |
| """BPE Tokenizer optimized for stock market time-series data""" | |
| def __init__(self): | |
| self.merges = {} | |
| self.vocab = {} | |
| # OPTIMIZATION: Treat the entire line as a single chunk to allow merging | |
| # labels with delimiters (e.g., "OPEN" + ":" -> "OPEN:") | |
| self.pattern = re.compile(r'[^\n]+|\n') | |
| def get_stats(self, ids): | |
| """Count frequency of adjacent pairs""" | |
| counts = {} | |
| for pair in zip(ids, ids[1:]): | |
| counts[pair] = counts.get(pair, 0) + 1 | |
| return counts | |
| def merge(self, ids, pair, idx): | |
| """Merge all occurrences of a pair""" | |
| newids = [] | |
| i = 0 | |
| while i < len(ids): | |
| if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
| newids.append(idx) | |
| i += 2 | |
| else: | |
| newids.append(ids[i]) | |
| i += 1 | |
| return newids | |
| def train(self, text, vocab_size, verbose=True): | |
| """Train BPE on stock market data""" | |
| assert vocab_size >= 256 | |
| num_merges = vocab_size - 256 | |
| # Pre-tokenize using pattern | |
| text_chunks = re.findall(self.pattern, text) | |
| # Convert to UTF-8 bytes | |
| ids = [list(chunk.encode("utf-8")) for chunk in text_chunks] | |
| # Training loop with progress bar | |
| for i in tqdm(range(num_merges), desc="Training Stock BPE", unit="merge"): | |
| stats = {} | |
| for chunk_ids in ids: | |
| chunk_stats = self.get_stats(chunk_ids) | |
| for pair, count in chunk_stats.items(): | |
| stats[pair] = stats.get(pair, 0) + count | |
| if not stats: | |
| print(f"\nNo more pairs to merge. Stopping at {i} merges.") | |
| break | |
| pair = max(stats, key=stats.get) | |
| idx = 256 + i | |
| # Apply merge | |
| ids = [self.merge(chunk_ids, pair, idx) for chunk_ids in ids] | |
| self.merges[pair] = idx | |
| # Build vocabulary | |
| self.vocab = {idx: bytes([idx]) for idx in range(256)} | |
| for (p0, p1), idx in self.merges.items(): | |
| self.vocab[idx] = self.vocab[p0] + self.vocab[p1] | |
| print(f"\nTraining complete. Final vocab size: {len(self.vocab)}") | |
| def encode(self, text): | |
| """Encode text to token IDs""" | |
| text_chunks = re.findall(self.pattern, text) | |
| ids = [] | |
| for chunk in text_chunks: | |
| chunk_ids = list(chunk.encode("utf-8")) | |
| while len(chunk_ids) >= 2: | |
| stats = self.get_stats(chunk_ids) | |
| pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
| if pair not in self.merges: | |
| break | |
| idx = self.merges[pair] | |
| chunk_ids = self.merge(chunk_ids, pair, idx) | |
| ids.extend(chunk_ids) | |
| return ids | |
| def decode(self, ids): | |
| """Decode token IDs back to text""" | |
| tokens = b"".join(self.vocab[idx] for idx in ids) | |
| return tokens.decode("utf-8", errors="replace") | |
| def save(self, prefix): | |
| """Save tokenizer to files""" | |
| # Save merges | |
| with open(f"{prefix}.merges", "w", encoding="utf-8") as f: | |
| for (p0, p1), idx in self.merges.items(): | |
| f.write(f"{p0} {p1} {idx}\n") | |
| # Save vocab | |
| vocab_str = {idx: token.decode("utf-8", errors="replace") | |
| for idx, token in self.vocab.items()} | |
| with open(f"{prefix}.vocab", "w", encoding="utf-8") as f: | |
| json.dump(vocab_str, f, ensure_ascii=False, indent=2) | |
| def load(self, prefix): | |
| """Load tokenizer from files""" | |
| self.merges = {} | |
| with open(f"{prefix}.merges", "r", encoding="utf-8") as f: | |
| for line in f: | |
| p0, p1, idx = map(int, line.split()) | |
| self.merges[(p0, p1)] = idx | |
| self.vocab = {idx: bytes([idx]) for idx in range(256)} | |
| for (p0, p1), idx in self.merges.items(): | |
| self.vocab[idx] = self.vocab[p0] + self.vocab[p1] | |
| def calculate_compression_ratio(self, text): | |
| """Calculate compression ratio""" | |
| encoded = self.encode(text) | |
| original_bytes = len(text.encode("utf-8")) | |
| compressed_tokens = len(encoded) | |
| return original_bytes / compressed_tokens if compressed_tokens > 0 else 0 | |