import torch import torch.nn as nn import re from fastapi import FastAPI from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware # ===================================================== # 1. FastAPI App # ===================================================== app = FastAPI( title="Khmer Spell Correction API", version="1.0" ) # Allow CORS for testing app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ===================================================== # 2. Utils # ===================================================== def preprocess_khmer_text(text: str) -> str: """Clean and normalize Khmer text.""" text = re.sub(r'\s+', ' ', text) text = re.sub(r'[^\u1780-\u17FF\u200B\u0020-\u007E]', '', text) return text.strip() # ===================================================== # 3. Model Definition # ===================================================== class KhmerSpellLSTM(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.lstm = nn.LSTM( embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0, bidirectional=True ) # Match checkpoint fc self.fc = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, vocab_size) ) def forward(self, x): emb = self.embedding(x) out, _ = self.lstm(emb) return self.fc(out) # ===================================================== # 4. Load Model ONCE # ===================================================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load("khmer_spell_lstm.pth", map_location=device) char_to_idx = checkpoint["char_to_idx"] vocab = checkpoint.get("vocab", char_to_idx.keys()) max_length = checkpoint["max_length"] idx_to_char = {i: c for c, i in char_to_idx.items()} model = KhmerSpellLSTM( vocab_size=len(vocab), embedding_dim=128, hidden_dim=256 ).to(device) # Load weights model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print("✅ Khmer Spell LSTM loaded successfully") # ===================================================== # 5. Inference Function # ===================================================== def predict(text: str) -> str: text = preprocess_khmer_text(text) input_len = len(text) seq = [char_to_idx.get(c, char_to_idx[""]) for c in text] seq += [char_to_idx[""]] * (max_length - len(seq)) seq = torch.tensor(seq[:max_length]).unsqueeze(0).to(device) with torch.no_grad(): out = model(seq) pred = torch.argmax(out, dim=-1)[0] # Keep the prediction same length as input pred = pred[:input_len+1] return "".join(idx_to_char[i.item()] for i in pred) # ===================================================== # 6. API Schema # ===================================================== class TextInput(BaseModel): text: str # ===================================================== # 7. Routes # ===================================================== @app.get("/") def health_check(): return {"status": "Khmer Spell API running"} @app.post("/predict") def spell_correct(data: TextInput): corrected_text = predict(data.text) return { "input": data.text, "output": corrected_text }