File size: 4,704 Bytes
28c5847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import regex as re
import json
from tqdm import tqdm

class StockBPE:
    """BPE Tokenizer optimized for stock market time-series data"""
    
    def __init__(self):
        self.merges = {}
        self.vocab = {}
        # OPTIMIZATION: Treat the entire line as a single chunk to allow merging 
        # labels with delimiters (e.g., "OPEN" + ":" -> "OPEN:")
        self.pattern = re.compile(r'[^\n]+|\n')
    
    def get_stats(self, ids):
        """Count frequency of adjacent pairs"""
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts
    
    def merge(self, ids, pair, idx):
        """Merge all occurrences of a pair"""
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)
                i += 2
            else:
                newids.append(ids[i])
                i += 1
        return newids
    
    def train(self, text, vocab_size, verbose=True):
        """Train BPE on stock market data"""
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        
        # Pre-tokenize using pattern
        text_chunks = re.findall(self.pattern, text)
        
        # Convert to UTF-8 bytes
        ids = [list(chunk.encode("utf-8")) for chunk in text_chunks]
        
        # Training loop with progress bar
        for i in tqdm(range(num_merges), desc="Training Stock BPE", unit="merge"):
            stats = {}
            for chunk_ids in ids:
                chunk_stats = self.get_stats(chunk_ids)
                for pair, count in chunk_stats.items():
                    stats[pair] = stats.get(pair, 0) + count
            
            if not stats:
                print(f"\nNo more pairs to merge. Stopping at {i} merges.")
                break
            
            pair = max(stats, key=stats.get)
            idx = 256 + i
            
            # Apply merge
            ids = [self.merge(chunk_ids, pair, idx) for chunk_ids in ids]
            
            self.merges[pair] = idx
        
        # Build vocabulary
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]
        
        print(f"\nTraining complete. Final vocab size: {len(self.vocab)}")
    
    def encode(self, text):
        """Encode text to token IDs"""
        text_chunks = re.findall(self.pattern, text)
        ids = []
        for chunk in text_chunks:
            chunk_ids = list(chunk.encode("utf-8"))
            while len(chunk_ids) >= 2:
                stats = self.get_stats(chunk_ids)
                pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
                if pair not in self.merges:
                    break
                idx = self.merges[pair]
                chunk_ids = self.merge(chunk_ids, pair, idx)
            ids.extend(chunk_ids)
        return ids
    
    def decode(self, ids):
        """Decode token IDs back to text"""
        tokens = b"".join(self.vocab[idx] for idx in ids)
        return tokens.decode("utf-8", errors="replace")
    
    def save(self, prefix):
        """Save tokenizer to files"""
        # Save merges
        with open(f"{prefix}.merges", "w", encoding="utf-8") as f:
            for (p0, p1), idx in self.merges.items():
                f.write(f"{p0} {p1} {idx}\n")
        
        # Save vocab
        vocab_str = {idx: token.decode("utf-8", errors="replace") 
                     for idx, token in self.vocab.items()}
        with open(f"{prefix}.vocab", "w", encoding="utf-8") as f:
            json.dump(vocab_str, f, ensure_ascii=False, indent=2)
    
    def load(self, prefix):
        """Load tokenizer from files"""
        self.merges = {}
        with open(f"{prefix}.merges", "r", encoding="utf-8") as f:
            for line in f:
                p0, p1, idx = map(int, line.split())
                self.merges[(p0, p1)] = idx
        
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]
    
    def calculate_compression_ratio(self, text):
        """Calculate compression ratio"""
        encoded = self.encode(text)
        original_bytes = len(text.encode("utf-8"))
        compressed_tokens = len(encoded)
        return original_bytes / compressed_tokens if compressed_tokens > 0 else 0