JermaineAI commited on
Commit
d670bc8
Β·
1 Parent(s): 8bdabff

Enhanced Streamlit app with LSTM + Trigram comparison

Browse files
Files changed (3) hide show
  1. app.py +317 -75
  2. model/trigram_model.pkl +3 -0
  3. save_trigram.py +38 -0
app.py CHANGED
@@ -1,30 +1,126 @@
1
  """
2
- Streamlit app for Nigerian Pidgin Next-Word Prediction.
3
- Deploy to Hugging Face Spaces.
4
  """
5
 
6
  import streamlit as st
7
  import torch
8
  import torch.nn as nn
9
  import re
10
- from typing import List, Dict
 
 
 
11
 
12
- # Page config
 
 
13
  st.set_page_config(
14
  page_title="Nigerian Pidgin Predictor",
15
  page_icon="πŸ’¬",
16
- layout="centered"
17
  )
18
 
19
- # Special tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  PAD_TOKEN = '<PAD>'
21
  UNK_TOKEN = '<UNK>'
22
  SOS_TOKEN = '<SOS>'
23
  EOS_TOKEN = '</EOS>'
 
 
24
 
25
-
 
 
26
  def clean_text(text: str) -> str:
27
- """Clean text while preserving Nigerian Pidgin features."""
28
  text = text.lower()
29
  text = re.sub(r'https?://\S+', '', text)
30
  text = re.sub(r'www\.\S+', '', text)
@@ -33,61 +129,125 @@ def clean_text(text: str) -> str:
33
  text = re.sub(r'\s+', ' ', text)
34
  return text.strip()
35
 
36
-
37
  def tokenize(text: str) -> List[str]:
38
- """Simple word tokenization."""
39
  tokens = re.findall(r"[\w']+|[.,!?;:]", text)
40
  return tokens
41
 
42
-
 
 
43
  class LSTMLanguageModel(nn.Module):
44
- """LSTM-based language model for next-word prediction."""
45
-
46
- def __init__(
47
- self,
48
- vocab_size: int,
49
- embed_dim: int = 256,
50
- hidden_dim: int = 512,
51
- num_layers: int = 2,
52
- dropout: float = 0.3
53
- ):
54
  super().__init__()
55
  self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
56
- self.lstm = nn.LSTM(
57
- embed_dim, hidden_dim, num_layers=num_layers,
58
- batch_first=True, dropout=dropout if num_layers > 1 else 0
59
- )
60
  self.dropout = nn.Dropout(dropout)
61
  self.fc = nn.Linear(hidden_dim, vocab_size)
62
- self.hidden_dim = hidden_dim
63
- self.num_layers = num_layers
64
 
65
  def forward(self, x):
66
  embedded = self.embedding(x)
67
  lstm_out, _ = self.lstm(embedded)
68
  last_out = lstm_out[:, -1, :]
69
  out = self.dropout(last_out)
70
- logits = self.fc(out)
71
- return logits
72
 
73
-
74
- @st.cache_resource
75
- def load_model():
76
- """Load model (cached)."""
77
- checkpoint = torch.load('model/lstm_pidgin_model.pt', map_location='cpu')
78
- word_to_idx = checkpoint['word_to_idx']
79
- idx_to_word = checkpoint['idx_to_word']
80
- vocab_size = checkpoint['vocab_size']
 
 
81
 
82
- model = LSTMLanguageModel(vocab_size=vocab_size)
83
- model.load_state_dict(checkpoint['model_state_dict'])
84
- model.eval()
 
 
 
 
85
 
