sedtha commited on
Commit
2bcaf25
·
verified ·
1 Parent(s): 846dce8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +123 -95
main.py CHANGED
@@ -1,95 +1,123 @@
1
- import torch
2
- import torch.nn as nn
3
- import re
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
-
7
- # =====================================================
8
- # App
9
- # =====================================================
10
- app = FastAPI(title="Khmer Spell Correction API")
11
-
12
- # =====================================================
13
- # Utils
14
- # =====================================================
15
- def preprocess_khmer_text(text: str) -> str:
16
- text = re.sub(r"\s+", " ", text)
17
- text = re.sub(r"[^\u1780-\u17FF\u200B\u0020-\u007E]", "", text)
18
- return text.strip()
19
-
20
- # =====================================================
21
- # Model
22
- # =====================================================
23
- class KhmerSpellLSTM(nn.Module):
24
- def __init__(self, vocab_size, embedding_dim, hidden_dim):
25
- super().__init__()
26
- self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
27
- self.lstm = nn.LSTM(
28
- embedding_dim,
29
- hidden_dim,
30
- batch_first=True,
31
- bidirectional=True
32
- )
33
- self.fc = nn.Linear(hidden_dim * 2, vocab_size)
34
-
35
- def forward(self, x):
36
- x = self.embedding(x)
37
- x, _ = self.lstm(x)
38
- return self.fc(x)
39
-
40
- # =====================================================
41
- # Load Model ONCE
42
- # =====================================================
43
- device = torch.device("cpu")
44
-
45
- checkpoint = torch.load("khmer_spell_lstm.pth", map_location=device)
46
-
47
- char_to_idx = checkpoint["char_to_idx"]
48
- idx_to_char = {i: c for c, i in char_to_idx.items()}
49
- max_length = checkpoint["max_length"]
50
-
51
- model = KhmerSpellLSTM(
52
- vocab_size=len(char_to_idx),
53
- embedding_dim=128,
54
- hidden_dim=256
55
- )
56
-
57
- model.load_state_dict(checkpoint["model_state_dict"])
58
- model.eval()
59
-
60
- # =====================================================
61
- # Inference
62
- # =====================================================
63
- def predict(text: str) -> str:
64
- text = preprocess_khmer_text(text)
65
- input_len = len(text)
66
-
67
- seq = [char_to_idx.get(c, char_to_idx["<UNK>"]) for c in text]
68
- seq += [char_to_idx["<PAD>"]] * (max_length - len(seq))
69
- seq = torch.tensor(seq[:max_length]).unsqueeze(0)
70
-
71
- with torch.no_grad():
72
- out = model(seq)
73
- pred = torch.argmax(out, dim=-1)[0][:input_len]
74
-
75
- return "".join(idx_to_char[i.item()] for i in pred)
76
-
77
- # =====================================================
78
- # Schema
79
- # =====================================================
80
- class TextInput(BaseModel):
81
- text: str
82
-
83
- # =====================================================
84
- # Routes
85
- # =====================================================
86
- @app.get("/")
87
- def health():
88
- return {"status": "Khmer Spell API running"}
89
-
90
- @app.post("/predict")
91
- def spell_correct(data: TextInput):
92
- return {
93
- "input": data.text,
94
- "output": predict(data.text)
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+
8
+ # =====================================================
9
+ # 1. FastAPI App
10
+ # =====================================================
11
+ app = FastAPI(
12
+ title="Khmer Spell Correction API",
13
+ version="1.0"
14
+ )
15
+
16
+ # Allow CORS for testing
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # =====================================================
25
+ # 2. Utils
26
+ # =====================================================
27
+ def preprocess_khmer_text(text: str) -> str:
28
+ """Clean and normalize Khmer text."""
29
+ text = re.sub(r'\s+', ' ', text)
30
+ text = re.sub(r'[^\u1780-\u17FF\u200B\u0020-\u007E]', '', text)
31
+ return text.strip()
32
+
33
+ # =====================================================
34
+ # 3. Model Definition
35
+ # =====================================================
36
+ class KhmerSpellLSTM(nn.Module):
37
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3):
38
+ super().__init__()
39
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
40
+ self.lstm = nn.LSTM(
41
+ embedding_dim,
42
+ hidden_dim,
43
+ num_layers=num_layers,
44
+ batch_first=True,
45
+ dropout=dropout if num_layers > 1 else 0,
46
+ bidirectional=True
47
+ )
48
+ # Match checkpoint fc
49
+ self.fc = nn.Sequential(
50
+ nn.Linear(hidden_dim * 2, hidden_dim),
51
+ nn.ReLU(),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(hidden_dim, vocab_size)
54
+ )
55
+
56
+ def forward(self, x):
57
+ emb = self.embedding(x)
58
+ out, _ = self.lstm(emb)
59
+ return self.fc(out)
60
+
61
+ # =====================================================
62
+ # 4. Load Model ONCE
63
+ # =====================================================
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+
66
+ checkpoint = torch.load("result/khmer_spell_lstm.pth", map_location=device)
67
+
68
+ char_to_idx = checkpoint["char_to_idx"]
69
+ vocab = checkpoint.get("vocab", char_to_idx.keys())
70
+ max_length = checkpoint["max_length"]
71
+ idx_to_char = {i: c for c, i in char_to_idx.items()}
72
+
73
+ model = KhmerSpellLSTM(
74
+ vocab_size=len(vocab),
75
+ embedding_dim=128,
76
+ hidden_dim=256
77
+ ).to(device)
78
+
79
+ # Load weights
80
+ model.load_state_dict(checkpoint["model_state_dict"])
81
+ model.eval()
82
+ print("✅ Khmer Spell LSTM loaded successfully")
83
+
84
+ # =====================================================
85
+ # 5. Inference Function
86
+ # =====================================================
87
+ def predict(text: str) -> str:
88
+ text = preprocess_khmer_text(text)
89
+ input_len = len(text)
90
+
91
+ seq = [char_to_idx.get(c, char_to_idx["<UNK>"]) for c in text]
92
+ seq += [char_to_idx["<PAD>"]] * (max_length - len(seq))
93
+ seq = torch.tensor(seq[:max_length]).unsqueeze(0).to(device)
94
+
95
+ with torch.no_grad():
96
+ out = model(seq)
97
+ pred = torch.argmax(out, dim=-1)[0]
98
+
99
+ # Keep the prediction same length as input
100
+ pred = pred[:input_len]
101
+
102
+ return "".join(idx_to_char[i.item()] for i in pred)
103
+
104
+ # =====================================================
105
+ # 6. API Schema
106
+ # =====================================================
107
+ class TextInput(BaseModel):
108
+ text: str
109
+
110
+ # =====================================================
111
+ # 7. Routes
112
+ # =====================================================
113
+ @app.get("/")
114
+ def health_check():
115
+ return {"status": "Khmer Spell API running"}
116
+
117
+ @app.post("/predict")
118
+ def spell_correct(data: TextInput):
119
+ corrected_text = predict(data.text)
120
+ return {
121
+ "input": data.text,
122
+ "output": corrected_text
123
+ }