import gradio as gr from transformers import AutoTokenizer import json import re # Supported models list SUPPORTED_MODELS = { "Llama-2": "meta-llama/Llama-2-7b-chat-hf", "Llama-3": "meta-llama/Meta-Llama-3-8B-Instruct", "Qwen2": "Qwen/Qwen2-7B-Instruct", "Gemma-2": "google/gemma-2-9b-it", "GPT-2": "gpt2", "BERT": "bert-base-uncased", } # Global variable to store current tokenizer current_tokenizer = None # Color palette for alternating tokens TOKEN_COLORS = [ "#e3f2fd", # Light blue "#f3e5f5", # Light purple "#e8f5e8", # Light green "#fff3e0", # Light orange "#fce4ec", # Light pink "#e0f2f1", # Light teal "#f1f8e9", # Light lime "#fafafa", # Light gray "#fff8e1", # Light amber "#f3e5f5", # Light indigo ] def load_tokenizer(model_name): """Load the specified tokenizer""" global current_tokenizer try: model_path = SUPPORTED_MODELS[model_name] current_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) return f"✅ Successfully loaded {model_name} tokenizer" except Exception as e: return f"❌ Loading failed: {str(e)}" def visualize_tokens(text, model_name): """Visualize the tokenization results of text""" global current_tokenizer if not current_tokenizer: return "Please select and load a model first", None, None if not text.strip(): return "Please enter text to analyze", None, None try: # Perform tokenization on text encoding = current_tokenizer(text, return_tensors="pt", add_special_tokens=True) token_ids = encoding['input_ids'][0].tolist() tokens = current_tokenizer.convert_ids_to_tokens(token_ids) # Create HTML format visualization results html_output = "
" # Display Tokenization Results html_output += "

Tokenization Results:

" html_output += "
" # Cycle through multiple colors for consecutive tokens for i, token in enumerate(tokens): # Cycle through all available colors current_color_index = i % len(TOKEN_COLORS) # Get color for this token bg_color = TOKEN_COLORS[current_color_index] border_color = "#2196f3" # Create token span # Escape special characters in token string for HTML display escaped_token = token.replace("<", "<").replace(">", ">") token_html = f'{escaped_token}' html_output += token_html html_output += "
" # Display Token IDs html_output += "

Token IDs:

" html_output += "
" for i, token_id in enumerate(token_ids): # Alternate between two colors for consecutive token IDs current_color_index = i % len(TOKEN_COLORS) bg_color = TOKEN_COLORS[current_color_index] border_color = "#2196f3" # Create token ID span token_id_html = f'{token_id}' html_output += token_id_html html_output += "
" html_output += "
" # No JavaScript needed since we removed hover effects js_code = "" html_output += js_code # Create JSON format detailed information # Get vocabulary size vocab_size = current_tokenizer.vocab_size return html_output, f"Total tokens: {len(tokens)}\nVocabulary size: {vocab_size:,}" except Exception as e: return f"❌ Processing failed: {str(e)}", None # Create Gradio interface with gr.Blocks(title="Token Visualizer", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔍 Token Visualizer") gr.Markdown("This is a tool for visualizing the text tokenization process. Select a model, input text, and view the tokenization results.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Select Model") model_dropdown = gr.Dropdown( choices=list(SUPPORTED_MODELS.keys()), label="Select Model", value="GPT-2" ) load_btn = gr.Button("Load Tokenizer", variant="primary") load_status = gr.Textbox(label="Loading Status", interactive=False) gr.Markdown("### 2. Input Text") text_input = gr.Textbox( label="Enter text to tokenize", placeholder="Example: Hello, how are you today?", lines=4 ) visualize_btn = gr.Button("Visualize", variant="primary") with gr.Column(scale=2): gr.Markdown("### 3. Visualization Results") html_output = gr.HTML(label="Token Visualization") stats_output = gr.Textbox(label="Statistics", interactive=False) # Event binding load_btn.click( fn=load_tokenizer, inputs=[model_dropdown], outputs=[load_status] ) visualize_btn.click( fn=visualize_tokens, inputs=[text_input, model_dropdown], outputs=[html_output, stats_output] ) if __name__ == "__main__": demo.launch()