import gradio as gr import json import os import time import torch from typing import List, Dict, Tuple from dotenv import load_dotenv from transformers import AutoTokenizer, AutoModelForCausalLM # Load environment variables from .env file load_dotenv() # Configuration MODEL_ID = "Qwen/Qwen3-0.6B" HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '') # Initialize model and tokenizer (local inference like the working app) print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) def show_token(token: str) -> str: """Format token for display""" if token == "\n": return "⏎" elif token.strip() == "": return f"␣{'' if len(token) == 1 else '×' + str(len(token))}" return token def predict_next_token(text: str, top_k: int = 10, temperature: float = 1.0, top_p: float = 0.9) -> Tuple[List[Dict], str]: """Predict next tokens using local model with temperature and top-p filtering""" if not text.strip(): return [], "Please enter some text to predict from" start_time = time.time() try: # Use local model inference tokens = tokenizer(text, return_tensors="pt", padding=False) out = model.generate( **tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, do_sample=False, ) # Get raw logits and apply temperature scaling logits = out.scores[0] scaled_logits = logits / temperature scores = torch.softmax(scaled_logits, dim=-1) # Apply top-p filtering (nucleus sampling) sorted_probs, sorted_indices = torch.sort(scores, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Find the cutoff point for top-p cutoff_index = torch.where(cumulative_probs >= top_p)[1] if len(cutoff_index) > 0: cutoff = cutoff_index[0].item() + 1 top_p_indices = sorted_indices[0, :cutoff] top_p_probs = sorted_probs[0, :cutoff] else: # Fallback if top_p is very low top_p_indices = sorted_indices[0, :min(50, len(sorted_indices[0]))] top_p_probs = sorted_probs[0, :min(50, len(sorted_probs[0]))] # Apply top-k to the top-p filtered results final_k = min(top_k, len(top_p_indices)) final_indices = top_p_indices[:final_k] final_probs = top_p_probs[:final_k] # Convert to tokens token_ids = [int(idx) for idx in final_indices] probs = [float(prob) for prob in final_probs] tokens_text = [tokenizer.decode([tid]) for tid in token_ids] # Create token data structure tokens_data = [] for i in range(len(token_ids)): tokens_data.append({ "token": tokens_text[i], "prob": probs[i] }) prediction_time = int((time.time() - start_time) * 1000) return tokens_data, f"Prediction time: {prediction_time}ms" except Exception as e: return [], f"❌ Error: {str(e)}" def create_clickable_token_display(tokens_data: List[Dict]) -> str: """Create HTML display with clickable tokens - simplified without JavaScript""" html = """
""" for i, token_data in enumerate(tokens_data): token_display = show_token(token_data['token']) percentage = f"{token_data['prob'] * 100:.2f}%" html += f"""
{token_display}
{percentage}
""" html += """
""" return html # Custom CSS to match the Token Visualizer gradient color scheme custom_css = """ /* Main container with gradient background like Token Visualizer */ .gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e293b 50%, #0f172a 100%) !important; color: #e2e8f0 !important; min-height: 100vh !important; } /* Blocks with subtle transparency and borders */ .block { background: rgba(15, 23, 42, 0.8) !important; border: 1px solid rgba(148, 163, 184, 0.2) !important; border-radius: 12px !important; backdrop-filter: blur(8px) !important; } /* Tab styling */ .tab-item { background: rgba(30, 41, 59, 0.6) !important; border: 1px solid rgba(148, 163, 184, 0.2) !important; border-radius: 8px !important; } /* Token buttons with modern gradient and hover effects */ .token-button { background: linear-gradient(135deg, #1e293b 0%, #334155 100%) !important; border: 1px solid rgba(148, 163, 184, 0.3) !important; color: #e2e8f0 !important; border-radius: 8px !important; margin: 0px !important; padding: 3px 8px !important; font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important; transition: all 0.3s ease !important; font-size: 12px !important; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3) !important; } .token-button:hover { background: linear-gradient(135deg, #334155 0%, #475569 100%) !important; border-color: rgba(148, 163, 184, 0.5) !important; transform: translateY(-1px) !important; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4) !important; } /* Input fields styling */ .gr-textbox { background: rgba(30, 41, 59, 0.6) !important; border: 1px solid rgba(148, 163, 184, 0.3) !important; border-radius: 8px !important; color: #e2e8f0 !important; } /* Slider styling */ .gr-slider { background: rgba(30, 41, 59, 0.6) !important; } .gr-slider input[type="range"] { background: linear-gradient(to right, #3b82f6, #06b6d4) !important; } /* Labels and text */ .gr-label { color: #cbd5e1 !important; font-weight: 500 !important; } .gr-info { color: #94a3b8 !important; background: rgba(30, 41, 59, 0.4) !important; border-radius: 6px !important; padding: 4px 8px !important; border: 1px solid rgba(148, 163, 184, 0.2) !important; } /* Remove Gradio's default spacing between buttons */ .token-button + .token-button { margin-top: 0px !important; } /* Remove gaps in the column containing buttons */ div:has(> .token-button) { gap: 0px !important; } /* Target Gradio's automatic spacing */ .block > div > div { gap: 0px !important; } /* Hide spinner arrows on number inputs */ input[type="number"]::-webkit-outer-spin-button, input[type="number"]::-webkit-inner-spin-button { -webkit-appearance: none !important; margin: 0 !important; } input[type="number"] { -moz-appearance: textfield !important; } /* Header styling to match Token Visualizer */ h1, h2, h3, h4 { background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 100%) !important; -webkit-background-clip: text !important; -webkit-text-fill-color: transparent !important; background-clip: text !important; } /* Interactive tooltip icons in labels */ .gr-label { position: relative !important; } /* Add tooltip functionality with JavaScript */ .gr-label:has-text("ⓘ"):hover::after { content: attr(data-tooltip); position: absolute; background: rgba(15, 23, 42, 0.95); color: #e2e8f0; padding: 8px 12px; border-radius: 6px; font-size: 12px; max-width: 250px; border: 1px solid rgba(148, 163, 184, 0.3); z-index: 1000; top: 100%; left: 0; margin-top: 5px; white-space: pre-wrap; } """ # Create Gradio interface with gr.Blocks(css=custom_css, title="Next-Token Predictor") as app: gr.HTML("""