86
- return model, word_to_idx, idx_to_word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- def predict_next_words(context: str, model, word_to_idx, idx_to_word, top_k: int = 5):
90
- """Predict next words given context."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  if not context.strip():
92
  return []
93
 
@@ -97,63 +257,145 @@ def predict_next_words(context: str, model, word_to_idx, idx_to_word, top_k: int
97
 
98
  unk_idx = word_to_idx.get(UNK_TOKEN, 1)
99
  indices = [word_to_idx.get(t, unk_idx) for t in tokens]
100
-
101
  x = torch.tensor([indices], dtype=torch.long)
102
 
103
  with torch.no_grad():
104
  logits = model(x)
105
  probs = torch.softmax(logits, dim=-1)
106
 
107
- top_probs, top_indices = torch.topk(probs[0], top_k)
108
 
109
  results = []
110
  for prob, idx in zip(top_probs.numpy(), top_indices.numpy()):
111
  word = idx_to_word.get(str(idx), idx_to_word.get(idx, UNK_TOKEN))
112
  if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]:
113
  results.append((word, float(prob)))
 
 
114
 
115
  return results
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- # Load model
119
- model, word_to_idx, idx_to_word = load_model()
 
120
 
121
- # UI
122
- st.title("πŸ’¬ Nigerian Pidgin Next-Word Predictor")
123
- st.markdown("**LSTM Language Model** trained on Nigerian Pidgin text.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Input
 
126
  context = st.text_input(
127
- "Enter Nigerian Pidgin text:",
128
- placeholder="e.g., 'i dey', 'wetin you', 'how far'"
 
129
  )
130
 
131
- top_k = st.slider("Number of predictions:", 1, 10, 5)
 
 
 
 
 
 
132
 
133
- # Predict button
134
- if st.button("Predict", type="primary") or context:
135
- if context:
136
- predictions = predict_next_words(context, model, word_to_idx, idx_to_word, top_k)
 
 
 
 
 
 
 
 
 
 
137
 
138
- if predictions:
139
- st.subheader("Predictions:")
140
- for word, prob in predictions:
141
- st.markdown(f"**{word}** β€” {prob:.1%}")
 
 
 
 
 
 
 
 
 
142
  else:
143
- st.warning("No predictions available.")
 
144
  else:
145
- st.info("Enter some text to get predictions.")
146
-
147
- # Examples
148
- st.markdown("---")
149
- st.markdown("**Try these examples:**")
150
- cols = st.columns(4)
151
- examples = ["i dey", "wetin you", "how far", "e don"]
152
- for col, ex in zip(cols, examples):
153
- if col.button(ex):
154
- st.session_state['context'] = ex
155
- st.rerun()
156
 
157
  # Footer
158
  st.markdown("---")
159
- st.caption("Trained on NaijaSenti + BBC Pidgin corpus (~10k texts)")
 
 
 
 
 
 
1
  """
2
+ Nigerian Pidgin Next-Word Prediction - Streamlit App
3
+ Supports both LSTM and Trigram models for comparison.
4
  """
5
 
6
  import streamlit as st
7
  import torch
8
  import torch.nn as nn
9
  import re
10
+ import pickle
11
+ import os
12
+ from collections import Counter
13
+ from typing import List, Dict, Tuple, Optional
14
 
15
+ # =============================================================================
16
+ # Page Config & Custom CSS
17
+ # =============================================================================
18
  st.set_page_config(
19
  page_title="Nigerian Pidgin Predictor",
20
  page_icon="πŸ’¬",
21
+ layout="wide"
22
  )
23
 
24
+ # Custom CSS for beautiful styling
25
+ st.markdown("""
26
+ <style>
27
+ /* Main container */
28
+ .main > div {
29
+ padding-top: 2rem;
30
+ }
31
+
32
+ /* Header styling */
33
+ .main-header {
34
+ background: linear-gradient(135deg, #1a5f2a 0%, #2d8a3e 50%, #f4c430 100%);
35
+ padding: 2rem;
36
+ border-radius: 15px;
37
+ margin-bottom: 2rem;
38
+ text-align: center;
39
+ color: white;
40
+ }
41
+
42
+ .main-header h1 {
43
+ color: white !important;
44
+ margin-bottom: 0.5rem;
45
+ }
46
+
47
+ /* Prediction cards */
48
+ .prediction-card {
49
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
50
+ border-radius: 12px;
51
+ padding: 1rem 1.5rem;
52
+ margin: 0.5rem 0;
53
+ border-left: 4px solid #2d8a3e;
54
+ transition: transform 0.2s, box-shadow 0.2s;
55
+ }
56
+
57
+ .prediction-card:hover {
58
+ transform: translateX(5px);
59
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
60
+ }
61
+
62
+ .word {
63
+ font-size: 1.3rem;
64
+ font-weight: 600;
65
+ color: #1a5f2a;
66
+ }
67
+
68
+ .prob {
69
+ font-size: 1rem;
70
+ color: #666;
71
+ }
72
+
73
+ /* Model selector */
74
+ .stRadio > div {
75
+ display: flex;
76
+ gap: 1rem;
77
+ }
78
+
79
+ /* Example buttons */
80
+ .stButton > button {
81
+ border-radius: 20px;
82
+ border: 2px solid #2d8a3e;
83
+ background: white;
84
+ color: #2d8a3e;
85
+ transition: all 0.3s;
86
+ }
87
+
88
+ .stButton > button:hover {
89
+ background: #2d8a3e;
90
+ color: white;
91
+ }
92
+
93
+ /* Comparison columns */
94
+ .model-column {
95
+ background: #f8f9fa;
96
+ border-radius: 12px;
97
+ padding: 1rem;
98
+ }
99
+
100
+ /* Footer */
101
+ .footer {
102
+ text-align: center;
103
+ padding: 2rem;
104
+ color: #666;
105
+ font-size: 0.9rem;
106
+ }
107
+ </style>
108
+ """, unsafe_allow_html=True)
109
+
110
+ # =============================================================================
111
+ # Special Tokens
112
+ # =============================================================================
113
  PAD_TOKEN = '<PAD>'
114
  UNK_TOKEN = '<UNK>'
115
  SOS_TOKEN = '<SOS>'
116
  EOS_TOKEN = '</EOS>'
117
+ START_TOKEN = '<s>'
118
+ END_TOKEN = '</s>'
119
 
120
+ # =============================================================================
121
+ # Text Processing
122
+ # =============================================================================
123
  def clean_text(text: str) -> str:
 
124
  text = text.lower()
125
  text = re.sub(r'https?://\S+', '', text)
126
  text = re.sub(r'www\.\S+', '', text)
 
129
  text = re.sub(r'\s+', ' ', text)
130
  return text.strip()
131
 
 
132
  def tokenize(text: str) -> List[str]:
 
133
  tokens = re.findall(r"[\w']+|[.,!?;:]", text)
134
  return tokens
135
 
136
+ # =============================================================================
137
+ # LSTM Model
138
+ # =============================================================================
139
  class LSTMLanguageModel(nn.Module):
140
+ def __init__(self, vocab_size: int, embed_dim: int = 256,
141
+ hidden_dim: int = 512, num_layers: int = 2, dropout: float = 0.3):
 
 
 
 
 
 
 
 
142
  super().__init__()
143
  self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
144
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
145
+ batch_first=True, dropout=dropout if num_layers > 1 else 0)
 
 
146
  self.dropout = nn.Dropout(dropout)
147
  self.fc = nn.Linear(hidden_dim, vocab_size)
 
 
148
 
149
  def forward(self, x):
150
  embedded = self.embedding(x)
151
  lstm_out, _ = self.lstm(embedded)
152
  last_out = lstm_out[:, -1, :]
153
  out = self.dropout(last_out)
154
+ return self.fc(out)
 
155
 
156
+ # =============================================================================
157
+ # Trigram Model
158
+ # =============================================================================
159
+ class TrigramLM:
160
+ def __init__(self, smoothing: float = 1.0):
161
+ self.smoothing = smoothing
162
+ self.unigram_counts = Counter()
163
+ self.bigram_counts = Counter()
164
+ self.trigram_counts = Counter()
165
+ self.vocab = set()
166
 
167
+ def probability(self, w3: str, w1: str, w2: str) -> float:
168
+ trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
169
+ bigram_count = self.bigram_counts.get((w1, w2), 0)
170
+ vocab_size = len(self.vocab)
171
+ numerator = trigram_count + self.smoothing
172
+ denominator = bigram_count + (self.smoothing * vocab_size)
173
+ return numerator / denominator if denominator > 0 else 0.0
174
 
175
+ def predict_next_words(self, context: str, top_k: int = 5) -> List[Tuple[str, float]]:
176
+ words = context.lower().split()
177
+ if len(words) == 0:
178
+ w1, w2 = START_TOKEN, START_TOKEN
179
+ elif len(words) == 1:
180
+ w1, w2 = START_TOKEN, words[0]
181
+ else:
182
+ w1, w2 = words[-2], words[-1]
183
+
184
+ candidates = []
185
+ for word in self.vocab:
186
+ if word not in (START_TOKEN, END_TOKEN, '<s>', '</s>'):
187
+ prob = self.probability(word, w1, w2)
188
+ candidates.append((word, prob))
189
+
190
+ candidates.sort(key=lambda x: x[1], reverse=True)
191
+ return candidates[:top_k]
192
 
193
+ # =============================================================================
194
+ # Model Loading
195
+ # =============================================================================
196
+ @st.cache_resource
197
+ def load_lstm_model():
198
+ """Load LSTM model."""
199
+ try:
200
+ checkpoint = torch.load('model/lstm_pidgin_model.pt', map_location='cpu')
201
+ word_to_idx = checkpoint['word_to_idx']
202
+ idx_to_word = checkpoint['idx_to_word']
203
+ vocab_size = checkpoint['vocab_size']
204
+
205
+ model = LSTMLanguageModel(vocab_size=vocab_size)
206
+ model.load_state_dict(checkpoint['model_state_dict'])
207
+ model.eval()
208
+ return model, word_to_idx, idx_to_word, True
209
+ except Exception as e:
210
+ return None, None, None, False
211
 
212
+ @st.cache_resource
213
+ def load_trigram_model():
214
+ """Load or build Trigram model."""
215
+ try:
216
+ # Try to load pre-saved trigram model
217
+ if os.path.exists('model/trigram_model.pkl'):
218
+ with open('model/trigram_model.pkl', 'rb') as f:
219
+ model = pickle.load(f)
220
+ return model, True
221
+ else:
222
+ # Build a simple demo trigram with common patterns
223
+ model = TrigramLM(smoothing=1.0)
224
+ # Add some common Nigerian Pidgin patterns
225
+ common_patterns = [
226
+ ['<s>', '<s>', 'i', 'dey', 'go', '</s>'],
227
+ ['<s>', '<s>', 'i', 'dey', 'come', '</s>'],
228
+ ['<s>', '<s>', 'wetin', 'you', 'dey', 'do', '</s>'],
229
+ ['<s>', '<s>', 'how', 'far', '</s>'],
230
+ ['<s>', '<s>', 'e', 'don', 'happen', '</s>'],
231
+ ['<s>', '<s>', 'na', 'the', 'matter', '</s>'],
232
+ ['<s>', '<s>', 'you', 'no', 'sabi', '</s>'],
233
+ ['<s>', '<s>', 'make', 'we', 'go', '</s>'],
234
+ ]
235
+ for sent in common_patterns:
236
+ model.vocab.update(sent)
237
+ for token in sent:
238
+ model.unigram_counts[token] += 1
239
+ for i in range(len(sent) - 1):
240
+ model.bigram_counts[(sent[i], sent[i+1])] += 1
241
+ for i in range(len(sent) - 2):
242
+ model.trigram_counts[(sent[i], sent[i+1], sent[i+2])] += 1
243
+ return model, True
244
+ except Exception as e:
245
+ return None, False
246
+
247
+ # =============================================================================
248
+ # Prediction Functions
249
+ # =============================================================================
250
+ def predict_lstm(context: str, model, word_to_idx, idx_to_word, top_k: int = 5):
251
  if not context.strip():
252
  return []
253
 
 
257
 
258
  unk_idx = word_to_idx.get(UNK_TOKEN, 1)
259
  indices = [word_to_idx.get(t, unk_idx) for t in tokens]
 
260
  x = torch.tensor([indices], dtype=torch.long)
261
 
262
  with torch.no_grad():
263
  logits = model(x)
264
  probs = torch.softmax(logits, dim=-1)
265
 
266
+ top_probs, top_indices = torch.topk(probs[0], top_k + 5)
267
 
268
  results = []
269
  for prob, idx in zip(top_probs.numpy(), top_indices.numpy()):
270
  word = idx_to_word.get(str(idx), idx_to_word.get(idx, UNK_TOKEN))
271
  if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]:
272
  results.append((word, float(prob)))
273
+ if len(results) >= top_k:
274
+ break
275
 
276
  return results
277
 
278
+ def predict_trigram(context: str, model, top_k: int = 5):
279
+ if not context.strip() or model is None:
280
+ return []
281
+ return model.predict_next_words(context, top_k)
282
+
283
+ # =============================================================================
284
+ # UI Components
285
+ # =============================================================================
286
+ def render_predictions(predictions: List[Tuple[str, float]], model_name: str):
287
+ if not predictions:
288
+ st.warning(f"No predictions from {model_name}")
289
+ return
290
+
291
+ for word, prob in predictions:
292
+ st.markdown(f"""
293
+ <div class="prediction-card">
294
+ <span class="word">{word}</span>
295
+ <span class="prob"> β€” {prob:.1%}</span>
296
+ </div>
297
+ """, unsafe_allow_html=True)
298
 
299
+ # =============================================================================
300
+ # Main App
301
+ # =============================================================================
302
 
303
+ # Header
304
+ st.markdown("""
305
+ <div class="main-header">
306
+ <h1>πŸ’¬ Nigerian Pidgin Next-Word Predictor</h1>
307
+ <p>Compare LSTM neural network vs Trigram statistical model</p>
308
+ </div>
309
+ """, unsafe_allow_html=True)
310
+
311
+ # Load models
312
+ lstm_model, word_to_idx, idx_to_word, lstm_loaded = load_lstm_model()
313
+ trigram_model, trigram_loaded = load_trigram_model()
314
+
315
+ # Sidebar
316
+ with st.sidebar:
317
+ st.header("βš™οΈ Settings")
318
+
319
+ model_choice = st.radio(
320
+ "Select Model:",
321
+ ["πŸ€– LSTM (Neural)", "πŸ“Š Trigram (Statistical)", "βš”οΈ Compare Both"],
322
+ index=2
323
+ )
324
+
325
+ top_k = st.slider("Number of predictions:", 1, 10, 5)
326
+
327
+ st.markdown("---")
328
+ st.markdown("### πŸ“– About")
329
+ st.markdown("""
330
+ **LSTM Model**: Neural network that learns patterns from data. Better at capturing complex dependencies.
331
+
332
+ **Trigram Model**: Statistical model using word co-occurrence counts. Fast and interpretable.
333
+ """)
334
+
335
+ st.markdown("---")
336
+ st.markdown("### πŸ”— Links")
337
+ st.markdown("[GitHub](https://github.com/Jaykay73/nextword-pidgin)")
338
 
339
+ # Main input
340
+ st.markdown("### Enter Nigerian Pidgin text:")
341
  context = st.text_input(
342
+ label="Context",
343
+ placeholder="e.g., 'i dey', 'wetin you', 'how far'",
344
+ label_visibility="collapsed"
345
  )
346
 
347
+ # Example buttons
348
+ st.markdown("**Try these examples:**")
349
+ example_cols = st.columns(5)
350
+ examples = ["i dey", "wetin you", "how far", "e don", "make we"]
351
+ for col, ex in zip(example_cols, examples):
352
+ if col.button(ex, use_container_width=True):
353
+ context = ex
354
 
355
+ # Predictions
356
+ if context:
357
+ st.markdown("---")
358
+
359
+ if "Compare" in model_choice:
360
+ col1, col2 = st.columns(2)
361
+
362
+ with col1:
363
+ st.markdown("### πŸ€– LSTM Neural Network")
364
+ if lstm_loaded:
365
+ predictions = predict_lstm(context, lstm_model, word_to_idx, idx_to_word, top_k)
366
+ render_predictions(predictions, "LSTM")
367
+ else:
368
+ st.error("LSTM model not loaded")
369
 
370
+ with col2:
371
+ st.markdown("### πŸ“Š Trigram Statistical")
372
+ if trigram_loaded:
373
+ predictions = predict_trigram(context, trigram_model, top_k)
374
+ render_predictions(predictions, "Trigram")
375
+ else:
376
+ st.error("Trigram model not loaded")
377
+
378
+ elif "LSTM" in model_choice:
379
+ st.markdown("### πŸ€– LSTM Predictions")
380
+ if lstm_loaded:
381
+ predictions = predict_lstm(context, lstm_model, word_to_idx, idx_to_word, top_k)
382
+ render_predictions(predictions, "LSTM")
383
  else:
384
+ st.error("LSTM model not loaded")
385
+
386
  else:
387
+ st.markdown("### πŸ“Š Trigram Predictions")
388
+ if trigram_loaded:
389
+ predictions = predict_trigram(context, trigram_model, top_k)
390
+ render_predictions(predictions, "Trigram")
391
+ else:
392
+ st.error("Trigram model not loaded")
 
 
 
 
 
393
 
394
  # Footer
395
  st.markdown("---")
396
+ st.markdown("""
397
+ <div class="footer">
398
+ <p>Trained on <strong>NaijaSenti</strong> + <strong>BBC Pidgin</strong> corpus (~10k texts)</p>
399
+ <p>πŸ‡³πŸ‡¬ Nigerian Pidgin Language Model</p>
400
+ </div>
401
+ """, unsafe_allow_html=True)
model/trigram_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fe85f4bc3b84c739e714e35e15cc80cf35108947d3c194ca9079edf09cd4149
3
+ size 15507557
save_trigram.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Save the trained trigram model for use in the Streamlit app.
3
+ """
4
+
5
+ import pickle
6
+ import os
7
+ from src.data_loader import load_all_texts
8
+ from src.preprocessing import preprocess_corpus
9
+ from src.trigram_model import TrigramLM
10
+
11
+ def save_trigram_model():
12
+ print("Loading data...")
13
+ texts = load_all_texts(include_bbc=True)
14
+
15
+ print("Preprocessing...")
16
+ sentences = preprocess_corpus(texts)
17
+
18
+ print("Training trigram model...")
19
+ model = TrigramLM(smoothing=1.0)
20
+ model.train(sentences)
21
+
22
+ # Ensure model directory exists
23
+ os.makedirs('model', exist_ok=True)
24
+
25
+ print("Saving model...")
26
+ with open('model/trigram_model.pkl', 'wb') as f:
27
+ pickle.dump(model, f)
28
+
29
+ print("Done! Saved to model/trigram_model.pkl")
30
+
31
+ # Test predictions
32
+ print("\nTest predictions:")
33
+ for ctx in ["i dey", "wetin you", "how far"]:
34
+ preds = model.predict_next_words(ctx, top_k=3)
35
+ print(f" '{ctx}' -> {preds}")
36
+
37
+ if __name__ == "__main__":
38
+ save_trigram_model()