Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI backend for Nigerian Pidgin Next-Word Prediction. | |
| Serves both LSTM and Trigram models as REST API. | |
| Deploy to Hugging Face Spaces with Docker SDK. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| import re | |
| import os | |
| # ============================================================================= | |
| # FastAPI App | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="Nigerian Pidgin Next-Word Predictor API", | |
| description="LSTM + Trigram models for Nigerian Pidgin next-word prediction", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS for all origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # Special Tokens | |
| # ============================================================================= | |
| PAD_TOKEN = '<PAD>' | |
| UNK_TOKEN = '<UNK>' | |
| SOS_TOKEN = '<SOS>' | |
| EOS_TOKEN = '</EOS>' | |
| START_TOKEN = '<s>' | |
| END_TOKEN = '</s>' | |
| # ============================================================================= | |
| # Text Processing | |
| # ============================================================================= | |
| def clean_text(text: str) -> str: | |
| text = text.lower() | |
| text = re.sub(r'https?://\S+', '', text) | |
| text = re.sub(r'www\.\S+', '', text) | |
| text = re.sub(r'@\w+', '', text) | |
| text = re.sub(r'#(\w+)', r'\1', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| return text.strip() | |
| def tokenize(text: str) -> List[str]: | |
| tokens = re.findall(r"[\w']+|[.,!?;:]", text) | |
| return tokens | |
| # ============================================================================= | |
| # LSTM Model | |
| # ============================================================================= | |
| class LSTMLanguageModel(nn.Module): | |
| def __init__(self, vocab_size: int, embed_dim: int = 256, | |
| hidden_dim: int = 512, num_layers: int = 2, dropout: float = 0.3): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) | |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, | |
| batch_first=True, dropout=dropout if num_layers > 1 else 0) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| def forward(self, x): | |
| embedded = self.embedding(x) | |
| lstm_out, _ = self.lstm(embedded) | |
| last_out = lstm_out[:, -1, :] | |
| out = self.dropout(last_out) | |
| return self.fc(out) | |
| # ============================================================================= | |
| # Trigram Model | |
| # ============================================================================= | |
| # Import directly from src to ensure compatibility with pickle | |
| from src.trigram_model import TrigramLM | |
| # ============================================================================= | |
| # Global Models (loaded once at startup) | |
| # ============================================================================= | |
| lstm_model = None | |
| word_to_idx = None | |
| idx_to_word = None | |
| trigram_model = None | |
| async def load_models(): | |
| global lstm_model, word_to_idx, idx_to_word, trigram_model | |
| # 1. Load LSTM | |
| try: | |
| checkpoint = torch.load('model/lstm_pidgin_model.pt', map_location='cpu') | |
| word_to_idx = checkpoint['word_to_idx'] | |
| idx_to_word = checkpoint['idx_to_word'] | |
| vocab_size = checkpoint['vocab_size'] | |
| lstm_model = LSTMLanguageModel(vocab_size=vocab_size) | |
| lstm_model.load_state_dict(checkpoint['model_state_dict']) | |
| lstm_model.eval() | |
| print(f"LSTM model loaded! Vocab size: {vocab_size}") | |
| except Exception as e: | |
| print(f"Failed to load LSTM model: {e}") | |
| # 2. Load Trigram | |
| try: | |
| with open('model/trigram_model.pkl', 'rb') as f: | |
| trigram_model = pickle.load(f) | |
| print(f"Trigram model loaded! Vocab size: {len(trigram_model.vocab)}") | |
| except Exception as e: | |
| print(f"Failed to load Trigram model: {e}") | |
| # ============================================================================= | |
| # Request/Response Models | |
| # ============================================================================= | |
| class PredictionRequest(BaseModel): | |
| context: str | |
| top_k: int = 5 | |
| model: str = "lstm" # "lstm", "trigram", or "both" | |
| class Prediction(BaseModel): | |
| word: str | |
| probability: float | |
| class PredictionResponse(BaseModel): | |
| context: str | |
| model: str | |
| predictions: List[Prediction] | |
| class BothModelsResponse(BaseModel): | |
| context: str | |
| lstm: List[Prediction] | |
| trigram: List[Prediction] | |
| # ============================================================================= | |
| # Prediction Functions | |
| # ============================================================================= | |
| def predict_lstm(context: str, top_k: int = 5) -> List[Prediction]: | |
| if lstm_model is None or not context.strip(): | |
| return [] | |
| tokens = tokenize(clean_text(context)) | |
| if not tokens: | |
| return [] | |
| unk_idx = word_to_idx.get(UNK_TOKEN, 1) | |
| indices = [word_to_idx.get(t, unk_idx) for t in tokens] | |
| x = torch.tensor([indices], dtype=torch.long) | |
| with torch.no_grad(): | |
| logits = lstm_model(x) | |
| probs = torch.softmax(logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probs[0], top_k + 5) | |
| results = [] | |
| for prob, idx in zip(top_probs.tolist(), top_indices.tolist()): | |
| word = idx_to_word.get(str(idx), idx_to_word.get(idx, UNK_TOKEN)) | |
| if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]: | |
| results.append(Prediction(word=word, probability=float(prob))) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| def predict_trigram(context: str, top_k: int = 5) -> List[Prediction]: | |
| if trigram_model is None or not context.strip(): | |
| return [] | |
| preds = trigram_model.predict_next_words(context, top_k) | |
| return [Prediction(word=w, probability=p) for w, p in preds] | |
| # ============================================================================= | |
| # API Endpoints | |
| # ============================================================================= | |
| async def root(): | |
| return { | |
| "message": "Nigerian Pidgin Next-Word Predictor API", | |
| "endpoints": { | |
| "/predict": "POST - Get predictions", | |
| "/predict/lstm": "GET - LSTM predictions", | |
| "/predict/trigram": "GET - Trigram predictions", | |
| "/health": "GET - Health check", | |
| "/debug": "GET - System info" | |
| } | |
| } | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "lstm_loaded": lstm_model is not None, | |
| "trigram_loaded": trigram_model is not None, | |
| "vocab_size": len(word_to_idx) if word_to_idx else 0 | |
| } | |
| async def debug_info(): | |
| """Return debug information about the environment.""" | |
| import sys | |
| return { | |
| "cwd": os.getcwd(), | |
| "files_root": os.listdir('.'), | |
| "files_model": os.listdir('model') if os.path.exists('model') else "MISSING", | |
| "files_src": os.listdir('src') if os.path.exists('src') else "MISSING", | |
| "python_path": sys.path, | |
| "lstm_model_type": str(type(lstm_model)) if lstm_model else "None", | |
| "trigram_model_type": str(type(trigram_model)) if trigram_model else "None", | |
| } | |
| async def predict(request: PredictionRequest): | |
| """Get predictions from both models.""" | |
| try: | |
| lstm_preds = predict_lstm(request.context, request.top_k) | |
| trigram_preds = predict_trigram(request.context, request.top_k) | |
| return BothModelsResponse( | |
| context=request.context, | |
| lstm=lstm_preds, | |
| trigram=trigram_preds | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Prediction Failed: {str(e)}") | |
| async def predict_lstm_endpoint(context: str, top_k: int = 5): | |
| """Get LSTM predictions.""" | |
| if lstm_model is None: | |
| raise HTTPException(status_code=503, detail="LSTM model not loaded") | |
| try: | |
| predictions = predict_lstm(context, top_k) | |
| return PredictionResponse( | |
| context=context, | |
| model="lstm", | |
| predictions=predictions | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"LSTM Prediction Failed: {str(e)}") | |
| async def predict_trigram_endpoint(context: str, top_k: int = 5): | |
| """Get Trigram predictions.""" | |
| if trigram_model is None: | |
| raise HTTPException(status_code=503, detail="Trigram model not loaded") | |
| try: | |
| predictions = predict_trigram(context, top_k) | |
| return PredictionResponse( | |
| context=context, | |
| model="trigram", | |
| predictions=predictions | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Trigram Prediction Failed: {str(e)}") | |
| # ============================================================================= | |
| # Run with: uvicorn api:app --reload | |
| # ============================================================================= | |