Next-Token Predictor

Explore how AI predicts the next word! Click on predictions to append them.

""") with gr.Column(): text_input = gr.Textbox( label="Enter your prompt:", placeholder="Type anything... predictions update automatically!", value="Twinkle, twinkle, little ", lines=3, info="💡 Try: 'The weather today is', 'I think that', 'Once upon a time'" ) # Next Token Predictions directly below input with gr.Column(): gr.HTML("

Next Token Predictions

") # Create buttons for each possible token (we'll show/hide as needed) token_buttons = [] for i in range(15): # Support up to 15 tokens btn = gr.Button( value="", visible=False, elem_classes=["token-button"], size="sm" ) token_buttons.append(btn) # Parameter controls below predictions with gr.Row(): with gr.Column(): top_k = gr.Slider( minimum=5, maximum=15, value=10, step=1, label="Top-K", info="Number of most likely words to consider", show_label=True, interactive=True ) with gr.Column(): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", info="Controls randomness in predictions", show_label=True, interactive=True ) with gr.Column(): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Probability threshold for word selection", show_label=True, interactive=True ) timing_info = gr.HTML(value="
✨ Predictions update as you type!
") # Store current tokens data current_tokens = gr.State([]) def update_predictions_and_buttons(text, k, temp, p): tokens_data, timing = predict_next_token(text, int(k), float(temp), float(p)) # Update button states button_updates = [] for i in range(15): if i < len(tokens_data): token = tokens_data[i]['token'] prob = tokens_data[i]['prob'] display_token = show_token(token) button_label = f"{display_token} ({prob*100:.1f}%)" button_updates.append(gr.Button(value=button_label, visible=True)) else: button_updates.append(gr.Button(visible=False)) return [timing, tokens_data] + button_updates def append_token_to_input(current_text, tokens_data, button_index): if tokens_data and 0 <= button_index < len(tokens_data): token = tokens_data[button_index]['token'] return current_text + token return current_text # Auto-predict on any input change outputs = [timing_info, current_tokens] + token_buttons for component in [text_input, top_k, temperature, top_p]: component.change( update_predictions_and_buttons, inputs=[text_input, top_k, temperature, top_p], outputs=outputs ) # Set up click handlers for each token button for i, btn in enumerate(token_buttons): btn.click( lambda text, tokens, idx=i: append_token_to_input(text, tokens, idx), inputs=[text_input, current_tokens], outputs=[text_input] ) # Load initial predictions on app start app.load( lambda: update_predictions_and_buttons("Twinkle, twinkle, little ", 10, 1.0, 0.9), outputs=outputs ) if __name__ == "__main__": app.launch(share=False)