Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pickle | |
| import json | |
| from typing import List, Tuple, Dict | |
| # Load the trained tokenizer | |
| def load_tokenizer(tokenizer_path="ecoli_bpe_tokenizer.pkl"): | |
| """Load the trained BPE tokenizer""" | |
| try: | |
| with open(tokenizer_path, 'rb') as f: | |
| tokenizer = pickle.load(f) | |
| return tokenizer['vocab'], tokenizer['merge_rules'] | |
| except Exception as e: | |
| print(f"Error loading tokenizer: {e}") | |
| return None, None | |
| # Initialize tokenizer | |
| vocab, merge_rules = load_tokenizer() | |
| def get_pair_counts(tokens: List[str]) -> Dict: | |
| """Count frequency of adjacent token pairs""" | |
| from collections import Counter | |
| pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)] | |
| return Counter(pairs) | |
| def merge_pair(tokens: List[str], pair: Tuple[str, str], new_token: str) -> List[str]: | |
| """Replace all occurrences of pair with new_token""" | |
| new_tokens = [] | |
| i = 0 | |
| while i < len(tokens): | |
| if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair: | |
| new_tokens.append(new_token) | |
| i += 2 | |
| else: | |
| new_tokens.append(tokens[i]) | |
| i += 1 | |
| return new_tokens | |
| def encode_sequence(text: str, vocab: Dict[str, int], merge_rules: List[Tuple[str, str]]) -> List[int]: | |
| """Encode DNA sequence to token IDs""" | |
| if not text: | |
| return [] | |
| # Validate DNA sequence | |
| valid_bases = set('ACGTN') | |
| text = text.upper().strip() | |
| if not all(c in valid_bases for c in text): | |
| invalid = set(text) - valid_bases | |
| raise ValueError(f"Invalid characters in sequence: {invalid}. Only A, C, G, T, N allowed.") | |
| # Start with character-level tokenization | |
| tokens = list(text) | |
| # Apply merge rules in order | |
| for pair in merge_rules: | |
| merged_token = pair[0] + pair[1] | |
| tokens = merge_pair(tokens, pair, merged_token) | |
| # Convert to IDs | |
| token_ids = [] | |
| for token in tokens: | |
| if token in vocab: | |
| token_ids.append(vocab[token]) | |
| else: | |
| # Shouldn't happen but handle gracefully | |
| print(f"Warning: Unknown token '{token}'") | |
| return token_ids | |
| def decode_sequence(token_ids: List[int], vocab: Dict[str, int]) -> str: | |
| """Decode token IDs back to DNA sequence""" | |
| if not token_ids: | |
| return "" | |
| # Create reverse vocabulary | |
| id_to_token = {idx: token for token, idx in vocab.items()} | |
| # Look up each ID | |
| tokens = [] | |
| for token_id in token_ids: | |
| if token_id in id_to_token: | |
| tokens.append(id_to_token[token_id]) | |
| else: | |
| print(f"Warning: Unknown token ID {token_id}") | |
| # Concatenate | |
| return ''.join(tokens) | |
| def analyze_sequence(sequence: str): | |
| """Analyze and encode a DNA sequence""" | |
| if not sequence or not vocab or not merge_rules: | |
| return "Please enter a valid DNA sequence.", "", "", "" | |
| try: | |
| # Encode | |
| encoded_ids = encode_sequence(sequence, vocab, merge_rules) | |
| # Decode to verify | |
| decoded = decode_sequence(encoded_ids, vocab) | |
| # Get token strings | |
| id_to_token = {idx: token for token, idx in vocab.items()} | |
| token_strings = [id_to_token[tid] for tid in encoded_ids if tid in id_to_token] | |
| # Calculate statistics | |
| original_length = len(sequence) | |
| compressed_length = len(encoded_ids) | |
| compression_ratio = original_length / compressed_length if compressed_length > 0 else 0 | |
| # Format results | |
| stats = f""" | |
| ## π Compression Statistics | |
| - **Original Length**: {original_length:,} bases | |
| - **Compressed Length**: {compressed_length:,} tokens | |
| - **Compression Ratio**: {compression_ratio:.3f}x | |
| - **Lossless Check**: {'β PASS' if sequence.upper() == decoded else 'β FAIL'} | |
| ## π’ Token IDs | |
| ``` | |
| {encoded_ids[:50]}{'...' if len(encoded_ids) > 50 else ''} | |
| ``` | |
| ## 𧬠Learned Tokens | |
| ``` | |
| {token_strings[:20]}{'...' if len(token_strings) > 20 else ''} | |
| ``` | |
| """ | |
| # Format encoded output | |
| encoded_output = f"Token IDs: {encoded_ids}" | |
| # Format decoded output | |
| decoded_output = decoded | |
| return stats, encoded_output, decoded_output, "β Success!" | |
| except Exception as e: | |
| return f"β Error: {str(e)}", "", "", f"β Error: {str(e)}" | |
| def tokenize_and_display(sequence: str): | |
| """Main function for the Gradio interface""" | |
| return analyze_sequence(sequence) | |
| # Example DNA sequences | |
| EXAMPLES = [ | |
| ["ATGAAACGCATTAGCACCACCATTACCACCACCATCA"], # Random sequence | |
| ["ATGATGATGATG"], # Start codon repeats | |
| ["TATAATATATATAA"], # TATA box variations | |
| ["GCGCGCGCGCGC"], # CpG islands | |
| ["AAAAAAAAAAAAAAAA"], # Poly-A tail | |
| ["AGGAGGTAAATG"], # Shine-Dalgarno + Start codon | |
| ["ATGCGGCGTGAACGCCTTATCCGGCC"], # Longest learned token | |
| ["AGCTTTTCATTCTGACTGCAACGGGCAATATGTCTCTGTGTGGATTAAAAAAAGAGTGTCTGATAGCAGC"], # E. coli actual sequence | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="𧬠BPE DNA Tokenizer") as demo: | |
| gr.Markdown(""" | |
| # 𧬠BPE DNA Tokenizer - Interactive Demo | |
| ### Byte Pair Encoding for Genomic Sequences | |
| This tokenizer was trained on the *E. coli* K-12 genome (4.6M base pairs) and achieves **5.208x compression**. | |
| **Key Features:** | |
| - π― 5,000 token vocabulary | |
| - π 5.21x compression ratio | |
| - π¬ Discovers biological patterns (codons, TATA boxes, etc.) | |
| - β 100% lossless encoding/decoding | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_sequence = gr.Textbox( | |
| label="𧬠Input DNA Sequence", | |
| placeholder="Enter DNA sequence (A, C, G, T, N)...\nExample: ATGAAACGCATTAGC", | |
| lines=5, | |
| max_lines=10 | |
| ) | |
| submit_btn = gr.Button("π¬ Tokenize Sequence", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| status_output = gr.Textbox(label="π Status", lines=1) | |
| with gr.Row(): | |
| with gr.Column(): | |
| stats_output = gr.Markdown(label="Statistics") | |
| with gr.Row(): | |
| with gr.Column(): | |
| encoded_output = gr.Textbox(label="π’ Encoded (Token IDs)", lines=3) | |
| with gr.Column(): | |
| decoded_output = gr.Textbox(label="𧬠Decoded (Verification)", lines=3) | |
| gr.Markdown(""" | |
| --- | |
| ## π Example Sequences | |
| Click any example below to try it out: | |
| """) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=input_sequence, | |
| outputs=[stats_output, encoded_output, decoded_output, status_output], | |
| fn=tokenize_and_display, | |
| cache_examples=True, | |
| label="Example DNA Sequences" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ## π¬ Biological Patterns Discovered | |
| The tokenizer automatically learned these biological motifs: | |
| | Pattern | Token ID | Biological Function | | |
| |---------|----------|---------------------| | |
| | ATG | 20 | Start codon (translation initiation) | | |
| | TAA | 25 | Stop codon (translation termination) | | |
| | TAG | 65 | Stop codon (translation termination) | | |
| | TATAA | 279 | TATA box (transcription promoter) | | |
| | AGGAGG | 3642 | Shine-Dalgarno (ribosome binding) | | |
| | GCGC | 35 | CpG island (gene regulation) | | |
| --- | |
| ## π Model Information | |
| - **Training Dataset**: *E. coli* K-12 (GCF_000005845.2) | |
| - **Vocabulary Size**: 5,000 tokens | |
| - **Compression Ratio**: 5.208x | |
| - **Token Length Range**: 1-26 bases | |
| - **Training Time**: 88 minutes | |
| --- | |
| ## π Links | |
| - [GitHub Repository](https://github.com/abi2024/bpe-dna-tokenizer) | |
| - [Paper: Neural Machine Translation with Subword Units](https://arxiv.org/abs/1508.07909) | |
| - [E. coli K-12 Reference](https://www.ncbi.nlm.nih.gov/genome/?term=escherichia+coli+K12) | |
| --- | |
| **Built with 𧬠for genomics and π€ for machine learning** | |
| """) | |
| # Set up the button click | |
| submit_btn.click( | |
| fn=tokenize_and_display, | |
| inputs=input_sequence, | |
| outputs=[stats_output, encoded_output, decoded_output, status_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |