JermaineAI commited on
Commit
8bdabff
·
1 Parent(s): faa30e1

Switch from Gradio to Streamlit

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +71 -58
  3. requirements.txt +1 -1
README.md CHANGED
@@ -3,8 +3,8 @@ title: Nigerian Pidgin Next-Word Predictor
3
  emoji: 💬
4
  colorFrom: green
5
  colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
3
  emoji: 💬
4
  colorFrom: green
5
  colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.32.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,19 +1,26 @@
1
  """
2
- Gradio app for Nigerian Pidgin Next-Word Prediction.
3
  Deploy to Hugging Face Spaces.
4
  """
5
 
6
- import gradio as gr
7
  import torch
8
  import torch.nn as nn
9
  import re
10
- from typing import List, Dict, Tuple
 
 
 
 
 
 
 
11
 
12
  # Special tokens
13
  PAD_TOKEN = '<PAD>'
14
  UNK_TOKEN = '<UNK>'
15
  SOS_TOKEN = '<SOS>'
16
- EOS_TOKEN = '<EOS>'
17
 
18
 
19
  def clean_text(text: str) -> str:
@@ -64,83 +71,89 @@ class LSTMLanguageModel(nn.Module):
64
  return logits
65
 
66
 
67
- # Load model
68
- print("Loading model...")
69
- checkpoint = torch.load('model/lstm_pidgin_model.pt', map_location='cpu')
70
- word_to_idx = checkpoint['word_to_idx']
71
- idx_to_word = checkpoint['idx_to_word']
72
- vocab_size = checkpoint['vocab_size']
73
-
74
- model = LSTMLanguageModel(vocab_size=vocab_size)
75
- model.load_state_dict(checkpoint['model_state_dict'])
76
- model.eval()
77
- print(f"Model loaded! Vocab size: {vocab_size:,}")
 
 
78
 
79
 
80
- def predict_next_words(context: str, top_k: int = 5) -> str:
81
  """Predict next words given context."""
82
  if not context.strip():
83
- return "Please enter some text..."
84
 
85
- # Tokenize and convert to indices
86
  tokens = tokenize(clean_text(context))
87
  if not tokens:
88
- return "No valid tokens found in input."
89
 
90
  unk_idx = word_to_idx.get(UNK_TOKEN, 1)
91
  indices = [word_to_idx.get(t, unk_idx) for t in tokens]
92
 
93
- # Create input tensor
94
  x = torch.tensor([indices], dtype=torch.long)
95
 
96
  with torch.no_grad():
97
  logits = model(x)
98
  probs = torch.softmax(logits, dim=-1)
99
 
100
- # Get top-k predictions
101
  top_probs, top_indices = torch.topk(probs[0], top_k)
102
 
103
  results = []
104
  for prob, idx in zip(top_probs.numpy(), top_indices.numpy()):
105
  word = idx_to_word.get(str(idx), idx_to_word.get(idx, UNK_TOKEN))
106
  if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]:
107
- results.append(f"**{word}** ({prob:.1%})")
108
-
109
- return "\n".join(results) if results else "No predictions available."
110
-
111
-
112
- # Gradio Interface
113
- demo = gr.Interface(
114
- fn=predict_next_words,
115
- inputs=[
116
- gr.Textbox(
117
- label="Enter Nigerian Pidgin text",
118
- placeholder="e.g., 'i dey', 'wetin you', 'how far'",
119
- lines=2
120
- ),
121
- gr.Slider(
122
- minimum=1, maximum=10, value=5, step=1,
123
- label="Number of predictions"
124
- )
125
- ],
126
- outputs=gr.Markdown(label="Predicted next words"),
127
- title="🇳🇬 Nigerian Pidgin Next-Word Predictor",
128
- description="""
129
- **LSTM Language Model** trained on Nigerian Pidgin text.
130
-
131
- Enter some Pidgin text and get predictions for the next word!
132
 
133
- Try: "i dey", "wetin you", "na the", "how far", "e don"
134
- """,
135
- examples=[
136
- ["i dey", 5],
137
- ["wetin you", 5],
138
- ["how far", 5],
139
- ["na the", 5],
140
- ["e don", 5],
141
- ],
142
- theme=gr.themes.Soft()
 
 
 
 
143
  )
144
 
145
- if __name__ == "__main__":
146
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Streamlit app for Nigerian Pidgin Next-Word Prediction.
3
  Deploy to Hugging Face Spaces.
