File size: 3,706 Bytes
2bcaf25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edd7bc2
2bcaf25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13abb53
2bcaf25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
import re
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware

# =====================================================
# 1. FastAPI App
# =====================================================
app = FastAPI(
    title="Khmer Spell Correction API",
    version="1.0"
)

# Allow CORS for testing
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# =====================================================
# 2. Utils
# =====================================================
def preprocess_khmer_text(text: str) -> str:
    """Clean and normalize Khmer text."""
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\u1780-\u17FF\u200B\u0020-\u007E]', '', text)
    return text.strip()

# =====================================================
# 3. Model Definition
# =====================================================
class KhmerSpellLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        # Match checkpoint fc
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, vocab_size)
        )

    def forward(self, x):
        emb = self.embedding(x)
        out, _ = self.lstm(emb)
        return self.fc(out)

# =====================================================
# 4. Load Model ONCE
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load("khmer_spell_lstm.pth", map_location=device)

char_to_idx = checkpoint["char_to_idx"]
vocab = checkpoint.get("vocab", char_to_idx.keys())
max_length = checkpoint["max_length"]
idx_to_char = {i: c for c, i in char_to_idx.items()}

model = KhmerSpellLSTM(
    vocab_size=len(vocab),
    embedding_dim=128,
    hidden_dim=256
).to(device)

# Load weights
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("✅ Khmer Spell LSTM loaded successfully")

# =====================================================
# 5. Inference Function
# =====================================================
def predict(text: str) -> str:
    text = preprocess_khmer_text(text)
    input_len = len(text)

    seq = [char_to_idx.get(c, char_to_idx["<UNK>"]) for c in text]
    seq += [char_to_idx["<PAD>"]] * (max_length - len(seq))
    seq = torch.tensor(seq[:max_length]).unsqueeze(0).to(device)

    with torch.no_grad():
        out = model(seq)
        pred = torch.argmax(out, dim=-1)[0]

    # Keep the prediction same length as input
    pred = pred[:input_len+1]

    return "".join(idx_to_char[i.item()] for i in pred)

# =====================================================
# 6. API Schema
# =====================================================
class TextInput(BaseModel):
    text: str

# =====================================================
# 7. Routes
# =====================================================
@app.get("/")
def health_check():
    return {"status": "Khmer Spell API running"}

@app.post("/predict")
def spell_correct(data: TextInput):
    corrected_text = predict(data.text)
    return {
        "input": data.text,
        "output": corrected_text
    }