import gradio as gr import pickle import json from typing import List, Tuple, Dict # Load the trained tokenizer def load_tokenizer(tokenizer_path="ecoli_bpe_tokenizer.pkl"): """Load the trained BPE tokenizer""" try: with open(tokenizer_path, 'rb') as f: tokenizer = pickle.load(f) return tokenizer['vocab'], tokenizer['merge_rules'] except Exception as e: print(f"Error loading tokenizer: {e}") return None, None # Initialize tokenizer vocab, merge_rules = load_tokenizer() def get_pair_counts(tokens: List[str]) -> Dict: """Count frequency of adjacent token pairs""" from collections import Counter pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)] return Counter(pairs) def merge_pair(tokens: List[str], pair: Tuple[str, str], new_token: str) -> List[str]: """Replace all occurrences of pair with new_token""" new_tokens = [] i = 0 while i < len(tokens): if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair: new_tokens.append(new_token) i += 2 else: new_tokens.append(tokens[i]) i += 1 return new_tokens def encode_sequence(text: str, vocab: Dict[str, int], merge_rules: List[Tuple[str, str]]) -> List[int]: """Encode DNA sequence to token IDs""" if not text: return [] # Validate DNA sequence valid_bases = set('ACGTN') text = text.upper().strip() if not all(c in valid_bases for c in text): invalid = set(text) - valid_bases raise ValueError(f"Invalid characters in sequence: {invalid}. Only A, C, G, T, N allowed.") # Start with character-level tokenization tokens = list(text) # Apply merge rules in order for pair in merge_rules: merged_token = pair[0] + pair[1] tokens = merge_pair(tokens, pair, merged_token) # Convert to IDs token_ids = [] for token in tokens: if token in vocab: token_ids.append(vocab[token]) else: # Shouldn't happen but handle gracefully print(f"Warning: Unknown token '{token}'") return token_ids def decode_sequence(token_ids: List[int], vocab: Dict[str, int]) -> str: """Decode token IDs back to DNA sequence""" if not token_ids: return "" # Create reverse vocabulary id_to_token = {idx: token for token, idx in vocab.items()} # Look up each ID tokens = [] for token_id in token_ids: if token_id in id_to_token: tokens.append(id_to_token[token_id]) else: print(f"Warning: Unknown token ID {token_id}") # Concatenate return ''.join(tokens) def analyze_sequence(sequence: str): """Analyze and encode a DNA sequence""" if not sequence or not vocab or not merge_rules: return "Please enter a valid DNA sequence.", "", "", "" try: # Encode encoded_ids = encode_sequence(sequence, vocab, merge_rules) # Decode to verify decoded = decode_sequence(encoded_ids, vocab) # Get token strings id_to_token = {idx: token for token, idx in vocab.items()} token_strings = [id_to_token[tid] for tid in encoded_ids if tid in id_to_token] # Calculate statistics original_length = len(sequence) compressed_length = len(encoded_ids) compression_ratio = original_length / compressed_length if compressed_length > 0 else 0 # Format results stats = f""" ## 📊 Compression Statistics - **Original Length**: {original_length:,} bases - **Compressed Length**: {compressed_length:,} tokens - **Compression Ratio**: {compression_ratio:.3f}x - **Lossless Check**: {'✅ PASS' if sequence.upper() == decoded else '❌ FAIL'} ## 🔢 Token IDs ``` {encoded_ids[:50]}{'...' if len(encoded_ids) > 50 else ''} ``` ## 🧬 Learned Tokens ``` {token_strings[:20]}{'...' if len(token_strings) > 20 else ''} ``` """ # Format encoded output encoded_output = f"Token IDs: {encoded_ids}" # Format decoded output decoded_output = decoded return stats, encoded_output, decoded_output, "✅ Success!" except Exception as e: return f"❌ Error: {str(e)}", "", "", f"❌ Error: {str(e)}" def tokenize_and_display(sequence: str): """Main function for the Gradio interface""" return analyze_sequence(sequence) # Example DNA sequences EXAMPLES = [ ["ATGAAACGCATTAGCACCACCATTACCACCACCATCA"], # Random sequence ["ATGATGATGATG"], # Start codon repeats ["TATAATATATATAA"], # TATA box variations ["GCGCGCGCGCGC"], # CpG islands ["AAAAAAAAAAAAAAAA"], # Poly-A tail ["AGGAGGTAAATG"], # Shine-Dalgarno + Start codon ["ATGCGGCGTGAACGCCTTATCCGGCC"], # Longest learned token ["AGCTTTTCATTCTGACTGCAACGGGCAATATGTCTCTGTGTGGATTAAAAAAAGAGTGTCTGATAGCAGC"], # E. coli actual sequence ] # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft(), title="🧬 BPE DNA Tokenizer") as demo: gr.Markdown(""" # 🧬 BPE DNA Tokenizer - Interactive Demo ### Byte Pair Encoding for Genomic Sequences This tokenizer was trained on the *E. coli* K-12 genome (4.6M base pairs) and achieves **5.208x compression**. **Key Features:** - 🎯 5,000 token vocabulary - 🚀 5.21x compression ratio - 🔬 Discovers biological patterns (codons, TATA boxes, etc.) - ✅ 100% lossless encoding/decoding --- """) with gr.Row(): with gr.Column(scale=2): input_sequence = gr.Textbox( label="🧬 Input DNA Sequence", placeholder="Enter DNA sequence (A, C, G, T, N)...\nExample: ATGAAACGCATTAGC", lines=5, max_lines=10 ) submit_btn = gr.Button("🔬 Tokenize Sequence", variant="primary", size="lg") with gr.Column(scale=1): status_output = gr.Textbox(label="📊 Status", lines=1) with gr.Row(): with gr.Column(): stats_output = gr.Markdown(label="Statistics") with gr.Row(): with gr.Column(): encoded_output = gr.Textbox(label="🔢 Encoded (Token IDs)", lines=3) with gr.Column(): decoded_output = gr.Textbox(label="🧬 Decoded (Verification)", lines=3) gr.Markdown(""" --- ## 📚 Example Sequences Click any example below to try it out: """) gr.Examples( examples=EXAMPLES, inputs=input_sequence, outputs=[stats_output, encoded_output, decoded_output, status_output], fn=tokenize_and_display, cache_examples=True, label="Example DNA Sequences" ) gr.Markdown(""" --- ## 🔬 Biological Patterns Discovered The tokenizer automatically learned these biological motifs: | Pattern | Token ID | Biological Function | |---------|----------|---------------------| | ATG | 20 | Start codon (translation initiation) | | TAA | 25 | Stop codon (translation termination) | | TAG | 65 | Stop codon (translation termination) | | TATAA | 279 | TATA box (transcription promoter) | | AGGAGG | 3642 | Shine-Dalgarno (ribosome binding) | | GCGC | 35 | CpG island (gene regulation) | --- ## 📊 Model Information - **Training Dataset**: *E. coli* K-12 (GCF_000005845.2) - **Vocabulary Size**: 5,000 tokens - **Compression Ratio**: 5.208x - **Token Length Range**: 1-26 bases - **Training Time**: 88 minutes --- ## 🔗 Links - [GitHub Repository](https://github.com/abi2024/bpe-dna-tokenizer) - [Paper: Neural Machine Translation with Subword Units](https://arxiv.org/abs/1508.07909) - [E. coli K-12 Reference](https://www.ncbi.nlm.nih.gov/genome/?term=escherichia+coli+K12) --- **Built with 🧬 for genomics and 🤖 for machine learning** """) # Set up the button click submit_btn.click( fn=tokenize_and_display, inputs=input_sequence, outputs=[stats_output, encoded_output, decoded_output, status_output] ) # Launch the app if __name__ == "__main__": demo.launch()