willsh1997 commited on
Commit
6ddb051
·
1 Parent(s): 240ea59

:sparkles: initial commit - token prediction app

Browse files
Files changed (3) hide show
  1. README.md +14 -2
  2. next_word_predictor.py +183 -0
  3. requirements.txt +87 -0
README.md CHANGED
@@ -1,2 +1,14 @@
1
- # widget-token-predictor
2
- widget for demonstrating next token prediction in GPT2
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Next Word Predictor
3
+ emoji: 🏆
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.23.3
8
+ app_file: next_word_predictor.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: generates linkedin posts from freetext entries
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
next_word_predictor.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ import torch.nn.functional as F
5
+ import spaces
6
+
7
+ class NextWordPredictor:
8
+ def __init__(self):
9
+ # Load pre-trained GPT-2 model and tokenizer
10
+ self.model_name = "gpt2"
11
+ self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
12
+ self.model = GPT2LMHeadModel.from_pretrained(self.model_name)
13
+
14
+ # Set padding token
15
+ self.tokenizer.pad_token = self.tokenizer.eos_token
16
+
17
+ # Set model to evaluation mode
18
+ self.model.eval()
19
+
20
+ @spaces.GPU
21
+ def predict_next_words(self, text, top_k=10):
22
+ """
23
+ Predict the next word given input text
24
+ Returns top_k most likely words with their probabilities and suggested words
25
+ """
26
+ text = text.strip()
27
+ if not text:
28
+ return [], []
29
+
30
+ # Tokenize input text
31
+ inputs = self.tokenizer.encode(text, return_tensors='pt')
32
+
33
+ # Get model predictions
34
+ with torch.no_grad():
35
+ outputs = self.model(inputs)
36
+ predictions = outputs.logits[0, -1, :] # Get last token predictions
37
+
38
+ # Apply softmax to get probabilities
39
+ probabilities = F.softmax(predictions, dim=-1)
40
+
41
+ # Get top k predictions
42
+ top_k_probs, top_k_indices = torch.topk(probabilities, top_k)
43
+
44
+ # Convert to readable format with aligned progress bars
45
+ results = []
46
+ suggested_words = []
47
+
48
+ # Find the longest word for alignment
49
+ words_with_probs = []
50
+ for prob, idx in zip(top_k_probs, top_k_indices):
51
+ word = self.tokenizer.decode(idx.item()).strip()
52
+ probability = prob.item()
53
+ percentage = probability * 100
54
+ words_with_probs.append((word, probability, percentage))
55
+
56
+ # Find max word length for alignment
57
+ max_word_length = max(len(word) for word, _, _ in words_with_probs)
58
+
59
+ for word, probability, percentage in words_with_probs:
60
+ # Create aligned progress bar with better blocks
61
+ bar_length = 20
62
+ filled_length = int(bar_length * probability)
63
+ bar = '█' * filled_length + '▢' * (bar_length - filled_length)
64
+
65
+ # Align everything properly
66
+ word_padded = word.ljust(max_word_length)
67
+ result = f"{word_padded} | {probability:.4f} ({percentage:5.2f}%) {bar}"
68
+
69
+ results.append(result)
70
+ suggested_words.append(word)
71
+
72
+ return results, suggested_words
73
+
74
+ # Initialize the predictor
75
+ predictor = NextWordPredictor()
76
+
77
+ def update_predictions(text):
78
+ """Update predictions based on current text"""
79
+ predictions_list, suggested_words = predictor.predict_next_words(text)
80
+
81
+ if not predictions_list:
82
+ return [gr.update(visible=False)] * 10
83
+
84
+ # Update buttons with predictions, hide unused ones
85
+ updates = []
86
+ for i in range(10):
87
+ if i < len(predictions_list):
88
+ updates.append(gr.update(value=predictions_list[i], visible=True))
89
+ else:
90
+ updates.append(gr.update(visible=False))
91
+
92
+ return updates
93
+
94
+ def add_word_to_text(current_text, button_value):
95
+ """Extract word from button and add to text"""
96
+ if not button_value:
97
+ return current_text
98
+
99
+ # Extract the word (everything before the first "|")
100
+ word = button_value.split(" | ")[0].strip()
101
+
102
+ if not current_text.strip():
103
+ return word
104
+ # Add space if text doesn't end with space
105
+ if current_text.endswith(' '):
106
+ return current_text + word
107
+ else:
108
+ return current_text + ' ' + word
109
+
110
+ # Create Gradio interface
111
+ with gr.Blocks(title="Next Word Predictor", theme=gr.themes.Soft()) as demo:
112
+ gr.Markdown("# Next Word Predictor")
113
+ gr.Markdown("Type a sentence and see the top 10 most likely next words with their probabilities! **Click on any prediction to add that word to your text.**")
114
+
115
+ with gr.Row():
116
+ with gr.Column(scale=2):
117
+ text_input = gr.Textbox(
118
+ label="Enter your text",
119
+ placeholder="Start typing a sentence...",
120
+ lines=4,
121
+ interactive=True
122
+ )
123
+
124
+ with gr.Column(scale=1):
125
+ gr.Markdown("### Top 10 Next Word Predictions")
126
+ gr.Markdown("*Click any prediction below to add it to your text*")
127
+
128
+ # Create 10 clickable buttons for predictions
129
+ prediction_buttons = []
130
+ for i in range(10):
131
+ btn = gr.Button(
132
+ value="",
133
+ visible=False,
134
+ variant="secondary",
135
+ size="sm"
136
+ )
137
+ prediction_buttons.append(btn)
138
+
139
+ # Update predictions as user types
140
+ text_input.change(
141
+ fn=update_predictions,
142
+ inputs=text_input,
143
+ outputs=prediction_buttons
144
+ )
145
+
146
+ # Add click handlers for each prediction button
147
+ for btn in prediction_buttons:
148
+ btn.click(
149
+ fn=add_word_to_text,
150
+ inputs=[text_input, btn],
151
+ outputs=text_input
152
+ ).then(
153
+ fn=update_predictions,
154
+ inputs=text_input,
155
+ outputs=prediction_buttons
156
+ )
157
+
158
+ # Examples
159
+ gr.Examples(
160
+ examples=[
161
+ ["The weather today is"],
162
+ ["I love to eat"],
163
+ ["Machine learning is"],
164
+ ["The quick brown fox"],
165
+ ["In the future, we will"]
166
+ ],
167
+ inputs=text_input
168
+ )
169
+
170
+ gr.Markdown("### How it works:")
171
+ gr.Markdown("""
172
+ - Uses GPT-2 language model to predict next words
173
+ - Applies softmax to convert logits to probabilities
174
+ - Shows top 10 most likely words with percentages and aligned visual bars
175
+ - Updates predictions in real-time as you type
176
+ - **Click on any prediction button to add that word to your text automatically**
177
+ - Progress bars show relative probability: █ = filled, ▢ = empty outline
178
+ - All bars are perfectly aligned for easy comparison
179
+ """)
180
+
181
+ # Launch the app
182
+ if __name__ == "__main__":
183
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.4.0
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.8.0
5
+ asttokens==3.0.0
6
+ bitsandbytes==0.45.4
7
+ certifi==2025.1.31
8
+ charset-normalizer==3.4.1
9
+ click==8.1.8
10
+ comm==0.2.2
11
+ debugpy==1.8.12
12
+ decorator==5.1.1
13
+ exceptiongroup==1.2.2
14
+ executing==2.2.0
15
+ fastapi==0.115.8
16
+ ffmpy==0.5.0
17
+ filelock==3.17.0
18
+ fsspec==2025.2.0
19
+ gradio==5.16.1
20
+ gradio_client==1.7.0
21
+ h11==0.14.0
22
+ httpcore==1.0.7
23
+ httpx==0.28.1
24
+ huggingface-hub==0.28.1
25
+ idna==3.10
26
+ ipykernel==6.29.5
27
+ ipython==8.32.0
28
+ jedi==0.19.2
29
+ Jinja2==3.1.5
30
+ jupyter_client==8.6.3
31
+ jupyter_core==5.7.2
32
+ markdown-it-py==3.0.0
33
+ MarkupSafe==2.1.5
34
+ matplotlib-inline==0.1.7
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ nest-asyncio==1.6.0
38
+ networkx==3.4.2
39
+ numpy==2.2.3
40
+ orjson==3.10.15
41
+ packaging==24.2
42
+ pandas==2.2.3
43
+ parso==0.8.4
44
+ pexpect==4.9.0
45
+ pillow==11.1.0
46
+ platformdirs==4.3.6
47
+ prompt_toolkit==3.0.50
48
+ psutil==7.0.0
49
+ ptyprocess==0.7.0
50
+ pure_eval==0.2.3
51
+ pydantic==2.10.6
52
+ pydantic_core==2.27.2
53
+ pydub==0.25.1
54
+ Pygments==2.19.1
55
+ python-dateutil==2.9.0.post0
56
+ python-multipart==0.0.20
57
+ pytz==2025.1
58
+ PyYAML==6.0.2
59
+ pyzmq==26.2.1
60
+ regex==2024.11.6
61
+ requests==2.32.3
62
+ rich==13.9.4
63
+ ruff==0.9.6
64
+ safehttpx==0.1.6
65
+ safetensors==0.5.2
66
+ semantic-version==2.10.0
67
+ shellingham==1.5.4
68
+ six==1.17.0
69
+ sniffio==1.3.1
70
+ stack-data==0.6.3
71
+ starlette==0.45.3
72
+ sympy==1.13.1
73
+ tokenizers==0.21.0
74
+ tomlkit==0.13.2
75
+ torch==2.4.0
76
+ tornado==6.4.2
77
+ tqdm==4.67.1
78
+ traitlets==5.14.3
79
+ transformers==4.49.0
80
+ typer==0.15.1
81
+ typing_extensions==4.12.2
82
+ tzdata==2025.1
83
+ urllib3==2.3.0
84
+ uvicorn==0.34.0
85
+ wcwidth==0.2.13
86
+ websockets==14.2
87
+