File size: 8,395 Bytes
8f586b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969bf17
8f586b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import gradio as gr
import pickle
import json
from typing import List, Tuple, Dict

# Load the trained tokenizer
def load_tokenizer(tokenizer_path="ecoli_bpe_tokenizer.pkl"):
    """Load the trained BPE tokenizer"""
    try:
        with open(tokenizer_path, 'rb') as f:
            tokenizer = pickle.load(f)
        return tokenizer['vocab'], tokenizer['merge_rules']
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        return None, None

# Initialize tokenizer
vocab, merge_rules = load_tokenizer()

def get_pair_counts(tokens: List[str]) -> Dict:
    """Count frequency of adjacent token pairs"""
    from collections import Counter
    pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
    return Counter(pairs)

def merge_pair(tokens: List[str], pair: Tuple[str, str], new_token: str) -> List[str]:
    """Replace all occurrences of pair with new_token"""
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

def encode_sequence(text: str, vocab: Dict[str, int], merge_rules: List[Tuple[str, str]]) -> List[int]:
    """Encode DNA sequence to token IDs"""
    if not text:
        return []
    
    # Validate DNA sequence
    valid_bases = set('ACGTN')
    text = text.upper().strip()
    
    if not all(c in valid_bases for c in text):
        invalid = set(text) - valid_bases
        raise ValueError(f"Invalid characters in sequence: {invalid}. Only A, C, G, T, N allowed.")
    
    # Start with character-level tokenization
    tokens = list(text)
    
    # Apply merge rules in order
    for pair in merge_rules:
        merged_token = pair[0] + pair[1]
        tokens = merge_pair(tokens, pair, merged_token)
    
    # Convert to IDs
    token_ids = []
    for token in tokens:
        if token in vocab:
            token_ids.append(vocab[token])
        else:
            # Shouldn't happen but handle gracefully
            print(f"Warning: Unknown token '{token}'")
    
    return token_ids

def decode_sequence(token_ids: List[int], vocab: Dict[str, int]) -> str:
    """Decode token IDs back to DNA sequence"""
    if not token_ids:
        return ""
    
    # Create reverse vocabulary
    id_to_token = {idx: token for token, idx in vocab.items()}
    
    # Look up each ID
    tokens = []
    for token_id in token_ids:
        if token_id in id_to_token:
            tokens.append(id_to_token[token_id])
        else:
            print(f"Warning: Unknown token ID {token_id}")
    
    # Concatenate
    return ''.join(tokens)

def analyze_sequence(sequence: str):
    """Analyze and encode a DNA sequence"""
    if not sequence or not vocab or not merge_rules:
        return "Please enter a valid DNA sequence.", "", "", ""
    
    try:
        # Encode
        encoded_ids = encode_sequence(sequence, vocab, merge_rules)
        
        # Decode to verify
        decoded = decode_sequence(encoded_ids, vocab)
        
        # Get token strings
        id_to_token = {idx: token for token, idx in vocab.items()}
        token_strings = [id_to_token[tid] for tid in encoded_ids if tid in id_to_token]
        
        # Calculate statistics
        original_length = len(sequence)
        compressed_length = len(encoded_ids)
        compression_ratio = original_length / compressed_length if compressed_length > 0 else 0
        
        # Format results
        stats = f"""
## πŸ“Š Compression Statistics

- **Original Length**: {original_length:,} bases
- **Compressed Length**: {compressed_length:,} tokens
- **Compression Ratio**: {compression_ratio:.3f}x
- **Lossless Check**: {'βœ… PASS' if sequence.upper() == decoded else '❌ FAIL'}

## πŸ”’ Token IDs
```
{encoded_ids[:50]}{'...' if len(encoded_ids) > 50 else ''}
```

## 🧬 Learned Tokens
```
{token_strings[:20]}{'...' if len(token_strings) > 20 else ''}
```
        """
        
        # Format encoded output
        encoded_output = f"Token IDs: {encoded_ids}"
        
        # Format decoded output
        decoded_output = decoded
        
        return stats, encoded_output, decoded_output, "βœ… Success!"
        
    except Exception as e:
        return f"❌ Error: {str(e)}", "", "", f"❌ Error: {str(e)}"

def tokenize_and_display(sequence: str):
    """Main function for the Gradio interface"""
    return analyze_sequence(sequence)

