abi96062's picture
Update app.py
969bf17 verified
import gradio as gr
import pickle
import json
from typing import List, Tuple, Dict
# Load the trained tokenizer
def load_tokenizer(tokenizer_path="ecoli_bpe_tokenizer.pkl"):
"""Load the trained BPE tokenizer"""
try:
with open(tokenizer_path, 'rb') as f:
tokenizer = pickle.load(f)
return tokenizer['vocab'], tokenizer['merge_rules']
except Exception as e:
print(f"Error loading tokenizer: {e}")
return None, None
# Initialize tokenizer
vocab, merge_rules = load_tokenizer()
def get_pair_counts(tokens: List[str]) -> Dict:
"""Count frequency of adjacent token pairs"""
from collections import Counter
pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
return Counter(pairs)
def merge_pair(tokens: List[str], pair: Tuple[str, str], new_token: str) -> List[str]:
"""Replace all occurrences of pair with new_token"""
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
new_tokens.append(new_token)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
return new_tokens
def encode_sequence(text: str, vocab: Dict[str, int], merge_rules: List[Tuple[str, str]]) -> List[int]:
"""Encode DNA sequence to token IDs"""
if not text:
return []
# Validate DNA sequence
valid_bases = set('ACGTN')
text = text.upper().strip()
if not all(c in valid_bases for c in text):
invalid = set(text) - valid_bases
raise ValueError(f"Invalid characters in sequence: {invalid}. Only A, C, G, T, N allowed.")
# Start with character-level tokenization
tokens = list(text)
# Apply merge rules in order
for pair in merge_rules:
merged_token = pair[0] + pair[1]
tokens = merge_pair(tokens, pair, merged_token)
# Convert to IDs
token_ids = []
for token in tokens:
if token in vocab:
token_ids.append(vocab[token])
else:
# Shouldn't happen but handle gracefully
print(f"Warning: Unknown token '{token}'")
return token_ids
def decode_sequence(token_ids: List[int], vocab: Dict[str, int]) -> str:
"""Decode token IDs back to DNA sequence"""
if not token_ids:
return ""
# Create reverse vocabulary
id_to_token = {idx: token for token, idx in vocab.items()}
# Look up each ID
tokens = []
for token_id in token_ids:
if token_id in id_to_token:
tokens.append(id_to_token[token_id])
else:
print(f"Warning: Unknown token ID {token_id}")
# Concatenate
return ''.join(tokens)
def analyze_sequence(sequence: str):
"""Analyze and encode a DNA sequence"""
if not sequence or not vocab or not merge_rules:
return "Please enter a valid DNA sequence.", "", "", ""
try:
# Encode
encoded_ids = encode_sequence(sequence, vocab, merge_rules)
# Decode to verify
decoded = decode_sequence(encoded_ids, vocab)
# Get token strings
id_to_token = {idx: token for token, idx in vocab.items()}
token_strings = [id_to_token[tid] for tid in encoded_ids if tid in id_to_token]
# Calculate statistics
original_length = len(sequence)
compressed_length = len(encoded_ids)
compression_ratio = original_length / compressed_length if compressed_length > 0 else 0
# Format results
stats = f"""
## πŸ“Š Compression Statistics
- **Original Length**: {original_length:,} bases
- **Compressed Length**: {compressed_length:,} tokens
- **Compression Ratio**: {compression_ratio:.3f}x
- **Lossless Check**: {'βœ… PASS' if sequence.upper() == decoded else '❌ FAIL'}
## πŸ”’ Token IDs
```
{encoded_ids[:50]}{'...' if len(encoded_ids) > 50 else ''}
```
## 🧬 Learned Tokens
```
{token_strings[:20]}{'...' if len(token_strings) > 20 else ''}
```
"""
# Format encoded output
encoded_output = f"Token IDs: {encoded_ids}"
# Format decoded output
decoded_output = decoded
return stats, encoded_output, decoded_output, "βœ… Success!"
except Exception as e:
return f"❌ Error: {str(e)}", "", "", f"❌ Error: {str(e)}"
def tokenize_and_display(sequence: str):
"""Main function for the Gradio interface"""
return analyze_sequence(sequence)
# Example DNA sequences
EXAMPLES = [
["ATGAAACGCATTAGCACCACCATTACCACCACCATCA"], # Random sequence
["ATGATGATGATG"], # Start codon repeats
["TATAATATATATAA"], # TATA box variations
["GCGCGCGCGCGC"], # CpG islands
["AAAAAAAAAAAAAAAA"], # Poly-A tail
["AGGAGGTAAATG"], # Shine-Dalgarno + Start codon
["ATGCGGCGTGAACGCCTTATCCGGCC"], # Longest learned token
["AGCTTTTCATTCTGACTGCAACGGGCAATATGTCTCTGTGTGGATTAAAAAAAGAGTGTCTGATAGCAGC"], # E. coli actual sequence
]
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="🧬 BPE DNA Tokenizer") as demo:
gr.Markdown("""
# 🧬 BPE DNA Tokenizer - Interactive Demo
### Byte Pair Encoding for Genomic Sequences
This tokenizer was trained on the *E. coli* K-12 genome (4.6M base pairs) and achieves **5.208x compression**.
**Key Features:**
- 🎯 5,000 token vocabulary
- πŸš€ 5.21x compression ratio
- πŸ”¬ Discovers biological patterns (codons, TATA boxes, etc.)
- βœ… 100% lossless encoding/decoding
---
""")
with gr.Row():
with gr.Column(scale=2):
input_sequence = gr.Textbox(
label="🧬 Input DNA Sequence",
placeholder="Enter DNA sequence (A, C, G, T, N)...\nExample: ATGAAACGCATTAGC",
lines=5,
max_lines=10
)
submit_btn = gr.Button("πŸ”¬ Tokenize Sequence", variant="primary", size="lg")
with gr.Column(scale=1):
status_output = gr.Textbox(label="πŸ“Š Status", lines=1)
with gr.Row():
with gr.Column():
stats_output = gr.Markdown(label="Statistics")
with gr.Row():
with gr.Column():
encoded_output = gr.Textbox(label="πŸ”’ Encoded (Token IDs)", lines=3)
with gr.Column():
decoded_output = gr.Textbox(label="🧬 Decoded (Verification)", lines=3)
gr.Markdown("""
---
## πŸ“š Example Sequences
Click any example below to try it out:
""")
gr.Examples(
examples=EXAMPLES,
inputs=input_sequence,
outputs=[stats_output, encoded_output, decoded_output, status_output],
fn=tokenize_and_display,
cache_examples=True,
label="Example DNA Sequences"
)
gr.Markdown("""
---
## πŸ”¬ Biological Patterns Discovered
The tokenizer automatically learned these biological motifs:
| Pattern | Token ID | Biological Function |
|---------|----------|---------------------|
| ATG | 20 | Start codon (translation initiation) |
| TAA | 25 | Stop codon (translation termination) |
| TAG | 65 | Stop codon (translation termination) |
| TATAA | 279 | TATA box (transcription promoter) |
| AGGAGG | 3642 | Shine-Dalgarno (ribosome binding) |
| GCGC | 35 | CpG island (gene regulation) |
---
## πŸ“Š Model Information
- **Training Dataset**: *E. coli* K-12 (GCF_000005845.2)
- **Vocabulary Size**: 5,000 tokens
- **Compression Ratio**: 5.208x
- **Token Length Range**: 1-26 bases
- **Training Time**: 88 minutes
---
## πŸ”— Links
- [GitHub Repository](https://github.com/abi2024/bpe-dna-tokenizer)
- [Paper: Neural Machine Translation with Subword Units](https://arxiv.org/abs/1508.07909)
- [E. coli K-12 Reference](https://www.ncbi.nlm.nih.gov/genome/?term=escherichia+coli+K12)
---
**Built with 🧬 for genomics and πŸ€– for machine learning**
""")
# Set up the button click
submit_btn.click(
fn=tokenize_and_display,
inputs=input_sequence,
outputs=[stats_output, encoded_output, decoded_output, status_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()