next-token-predictor / next_word_predictor.py
willsh1997's picture
:wrench: minor change - test github actions
ec2e454
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F
import spaces
class NextWordPredictor:
def __init__(self):
# Load pre-trained GPT-2 model and tokenizer
self.model_name = "gpt2"
self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
self.model = GPT2LMHeadModel.from_pretrained(self.model_name)
# Set padding token
self.tokenizer.pad_token = self.tokenizer.eos_token
#make model smaller
self.model.half()
# Set model to evaluation mode
self.model.eval()
@spaces.GPU
def predict_next_words(self, text, top_k=10):
"""
Predict the next word given input text
Returns top_k most likely words with their probabilities and suggested words
"""
text = text.strip()
if not text:
return [], []
# Tokenize input text
inputs = self.tokenizer.encode(text, return_tensors='pt')
# Get model predictions
with torch.no_grad():
outputs = self.model(inputs)
predictions = outputs.logits[0, -1, :] # Get last token predictions
# Apply softmax to get probabilities
probabilities = F.softmax(predictions, dim=-1)
# Get top k predictions
top_k_probs, top_k_indices = torch.topk(probabilities, top_k)
# Convert to readable format with aligned progress bars
results = []
suggested_words = []
# Find the longest word for alignment
words_with_probs = []
for prob, idx in zip(top_k_probs, top_k_indices):
word = self.tokenizer.decode(idx.item()).strip()
probability = prob.item()
percentage = probability * 100
words_with_probs.append((word, probability, percentage))
# Find max word length for alignment
max_word_length = max(len(word) for word, _, _ in words_with_probs)
for word, probability, percentage in words_with_probs:
# Create aligned progress bar with better blocks
bar_length = 20
filled_length = int(bar_length * probability)
bar = '█' * filled_length + '▢' * (bar_length - filled_length)
# Align everything properly
word_padded = word.ljust(max_word_length)
result = f"{word_padded} | {probability:.4f} ({percentage:5.2f}%) {bar}"
results.append(result)
suggested_words.append(word)
return results, suggested_words
# Initialize the predictor
predictor = NextWordPredictor()
def update_predictions(text):
"""Update predictions based on current text"""
predictions_list, suggested_words = predictor.predict_next_words(text)
if not predictions_list:
return [gr.update(visible=False, interactive=True)] * 10
# Update buttons with predictions, hide unused ones, enable all buttons
updates = []
for i in range(10):
if i < len(predictions_list):
updates.append(gr.update(value=predictions_list[i], visible=True, interactive=True))
else:
updates.append(gr.update(visible=False, interactive=True))
return updates
def disable_all_buttons():
"""Disable all prediction buttons"""
return [gr.update(interactive=False)] * 10
def add_word_to_text(current_text, button_value):
"""Extract word from button and add to text"""
if not button_value:
return current_text
# Extract the word (everything before the first "|")
word = button_value.split(" | ")[0].strip()
if not current_text.strip():
return word
# Add space if text doesn't end with space
if current_text.endswith(' '):
return current_text + word
else:
return current_text + ' ' + word
# Create Gradio interface
with gr.Blocks(title="Next Word Predictor", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Next Word Predictor")
gr.Markdown("Type a sentence and see the top 10 most likely next words with their probabilities! **Click on any prediction to add that word to your text.**")
with gr.Row():
text_input = gr.Textbox(
label="Enter your text",
placeholder="Start typing a sentence...",
lines=4,
interactive=True
)
with gr.Row():
# Examples
gr.Examples(
examples=[
["The weather today is"],
["I love to eat"],
["Machine learning is"],
["The quick brown fox"],
# ["In the future, we will"]
],
inputs=text_input
)
with gr.Row():
gr.Markdown("### Top 10 Next Word Predictions")
gr.Markdown("*Click any prediction below to add it to your text*")
# Create 10 clickable buttons for predictions
prediction_buttons = []
for i in range(10):
with gr.Row():
btn = gr.Button(
value="",
visible=False,
variant="secondary",
size="sm",
interactive=True
)
prediction_buttons.append(btn)
# Update predictions as user types
text_input.change(
fn=update_predictions,
inputs=text_input,
outputs=prediction_buttons
)
# Add click handlers for each prediction button
for btn in prediction_buttons:
# First disable all buttons when any button is clicked
btn.click(
fn=disable_all_buttons,
inputs=[],
outputs=prediction_buttons
).then(
# Then add the word to text
fn=add_word_to_text,
inputs=[text_input, btn],
outputs=text_input
).then(
# Finally update predictions (which will re-enable buttons)
fn=update_predictions,
inputs=text_input,
outputs=prediction_buttons
)
# gr.Markdown("### How it works:")
# gr.Markdown("""
# - Uses GPT-2 language model to predict next words
# - Applies softmax to convert logits to probabilities
# - Shows top 10 most likely words with percentages and aligned visual bars
# - Updates predictions in real-time as you type
# - **Click on any prediction button to add that word to your text automatically**
# - **Buttons are disabled while processing to prevent multiple clicks**
# - Progress bars show relative probability: █ = filled, ▢ = empty outline
# - All bars are perfectly aligned for easy comparison
# """)
# Launch the app
if __name__ == "__main__":
demo.launch()