# Example DNA sequences
EXAMPLES = [
    ["ATGAAACGCATTAGCACCACCATTACCACCACCATCA"],  # Random sequence
    ["ATGATGATGATG"],  # Start codon repeats
    ["TATAATATATATAA"],  # TATA box variations
    ["GCGCGCGCGCGC"],  # CpG islands
    ["AAAAAAAAAAAAAAAA"],  # Poly-A tail
    ["AGGAGGTAAATG"],  # Shine-Dalgarno + Start codon
    ["ATGCGGCGTGAACGCCTTATCCGGCC"],  # Longest learned token
    ["AGCTTTTCATTCTGACTGCAACGGGCAATATGTCTCTGTGTGGATTAAAAAAAGAGTGTCTGATAGCAGC"],  # E. coli actual sequence
]

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="🧬 BPE DNA Tokenizer") as demo:
    gr.Markdown("""
    # 🧬 BPE DNA Tokenizer - Interactive Demo
    
    ### Byte Pair Encoding for Genomic Sequences
    
    This tokenizer was trained on the *E. coli* K-12 genome (4.6M base pairs) and achieves **5.208x compression**.
    
    **Key Features:**
    - 🎯 5,000 token vocabulary
    - πŸš€ 5.21x compression ratio
    - πŸ”¬ Discovers biological patterns (codons, TATA boxes, etc.)
    - βœ… 100% lossless encoding/decoding
    
    ---
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            input_sequence = gr.Textbox(
                label="🧬 Input DNA Sequence",
                placeholder="Enter DNA sequence (A, C, G, T, N)...\nExample: ATGAAACGCATTAGC",
                lines=5,
                max_lines=10
            )
            
            submit_btn = gr.Button("πŸ”¬ Tokenize Sequence", variant="primary", size="lg")
            
        with gr.Column(scale=1):
            status_output = gr.Textbox(label="πŸ“Š Status", lines=1)
    
    with gr.Row():
        with gr.Column():
            stats_output = gr.Markdown(label="Statistics")
        
    with gr.Row():
        with gr.Column():
            encoded_output = gr.Textbox(label="πŸ”’ Encoded (Token IDs)", lines=3)
        with gr.Column():
            decoded_output = gr.Textbox(label="🧬 Decoded (Verification)", lines=3)
    
    gr.Markdown("""
    ---
    ## πŸ“š Example Sequences
    
    Click any example below to try it out:
    """)
    
    gr.Examples(
        examples=EXAMPLES,
        inputs=input_sequence,
        outputs=[stats_output, encoded_output, decoded_output, status_output],
        fn=tokenize_and_display,
        cache_examples=True,
        label="Example DNA Sequences"
    )
    
    gr.Markdown("""
    ---
    ## πŸ”¬ Biological Patterns Discovered
    
    The tokenizer automatically learned these biological motifs:
    
    | Pattern | Token ID | Biological Function |
    |---------|----------|---------------------|
    | ATG | 20 | Start codon (translation initiation) |
    | TAA | 25 | Stop codon (translation termination) |
    | TAG | 65 | Stop codon (translation termination) |
    | TATAA | 279 | TATA box (transcription promoter) |
    | AGGAGG | 3642 | Shine-Dalgarno (ribosome binding) |
    | GCGC | 35 | CpG island (gene regulation) |
    
    ---
    
    ## πŸ“Š Model Information
    
    - **Training Dataset**: *E. coli* K-12 (GCF_000005845.2)
    - **Vocabulary Size**: 5,000 tokens
    - **Compression Ratio**: 5.208x
    - **Token Length Range**: 1-26 bases
    - **Training Time**: 88 minutes
    
    ---
    
    ## πŸ”— Links
    
    - [GitHub Repository](https://github.com/abi2024/bpe-dna-tokenizer)
    - [Paper: Neural Machine Translation with Subword Units](https://arxiv.org/abs/1508.07909)
    - [E. coli K-12 Reference](https://www.ncbi.nlm.nih.gov/genome/?term=escherichia+coli+K12)
    
    ---
    
    **Built with 🧬 for genomics and πŸ€– for machine learning**
    """)
    
    # Set up the button click
    submit_btn.click(
        fn=tokenize_and_display,
        inputs=input_sequence,
        outputs=[stats_output, encoded_output, decoded_output, status_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()