JermaineAI commited on
Commit
5d7219a
·
0 Parent(s):

Initial commit: FastAPI backend for Nigerian Pidgin prediction

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.pkl filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy application
10
+ COPY api.py .
11
+ COPY model/ model/
12
+
13
+ # Expose port
14
+ EXPOSE 7860
15
+
16
+ # Run the API
17
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Nigerian Pidgin Predictor API
3
+ emoji: 🚀
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ # Nigerian Pidgin Next-Word Predictor API
12
+
13
+ FastAPI backend serving LSTM + Trigram models for next-word prediction.
14
+
15
+ ## Endpoints
16
+
17
+ - `GET /` - API info
18
+ - `GET /health` - Health check
19
+ - `POST /predict` - Get predictions from both models
20
+ - `GET /predict/lstm?context=...&top_k=5` - LSTM predictions
21
+ - `GET /predict/trigram?context=...&top_k=5` - Trigram predictions
22
+
23
+ ## Example
24
+
25
+ ```bash
26
+ curl -X POST "https://jaykay73-nextword-pidgin-api.hf.space/predict" \
27
+ -H "Content-Type: application/json" \
28
+ -d '{"context": "i dey", "top_k": 5}'
29
+ ```
api.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI backend for Nigerian Pidgin Next-Word Prediction.
3
+ Serves both LSTM and Trigram models as REST API.
4
+ Deploy to Hugging Face Spaces with Docker SDK.
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from typing import List, Tuple, Optional
11
+ import torch
12
+ import torch.nn as nn
13
+ import pickle
14
+ import re
15
+ import os
16
+
17
+ # =============================================================================
18
+ # FastAPI App
19
+ # =============================================================================
20
+ app = FastAPI(
21
+ title="Nigerian Pidgin Next-Word Predictor API",
22
+ description="LSTM + Trigram models for Nigerian Pidgin next-word prediction",
23
+ version="1.0.0"
24
+ )
25
+
26
+ # Enable CORS for all origins
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # =============================================================================
36
+ # Special Tokens
37
+ # =============================================================================
38
+ PAD_TOKEN = '<PAD>'
39
+ UNK_TOKEN = '<UNK>'
40
+ SOS_TOKEN = '<SOS>'
41
+ EOS_TOKEN = '</EOS>'
42
+ START_TOKEN = '<s>'
43
+ END_TOKEN = '</s>'
44
+
45
+ # =============================================================================
46
+ # Text Processing
47
+ # =============================================================================
48
+ def clean_text(text: str) -> str:
49
+ text = text.lower()
50
+ text = re.sub(r'https?://\S+', '', text)
51
+ text = re.sub(r'www\.\S+', '', text)
52
+ text = re.sub(r'@\w+', '', text)
53
+ text = re.sub(r'#(\w+)', r'\1', text)
54
+ text = re.sub(r'\s+', ' ', text)
55
+ return text.strip()
56
+
57
+ def tokenize(text: str) -> List[str]:
58
+ tokens = re.findall(r"[\w']+|[.,!?;:]", text)
59
+ return tokens
60
+
61
+ # =============================================================================
62
+ # LSTM Model
63
+ # =============================================================================
64
+ class LSTMLanguageModel(nn.Module):
65
+ def __init__(self, vocab_size: int, embed_dim: int = 256,
66
+ hidden_dim: int = 512, num_layers: int = 2, dropout: float = 0.3):
67
+ super().__init__()
68
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
69
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
70
+ batch_first=True, dropout=dropout if num_layers > 1 else 0)
71
+ self.dropout = nn.Dropout(dropout)
72
+ self.fc = nn.Linear(hidden_dim, vocab_size)
73
+
74
+ def forward(self, x):
75
+ embedded = self.embedding(x)
76
+ lstm_out, _ = self.lstm(embedded)
77
+ last_out = lstm_out[:, -1, :]
78
+ out = self.dropout(last_out)
79
+ return self.fc(out)
80
+
81
+ # =============================================================================
82
+ # Trigram Model
83
+ # =============================================================================
84
+ class TrigramLM:
85
+ def __init__(self, smoothing: float = 1.0):
86
+ self.smoothing = smoothing
87
+ self.unigram_counts = {}
88
+ self.bigram_counts = {}
89
+ self.trigram_counts = {}
90
+ self.vocab = set()
91
+
92
+ def probability(self, w3: str, w1: str, w2: str) -> float:
93
+ trigram_count = self.trigram_counts.get((w1, w2, w3), 0)
94
+ bigram_count = self.bigram_counts.get((w1, w2), 0)
95
+ vocab_size = len(self.vocab)
96
+ numerator = trigram_count + self.smoothing
97
+ denominator = bigram_count + (self.smoothing * vocab_size)
98
+ return numerator / denominator if denominator > 0 else 0.0
99
+
100
+ def predict_next_words(self, context: str, top_k: int = 5) -> List[Tuple[str, float]]:
101
+ words = context.lower().split()
102
+ if len(words) == 0:
103
+ w1, w2 = START_TOKEN, START_TOKEN
104
+ elif len(words) == 1:
105
+ w1, w2 = START_TOKEN, words[0]
106
+ else:
107
+ w1, w2 = words[-2], words[-1]
108
+
109
+ candidates = []
110
+ for word in self.vocab:
111
+ if word not in (START_TOKEN, END_TOKEN, '<s>', '</s>'):
112
+ prob = self.probability(word, w1, w2)
113
+ candidates.append((word, prob))
114
+
115
+ candidates.sort(key=lambda x: x[1], reverse=True)
116
+ return candidates[:top_k]
117
+
118
+ # =============================================================================
119
+ # Global Models (loaded once at startup)
120
+ # =============================================================================
121
+ lstm_model = None
122
+ word_to_idx = None
123
+ idx_to_word = None
124
+ trigram_model = None
125
+
126
+ @app.on_event("startup")
127
+ async def load_models():
128
+ global lstm_model, word_to_idx, idx_to_word, trigram_model
129
+
130
+ # Load LSTM
131
+ try:
132
+ checkpoint = torch.load('model/lstm_pidgin_model.pt', map_location='cpu')
133
+ word_to_idx = checkpoint['word_to_idx']
134
+ idx_to_word = checkpoint['idx_to_word']
135
+ vocab_size = checkpoint['vocab_size']
136
+
137
+ lstm_model = LSTMLanguageModel(vocab_size=vocab_size)
138
+ lstm_model.load_state_dict(checkpoint['model_state_dict'])
139
+ lstm_model.eval()
140
+ print(f"LSTM model loaded! Vocab size: {vocab_size}")
141
+ except Exception as e:
142
+ print(f"Failed to load LSTM model: {e}")
143
+
144
+ # Load Trigram
145
+ try:
146
+ with open('model/trigram_model.pkl', 'rb') as f:
147
+ trigram_model = pickle.load(f)
148
+ print(f"Trigram model loaded! Vocab size: {len(trigram_model.vocab)}")
149
+ except Exception as e:
150
+ print(f"Failed to load Trigram model: {e}")
151
+
152
+ # =============================================================================
153
+ # Request/Response Models
154
+ # =============================================================================
155
+ class PredictionRequest(BaseModel):
156
+ context: str
157
+ top_k: int = 5
158
+ model: str = "lstm" # "lstm", "trigram", or "both"
159
+
160
+ class Prediction(BaseModel):
161
+ word: str
162
+ probability: float
163
+
164
+ class PredictionResponse(BaseModel):
165
+ context: str
166
+ model: str
167
+ predictions: List[Prediction]
168
+
169
+ class BothModelsResponse(BaseModel):
170
+ context: str
171
+ lstm: List[Prediction]
172
+ trigram: List[Prediction]
173
+
174
+ # =============================================================================
175
+ # Prediction Functions
176
+ # =============================================================================
177
+ def predict_lstm(context: str, top_k: int = 5) -> List[Prediction]:
178
+ if lstm_model is None or not context.strip():
179
+ return []
180
+
181
+ tokens = tokenize(clean_text(context))
182
+ if not tokens:
183
+ return []
184
+
185
+ unk_idx = word_to_idx.get(UNK_TOKEN, 1)
186
+ indices = [word_to_idx.get(t, unk_idx) for t in tokens]
187
+ x = torch.tensor([indices], dtype=torch.long)
188
+
189
+ with torch.no_grad():
190
+ logits = lstm_model(x)
191
+ probs = torch.softmax(logits, dim=-1)
192
+
193
+ top_probs, top_indices = torch.topk(probs[0], top_k + 5)
194
+
195
+ results = []
196
+ for prob, idx in zip(top_probs.numpy(), top_indices.numpy()):
197
+ word = idx_to_word.get(str(idx), idx_to_word.get(idx, UNK_TOKEN))
198
+ if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]:
199
+ results.append(Prediction(word=word, probability=float(prob)))
200
+ if len(results) >= top_k:
201
+ break
202
+
203
+ return results
204
+
205
+ def predict_trigram(context: str, top_k: int = 5) -> List[Prediction]:
206
+ if trigram_model is None or not context.strip():
207
+ return []
208
+
209
+ preds = trigram_model.predict_next_words(context, top_k)
210
+ return [Prediction(word=w, probability=p) for w, p in preds]
211
+
212
+ # =============================================================================
213
+ # API Endpoints
214
+ # =============================================================================
215
+ @app.get("/")
216
+ async def root():
217
+ return {
218
+ "message": "Nigerian Pidgin Next-Word Predictor API",
219
+ "endpoints": {
220
+ "/predict": "POST - Get predictions",
221
+ "/predict/lstm": "GET - LSTM predictions",
222
+ "/predict/trigram": "GET - Trigram predictions",
223
+ "/health": "GET - Health check"
224
+ }
225
+ }
226
+
227
+ @app.get("/health")
228
+ async def health():
229
+ return {
230
+ "status": "healthy",
231
+ "lstm_loaded": lstm_model is not None,
232
+ "trigram_loaded": trigram_model is not None
233
+ }
234
+
235
+ @app.post("/predict", response_model=BothModelsResponse)
236
+ async def predict(request: PredictionRequest):
237
+ """Get predictions from both models."""
238
+ return BothModelsResponse(
239
+ context=request.context,
240
+ lstm=predict_lstm(request.context, request.top_k),
241
+ trigram=predict_trigram(request.context, request.top_k)
242
+ )
243
+
244
+ @app.get("/predict/lstm")
245
+ async def predict_lstm_endpoint(context: str, top_k: int = 5):
246
+ """Get LSTM predictions."""
247
+ if lstm_model is None:
248
+ raise HTTPException(status_code=503, detail="LSTM model not loaded")
249
+
250
+ predictions = predict_lstm(context, top_k)
251
+ return PredictionResponse(
252
+ context=context,
253
+ model="lstm",
254
+ predictions=predictions
255
+ )
256
+
257
+ @app.get("/predict/trigram")
258
+ async def predict_trigram_endpoint(context: str, top_k: int = 5):
259
+ """Get Trigram predictions."""
260
+ if trigram_model is None:
261
+ raise HTTPException(status_code=503, detail="Trigram model not loaded")
262
+
263
+ predictions = predict_trigram(context, top_k)
264
+ return PredictionResponse(
265
+ context=context,
266
+ model="trigram",
267
+ predictions=predictions
268
+ )
269
+
270
+ # =============================================================================
271
+ # Run with: uvicorn api:app --reload
272
+ # =============================================================================
model/lstm_pidgin_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:239784faa3a5af6f08025a11f9705c75f30e3a1106f669c6b297dbbca21de04a
3
+ size 64095297
model/trigram_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fe85f4bc3b84c739e714e35e15cc80cf35108947d3c194ca9079edf09cd4149
3
+ size 15507557
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn>=0.23.0
3
+ torch>=2.0.0
4
+ pydantic>=2.0.0