sedtha commited on
Commit
846dce8
·
verified ·
1 Parent(s): 6bf1e0f

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +10 -0
  2. khmer_spell_lstm.pth +3 -0
  3. main.py +95 -0
  4. requirements.txt +4 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
khmer_spell_lstm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdb4948f7bf7b078fd6195db4d8745aa7f65d96eac7edd164da06b55d803a22d
3
+ size 10207041
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+
7
+ # =====================================================
8
+ # App
9
+ # =====================================================
10
+ app = FastAPI(title="Khmer Spell Correction API")
11
+
12
+ # =====================================================
13
+ # Utils
14
+ # =====================================================
15
+ def preprocess_khmer_text(text: str) -> str:
16
+ text = re.sub(r"\s+", " ", text)
17
+ text = re.sub(r"[^\u1780-\u17FF\u200B\u0020-\u007E]", "", text)
18
+ return text.strip()
19
+
20
+ # =====================================================
21
+ # Model
22
+ # =====================================================
23
+ class KhmerSpellLSTM(nn.Module):
24
+ def __init__(self, vocab_size, embedding_dim, hidden_dim):
25
+ super().__init__()
26
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
27
+ self.lstm = nn.LSTM(
28
+ embedding_dim,
29
+ hidden_dim,
30
+ batch_first=True,
31
+ bidirectional=True
32
+ )
33
+ self.fc = nn.Linear(hidden_dim * 2, vocab_size)
34
+
35
+ def forward(self, x):
36
+ x = self.embedding(x)
37
+ x, _ = self.lstm(x)
38
+ return self.fc(x)
39
+
40
+ # =====================================================
41
+ # Load Model ONCE
42
+ # =====================================================
43
+ device = torch.device("cpu")
44
+
45
+ checkpoint = torch.load("khmer_spell_lstm.pth", map_location=device)
46
+
47
+ char_to_idx = checkpoint["char_to_idx"]
48
+ idx_to_char = {i: c for c, i in char_to_idx.items()}
49
+ max_length = checkpoint["max_length"]
50
+
51
+ model = KhmerSpellLSTM(
52
+ vocab_size=len(char_to_idx),
53
+ embedding_dim=128,
54
+ hidden_dim=256
55
+ )
56
+
57
+ model.load_state_dict(checkpoint["model_state_dict"])
58
+ model.eval()
59
+
60
+ # =====================================================
61
+ # Inference
62
+ # =====================================================
63
+ def predict(text: str) -> str:
64
+ text = preprocess_khmer_text(text)
65
+ input_len = len(text)
66
+
67
+ seq = [char_to_idx.get(c, char_to_idx["<UNK>"]) for c in text]
68
+ seq += [char_to_idx["<PAD>"]] * (max_length - len(seq))
69
+ seq = torch.tensor(seq[:max_length]).unsqueeze(0)
70
+
71
+ with torch.no_grad():
72
+ out = model(seq)
73
+ pred = torch.argmax(out, dim=-1)[0][:input_len]
74
+
75
+ return "".join(idx_to_char[i.item()] for i in pred)
76
+
77
+ # =====================================================
78
+ # Schema
79
+ # =====================================================
80
+ class TextInput(BaseModel):
81
+ text: str
82
+
83
+ # =====================================================
84
+ # Routes
85
+ # =====================================================
86
+ @app.get("/")
87
+ def health():
88
+ return {"status": "Khmer Spell API running"}
89
+
90
+ @app.post("/predict")
91
+ def spell_correct(data: TextInput):
92
+ return {
93
+ "input": data.text,
94
+ "output": predict(data.text)
95
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ pydantic