4
  """
5
 
6
+ import streamlit as st
7
  import torch
8
  import torch.nn as nn
9
  import re
10
+ from typing import List, Dict
11
+
12
+ # Page config
13
+ st.set_page_config(
14
+ page_title="Nigerian Pidgin Predictor",
15
+ page_icon="💬",
16
+ layout="centered"
17
+ )
18
 
19
  # Special tokens
20
  PAD_TOKEN = '<PAD>'
21
  UNK_TOKEN = '<UNK>'
22
  SOS_TOKEN = '<SOS>'
23
+ EOS_TOKEN = '</EOS>'
24
 
25
 
26
  def clean_text(text: str) -> str:
 
71
  return logits
72
 
73
 
74
+ @st.cache_resource
75
+ def load_model():
76
+ """Load model (cached)."""
77
+ checkpoint = torch.load('model/lstm_pidgin_model.pt', map_location='cpu')
78
+ word_to_idx = checkpoint['word_to_idx']
79
+ idx_to_word = checkpoint['idx_to_word']
80
+ vocab_size = checkpoint['vocab_size']
81
+
82
+ model = LSTMLanguageModel(vocab_size=vocab_size)
83
+ model.load_state_dict(checkpoint['model_state_dict'])
84
+ model.eval()
85
+
86
+ return model, word_to_idx, idx_to_word
87
 
88
 
89
+ def predict_next_words(context: str, model, word_to_idx, idx_to_word, top_k: int = 5):
90
  """Predict next words given context."""
91
  if not context.strip():
92
+ return []
93
 
 
94
  tokens = tokenize(clean_text(context))
95
  if not tokens:
96
+ return []
97
 
98
  unk_idx = word_to_idx.get(UNK_TOKEN, 1)
99
  indices = [word_to_idx.get(t, unk_idx) for t in tokens]
100
 
 
101
  x = torch.tensor([indices], dtype=torch.long)
102
 
103
  with torch.no_grad():
104
  logits = model(x)
105
  probs = torch.softmax(logits, dim=-1)
106
 
 
107
  top_probs, top_indices = torch.topk(probs[0], top_k)
108
 
109
  results = []
110
  for prob, idx in zip(top_probs.numpy(), top_indices.numpy()):
111
  word = idx_to_word.get(str(idx), idx_to_word.get(idx, UNK_TOKEN))
112
  if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]:
113
+ results.append((word, float(prob)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ return results
116
+
117
+
118
+ # Load model
119
+ model, word_to_idx, idx_to_word = load_model()
120
+
121
+ # UI
122
+ st.title("💬 Nigerian Pidgin Next-Word Predictor")
123
+ st.markdown("**LSTM Language Model** trained on Nigerian Pidgin text.")
124
+
125
+ # Input
126
+ context = st.text_input(
127
+ "Enter Nigerian Pidgin text:",
128
+ placeholder="e.g., 'i dey', 'wetin you', 'how far'"
129
  )
130
 
131
+ top_k = st.slider("Number of predictions:", 1, 10, 5)
132
+
133
+ # Predict button
134
+ if st.button("Predict", type="primary") or context:
135
+ if context:
136
+ predictions = predict_next_words(context, model, word_to_idx, idx_to_word, top_k)
137
+
138
+ if predictions:
139
+ st.subheader("Predictions:")
140
+ for word, prob in predictions:
141
+ st.markdown(f"**{word}** — {prob:.1%}")
142
+ else:
143
+ st.warning("No predictions available.")
144
+ else:
145
+ st.info("Enter some text to get predictions.")
146
+
147
+ # Examples
148
+ st.markdown("---")
149
+ st.markdown("**Try these examples:**")
150
+ cols = st.columns(4)
151
+ examples = ["i dey", "wetin you", "how far", "e don"]
152
+ for col, ex in zip(cols, examples):
153
+ if col.button(ex):
154
+ st.session_state['context'] = ex
155
+ st.rerun()
156
+
157
+ # Footer
158
+ st.markdown("---")
159
+ st.caption("Trained on NaijaSenti + BBC Pidgin corpus (~10k texts)")
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  datasets>=2.14.0
2
  torch>=2.0.0
3
- gradio>=4.0.0
 
1
  datasets>=2.14.0
2
  torch>=2.0.0
3
+ streamlit>=1.32.0