import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch.nn.functional as F import spaces class NextWordPredictor: def __init__(self): # Load pre-trained GPT-2 model and tokenizer self.model_name = "gpt2" self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name) self.model = GPT2LMHeadModel.from_pretrained(self.model_name) # Set padding token self.tokenizer.pad_token = self.tokenizer.eos_token #make model smaller self.model.half() # Set model to evaluation mode self.model.eval() @spaces.GPU def predict_next_words(self, text, top_k=10): """ Predict the next word given input text Returns top_k most likely words with their probabilities and suggested words """ text = text.strip() if not text: return [], [] # Tokenize input text inputs = self.tokenizer.encode(text, return_tensors='pt') # Get model predictions with torch.no_grad(): outputs = self.model(inputs) predictions = outputs.logits[0, -1, :] # Get last token predictions # Apply softmax to get probabilities probabilities = F.softmax(predictions, dim=-1) # Get top k predictions top_k_probs, top_k_indices = torch.topk(probabilities, top_k) # Convert to readable format with aligned progress bars results = [] suggested_words = [] # Find the longest word for alignment words_with_probs = [] for prob, idx in zip(top_k_probs, top_k_indices): word = self.tokenizer.decode(idx.item()).strip() probability = prob.item() percentage = probability * 100 words_with_probs.append((word, probability, percentage)) # Find max word length for alignment max_word_length = max(len(word) for word, _, _ in words_with_probs) for word, probability, percentage in words_with_probs: # Create aligned progress bar with better blocks bar_length = 20 filled_length = int(bar_length * probability) bar = '█' * filled_length + '▢' * (bar_length - filled_length) # Align everything properly word_padded = word.ljust(max_word_length) result = f"{word_padded} | {probability:.4f} ({percentage:5.2f}%) {bar}" results.append(result) suggested_words.append(word) return results, suggested_words # Initialize the predictor predictor = NextWordPredictor() def update_predictions(text): """Update predictions based on current text""" predictions_list, suggested_words = predictor.predict_next_words(text) if not predictions_list: return [gr.update(visible=False, interactive=True)] * 10 # Update buttons with predictions, hide unused ones, enable all buttons updates = [] for i in range(10): if i < len(predictions_list): updates.append(gr.update(value=predictions_list[i], visible=True, interactive=True)) else: updates.append(gr.update(visible=False, interactive=True)) return updates def disable_all_buttons(): """Disable all prediction buttons""" return [gr.update(interactive=False)] * 10 def add_word_to_text(current_text, button_value): """Extract word from button and add to text""" if not button_value: return current_text # Extract the word (everything before the first "|") word = button_value.split(" | ")[0].strip() if not current_text.strip(): return word # Add space if text doesn't end with space if current_text.endswith(' '): return current_text + word else: return current_text + ' ' + word # Create Gradio interface with gr.Blocks(title="Next Word Predictor", theme=gr.themes.Soft()) as demo: gr.Markdown("# Next Word Predictor") gr.Markdown("Type a sentence and see the top 10 most likely next words with their probabilities! **Click on any prediction to add that word to your text.**") with gr.Row(): text_input = gr.Textbox( label="Enter your text", placeholder="Start typing a sentence...", lines=4, interactive=True ) with gr.Row(): # Examples gr.Examples( examples=[ ["The weather today is"], ["I love to eat"], ["Machine learning is"], ["The quick brown fox"], # ["In the future, we will"] ], inputs=text_input ) with gr.Row(): gr.Markdown("### Top 10 Next Word Predictions") gr.Markdown("*Click any prediction below to add it to your text*") # Create 10 clickable buttons for predictions prediction_buttons = [] for i in range(10): with gr.Row(): btn = gr.Button( value="", visible=False, variant="secondary", size="sm", interactive=True ) prediction_buttons.append(btn) # Update predictions as user types text_input.change( fn=update_predictions, inputs=text_input, outputs=prediction_buttons ) # Add click handlers for each prediction button for btn in prediction_buttons: # First disable all buttons when any button is clicked btn.click( fn=disable_all_buttons, inputs=[], outputs=prediction_buttons ).then( # Then add the word to text fn=add_word_to_text, inputs=[text_input, btn], outputs=text_input ).then( # Finally update predictions (which will re-enable buttons) fn=update_predictions, inputs=text_input, outputs=prediction_buttons ) # gr.Markdown("### How it works:") # gr.Markdown(""" # - Uses GPT-2 language model to predict next words # - Applies softmax to convert logits to probabilities # - Shows top 10 most likely words with percentages and aligned visual bars # - Updates predictions in real-time as you type # - **Click on any prediction button to add that word to your text automatically** # - **Buttons are disabled while processing to prevent multiple clicks** # - Progress bars show relative probability: █ = filled, ▢ = empty outline # - All bars are perfectly aligned for easy comparison # """) # Launch the app if __name__ == "__main__": demo.launch()