import gradio as gr
from transformers import AutoTokenizer
import json
import re
# Supported models list
SUPPORTED_MODELS = {
"Llama-2": "meta-llama/Llama-2-7b-chat-hf",
"Llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
"Qwen2": "Qwen/Qwen2-7B-Instruct",
"Gemma-2": "google/gemma-2-9b-it",
"GPT-2": "gpt2",
"BERT": "bert-base-uncased",
}
# Global variable to store current tokenizer
current_tokenizer = None
# Color palette for alternating tokens
TOKEN_COLORS = [
"#e3f2fd", # Light blue
"#f3e5f5", # Light purple
"#e8f5e8", # Light green
"#fff3e0", # Light orange
"#fce4ec", # Light pink
"#e0f2f1", # Light teal
"#f1f8e9", # Light lime
"#fafafa", # Light gray
"#fff8e1", # Light amber
"#f3e5f5", # Light indigo
]
def load_tokenizer(model_name):
"""Load the specified tokenizer"""
global current_tokenizer
try:
model_path = SUPPORTED_MODELS[model_name]
current_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return f"✅ Successfully loaded {model_name} tokenizer"
except Exception as e:
return f"❌ Loading failed: {str(e)}"
def visualize_tokens(text, model_name):
"""Visualize the tokenization results of text"""
global current_tokenizer
if not current_tokenizer:
return "Please select and load a model first", None, None
if not text.strip():
return "Please enter text to analyze", None, None
try:
# Perform tokenization on text
encoding = current_tokenizer(text, return_tensors="pt", add_special_tokens=True)
token_ids = encoding['input_ids'][0].tolist()
tokens = current_tokenizer.convert_ids_to_tokens(token_ids)
# Create HTML format visualization results
html_output = "
"
# Display Tokenization Results
html_output += "
Tokenization Results:
"
html_output += "
"
# Cycle through multiple colors for consecutive tokens
for i, token in enumerate(tokens):
# Cycle through all available colors
current_color_index = i % len(TOKEN_COLORS)
# Get color for this token
bg_color = TOKEN_COLORS[current_color_index]
border_color = "#2196f3"
# Create token span
# Escape special characters in token string for HTML display
escaped_token = token.replace("<", "<").replace(">", ">")
token_html = f'{escaped_token}'
html_output += token_html
html_output += "
"
# Display Token IDs
html_output += "
Token IDs:
"
html_output += "
"
for i, token_id in enumerate(token_ids):
# Alternate between two colors for consecutive token IDs
current_color_index = i % len(TOKEN_COLORS)
bg_color = TOKEN_COLORS[current_color_index]
border_color = "#2196f3"
# Create token ID span
token_id_html = f'{token_id}'
html_output += token_id_html
html_output += "
"
html_output += "
"
# No JavaScript needed since we removed hover effects
js_code = ""
html_output += js_code
# Create JSON format detailed information
# Get vocabulary size
vocab_size = current_tokenizer.vocab_size
return html_output, f"Total tokens: {len(tokens)}\nVocabulary size: {vocab_size:,}"
except Exception as e:
return f"❌ Processing failed: {str(e)}", None
# Create Gradio interface
with gr.Blocks(title="Token Visualizer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔍 Token Visualizer")
gr.Markdown("This is a tool for visualizing the text tokenization process. Select a model, input text, and view the tokenization results.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Select Model")
model_dropdown = gr.Dropdown(
choices=list(SUPPORTED_MODELS.keys()),
label="Select Model",
value="GPT-2"
)
load_btn = gr.Button("Load Tokenizer", variant="primary")
load_status = gr.Textbox(label="Loading Status", interactive=False)
gr.Markdown("### 2. Input Text")
text_input = gr.Textbox(
label="Enter text to tokenize",
placeholder="Example: Hello, how are you today?",
lines=4
)
visualize_btn = gr.Button("Visualize", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 3. Visualization Results")
html_output = gr.HTML(label="Token Visualization")
stats_output = gr.Textbox(label="Statistics", interactive=False)
# Event binding
load_btn.click(
fn=load_tokenizer,
inputs=[model_dropdown],
outputs=[load_status]
)
visualize_btn.click(
fn=visualize_tokens,
inputs=[text_input, model_dropdown],
outputs=[html_output, stats_output]
)
if __name__ == "__main__":
demo.launch()