import gradio as gr
import json
import os
import time
import torch
from typing import List, Dict, Tuple
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load environment variables from .env file
load_dotenv()
# Configuration
MODEL_ID = "Qwen/Qwen3-0.6B"
HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
# Initialize model and tokenizer (local inference like the working app)
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
def show_token(token: str) -> str:
"""Format token for display"""
if token == "\n":
return "⏎"
elif token.strip() == "":
return f"␣{'' if len(token) == 1 else '×' + str(len(token))}"
return token
def predict_next_token(text: str, top_k: int = 10, temperature: float = 1.0, top_p: float = 0.9) -> Tuple[List[Dict], str]:
"""Predict next tokens using local model with temperature and top-p filtering"""
if not text.strip():
return [], "Please enter some text to predict from"
start_time = time.time()
try:
# Use local model inference
tokens = tokenizer(text, return_tensors="pt", padding=False)
out = model.generate(
**tokens,
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
)
# Get raw logits and apply temperature scaling
logits = out.scores[0]
scaled_logits = logits / temperature
scores = torch.softmax(scaled_logits, dim=-1)
# Apply top-p filtering (nucleus sampling)
sorted_probs, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Find the cutoff point for top-p
cutoff_index = torch.where(cumulative_probs >= top_p)[1]
if len(cutoff_index) > 0:
cutoff = cutoff_index[0].item() + 1
top_p_indices = sorted_indices[0, :cutoff]
top_p_probs = sorted_probs[0, :cutoff]
else:
# Fallback if top_p is very low
top_p_indices = sorted_indices[0, :min(50, len(sorted_indices[0]))]
top_p_probs = sorted_probs[0, :min(50, len(sorted_probs[0]))]
# Apply top-k to the top-p filtered results
final_k = min(top_k, len(top_p_indices))
final_indices = top_p_indices[:final_k]
final_probs = top_p_probs[:final_k]
# Convert to tokens
token_ids = [int(idx) for idx in final_indices]
probs = [float(prob) for prob in final_probs]
tokens_text = [tokenizer.decode([tid]) for tid in token_ids]
# Create token data structure
tokens_data = []
for i in range(len(token_ids)):
tokens_data.append({
"token": tokens_text[i],
"prob": probs[i]
})
prediction_time = int((time.time() - start_time) * 1000)
return tokens_data, f"Prediction time: {prediction_time}ms"
except Exception as e:
return [], f"❌ Error: {str(e)}"
def create_clickable_token_display(tokens_data: List[Dict]) -> str:
"""Create HTML display with clickable tokens - simplified without JavaScript"""
html = """
"""
for i, token_data in enumerate(tokens_data):
token_display = show_token(token_data['token'])
percentage = f"{token_data['prob'] * 100:.2f}%"
html += f"""
{token_display}
{percentage}
"""
html += """
"""
return html
# Custom CSS to match the Token Visualizer gradient color scheme
custom_css = """
/* Main container with gradient background like Token Visualizer */
.gradio-container {
background: linear-gradient(135deg, #0f172a 0%, #1e293b 50%, #0f172a 100%) !important;
color: #e2e8f0 !important;
min-height: 100vh !important;
}
/* Blocks with subtle transparency and borders */
.block {
background: rgba(15, 23, 42, 0.8) !important;
border: 1px solid rgba(148, 163, 184, 0.2) !important;
border-radius: 12px !important;
backdrop-filter: blur(8px) !important;
}
/* Tab styling */
.tab-item {
background: rgba(30, 41, 59, 0.6) !important;
border: 1px solid rgba(148, 163, 184, 0.2) !important;
border-radius: 8px !important;
}
/* Token buttons with modern gradient and hover effects */
.token-button {
background: linear-gradient(135deg, #1e293b 0%, #334155 100%) !important;
border: 1px solid rgba(148, 163, 184, 0.3) !important;
color: #e2e8f0 !important;
border-radius: 8px !important;
margin: 0px !important;
padding: 3px 8px !important;
font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important;
transition: all 0.3s ease !important;
font-size: 12px !important;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3) !important;
}
.token-button:hover {
background: linear-gradient(135deg, #334155 0%, #475569 100%) !important;
border-color: rgba(148, 163, 184, 0.5) !important;
transform: translateY(-1px) !important;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4) !important;
}
/* Input fields styling */
.gr-textbox {
background: rgba(30, 41, 59, 0.6) !important;
border: 1px solid rgba(148, 163, 184, 0.3) !important;
border-radius: 8px !important;
color: #e2e8f0 !important;
}
/* Slider styling */
.gr-slider {
background: rgba(30, 41, 59, 0.6) !important;
}
.gr-slider input[type="range"] {
background: linear-gradient(to right, #3b82f6, #06b6d4) !important;
}
/* Labels and text */
.gr-label {
color: #cbd5e1 !important;
font-weight: 500 !important;
}
.gr-info {
color: #94a3b8 !important;
background: rgba(30, 41, 59, 0.4) !important;
border-radius: 6px !important;
padding: 4px 8px !important;
border: 1px solid rgba(148, 163, 184, 0.2) !important;
}
/* Remove Gradio's default spacing between buttons */
.token-button + .token-button {
margin-top: 0px !important;
}
/* Remove gaps in the column containing buttons */
div:has(> .token-button) {
gap: 0px !important;
}
/* Target Gradio's automatic spacing */
.block > div > div {
gap: 0px !important;
}
/* Hide spinner arrows on number inputs */
input[type="number"]::-webkit-outer-spin-button,
input[type="number"]::-webkit-inner-spin-button {
-webkit-appearance: none !important;
margin: 0 !important;
}
input[type="number"] {
-moz-appearance: textfield !important;
}
/* Header styling to match Token Visualizer */
h1, h2, h3, h4 {
background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 100%) !important;
-webkit-background-clip: text !important;
-webkit-text-fill-color: transparent !important;
background-clip: text !important;
}
/* Interactive tooltip icons in labels */
.gr-label {
position: relative !important;
}
/* Add tooltip functionality with JavaScript */
.gr-label:has-text("ⓘ"):hover::after {
content: attr(data-tooltip);
position: absolute;
background: rgba(15, 23, 42, 0.95);
color: #e2e8f0;
padding: 8px 12px;
border-radius: 6px;
font-size: 12px;
max-width: 250px;
border: 1px solid rgba(148, 163, 184, 0.3);
z-index: 1000;
top: 100%;
left: 0;
margin-top: 5px;
white-space: pre-wrap;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="Next-Token Predictor") as app:
gr.HTML("""
Next-Token Predictor
Explore how AI predicts the next word! Click on predictions to append them.
""")
with gr.Column():
text_input = gr.Textbox(
label="Enter your prompt:",
placeholder="Type anything... predictions update automatically!",
value="Twinkle, twinkle, little ",
lines=3,
info="💡 Try: 'The weather today is', 'I think that', 'Once upon a time'"
)
# Next Token Predictions directly below input
with gr.Column():
gr.HTML("Next Token Predictions
")
# Create buttons for each possible token (we'll show/hide as needed)
token_buttons = []
for i in range(15): # Support up to 15 tokens
btn = gr.Button(
value="",
visible=False,
elem_classes=["token-button"],
size="sm"
)
token_buttons.append(btn)
# Parameter controls below predictions
with gr.Row():
with gr.Column():
top_k = gr.Slider(
minimum=5,
maximum=15,
value=10,
step=1,
label="Top-K",
info="Number of most likely words to consider",
show_label=True,
interactive=True
)
with gr.Column():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature",
info="Controls randomness in predictions",
show_label=True,
interactive=True
)
with gr.Column():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P",
info="Probability threshold for word selection",
show_label=True,
interactive=True
)
timing_info = gr.HTML(value="✨ Predictions update as you type!
")
# Store current tokens data
current_tokens = gr.State([])
def update_predictions_and_buttons(text, k, temp, p):
tokens_data, timing = predict_next_token(text, int(k), float(temp), float(p))
# Update button states
button_updates = []
for i in range(15):
if i < len(tokens_data):
token = tokens_data[i]['token']
prob = tokens_data[i]['prob']
display_token = show_token(token)
button_label = f"{display_token} ({prob*100:.1f}%)"
button_updates.append(gr.Button(value=button_label, visible=True))
else:
button_updates.append(gr.Button(visible=False))
return [timing, tokens_data] + button_updates
def append_token_to_input(current_text, tokens_data, button_index):
if tokens_data and 0 <= button_index < len(tokens_data):
token = tokens_data[button_index]['token']
return current_text + token
return current_text
# Auto-predict on any input change
outputs = [timing_info, current_tokens] + token_buttons
for component in [text_input, top_k, temperature, top_p]:
component.change(
update_predictions_and_buttons,
inputs=[text_input, top_k, temperature, top_p],
outputs=outputs
)
# Set up click handlers for each token button
for i, btn in enumerate(token_buttons):
btn.click(
lambda text, tokens, idx=i: append_token_to_input(text, tokens, idx),
inputs=[text_input, current_tokens],
outputs=[text_input]
)
# Load initial predictions on app start
app.load(
lambda: update_predictions_and_buttons("Twinkle, twinkle, little ", 10, 1.0, 0.9),
outputs=outputs
)
if __name__ == "__main__":
app.launch(share=False)