Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import torch.nn.functional as F | |
| import spaces | |
| class NextWordPredictor: | |
| def __init__(self): | |
| # Load pre-trained GPT-2 model and tokenizer | |
| self.model_name = "gpt2" | |
| self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name) | |
| self.model = GPT2LMHeadModel.from_pretrained(self.model_name) | |
| # Set padding token | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| #make model smaller | |
| self.model.half() | |
| # Set model to evaluation mode | |
| self.model.eval() | |
| def predict_next_words(self, text, top_k=10): | |
| """ | |
| Predict the next word given input text | |
| Returns top_k most likely words with their probabilities and suggested words | |
| """ | |
| text = text.strip() | |
| if not text: | |
| return [], [] | |
| # Tokenize input text | |
| inputs = self.tokenizer.encode(text, return_tensors='pt') | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = self.model(inputs) | |
| predictions = outputs.logits[0, -1, :] # Get last token predictions | |
| # Apply softmax to get probabilities | |
| probabilities = F.softmax(predictions, dim=-1) | |
| # Get top k predictions | |
| top_k_probs, top_k_indices = torch.topk(probabilities, top_k) | |
| # Convert to readable format with aligned progress bars | |
| results = [] | |
| suggested_words = [] | |
| # Find the longest word for alignment | |
| words_with_probs = [] | |
| for prob, idx in zip(top_k_probs, top_k_indices): | |
| word = self.tokenizer.decode(idx.item()).strip() | |
| probability = prob.item() | |
| percentage = probability * 100 | |
| words_with_probs.append((word, probability, percentage)) | |
| # Find max word length for alignment | |
| max_word_length = max(len(word) for word, _, _ in words_with_probs) | |
| for word, probability, percentage in words_with_probs: | |
| # Create aligned progress bar with better blocks | |
| bar_length = 20 | |
| filled_length = int(bar_length * probability) | |
| bar = '█' * filled_length + '▢' * (bar_length - filled_length) | |
| # Align everything properly | |
| word_padded = word.ljust(max_word_length) | |
| result = f"{word_padded} | {probability:.4f} ({percentage:5.2f}%) {bar}" | |
| results.append(result) | |
| suggested_words.append(word) | |
| return results, suggested_words | |
| # Initialize the predictor | |
| predictor = NextWordPredictor() | |
| def update_predictions(text): | |
| """Update predictions based on current text""" | |
| predictions_list, suggested_words = predictor.predict_next_words(text) | |
| if not predictions_list: | |
| return [gr.update(visible=False, interactive=True)] * 10 | |
| # Update buttons with predictions, hide unused ones, enable all buttons | |
| updates = [] | |
| for i in range(10): | |
| if i < len(predictions_list): | |
| updates.append(gr.update(value=predictions_list[i], visible=True, interactive=True)) | |
| else: | |
| updates.append(gr.update(visible=False, interactive=True)) | |
| return updates | |
| def disable_all_buttons(): | |
| """Disable all prediction buttons""" | |
| return [gr.update(interactive=False)] * 10 | |
| def add_word_to_text(current_text, button_value): | |
| """Extract word from button and add to text""" | |
| if not button_value: | |
| return current_text | |
| # Extract the word (everything before the first "|") | |
| word = button_value.split(" | ")[0].strip() | |
| if not current_text.strip(): | |
| return word | |
| # Add space if text doesn't end with space | |
| if current_text.endswith(' '): | |
| return current_text + word | |
| else: | |
| return current_text + ' ' + word | |
| # Create Gradio interface | |
| with gr.Blocks(title="Next Word Predictor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Next Word Predictor") | |
| gr.Markdown("Type a sentence and see the top 10 most likely next words with their probabilities! **Click on any prediction to add that word to your text.**") | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Enter your text", | |
| placeholder="Start typing a sentence...", | |
| lines=4, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["The weather today is"], | |
| ["I love to eat"], | |
| ["Machine learning is"], | |
| ["The quick brown fox"], | |
| # ["In the future, we will"] | |
| ], | |
| inputs=text_input | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### Top 10 Next Word Predictions") | |
| gr.Markdown("*Click any prediction below to add it to your text*") | |
| # Create 10 clickable buttons for predictions | |
| prediction_buttons = [] | |
| for i in range(10): | |
| with gr.Row(): | |
| btn = gr.Button( | |
| value="", | |
| visible=False, | |
| variant="secondary", | |
| size="sm", | |
| interactive=True | |
| ) | |
| prediction_buttons.append(btn) | |
| # Update predictions as user types | |
| text_input.change( | |
| fn=update_predictions, | |
| inputs=text_input, | |
| outputs=prediction_buttons | |
| ) | |
| # Add click handlers for each prediction button | |
| for btn in prediction_buttons: | |
| # First disable all buttons when any button is clicked | |
| btn.click( | |
| fn=disable_all_buttons, | |
| inputs=[], | |
| outputs=prediction_buttons | |
| ).then( | |
| # Then add the word to text | |
| fn=add_word_to_text, | |
| inputs=[text_input, btn], | |
| outputs=text_input | |
| ).then( | |
| # Finally update predictions (which will re-enable buttons) | |
| fn=update_predictions, | |
| inputs=text_input, | |
| outputs=prediction_buttons | |
| ) | |
| # gr.Markdown("### How it works:") | |
| # gr.Markdown(""" | |
| # - Uses GPT-2 language model to predict next words | |
| # - Applies softmax to convert logits to probabilities | |
| # - Shows top 10 most likely words with percentages and aligned visual bars | |
| # - Updates predictions in real-time as you type | |
| # - **Click on any prediction button to add that word to your text automatically** | |
| # - **Buttons are disabled while processing to prevent multiple clicks** | |
| # - Progress bars show relative probability: █ = filled, ▢ = empty outline | |
| # - All bars are perfectly aligned for easy comparison | |
| # """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |