Khmer LSTM Autocomplete
A word-level LSTM model for Khmer text autocomplete.
Model Architecture
- Type: LSTM (1 layer)
- Embedding dim: 128
- Hidden dim: 256
- Tokenizer: khmercut
Usage
import os
import pickle
import torch
import torch.nn as nn
# Install dependencies if needed
try:
from khmercut import tokenize
except ImportError:
os.system('pip install khmercut')
from khmercut import tokenize
try:
from huggingface_hub import hf_hub_download
except ImportError:
os.system('pip install huggingface_hub')
from huggingface_hub import hf_hub_download
# ββ 1. Download files from HuggingFace ββββββββββββββββββββββββββββββββββββββ
print("Downloading model and vocab from HuggingFace...")
model_path = hf_hub_download("phonsobon/khmer_auto_completed", "khmer_lstm_autocomplete_best.pth")
vocab_path = hf_hub_download("phonsobon/khmer_auto_completed", "vocab_mapping.pkl")
# ββ 2. Load vocabulary βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Loading vocabulary...")
with open(vocab_path, "rb") as f:
vocab_data = pickle.load(f)
word_to_idx = vocab_data["word_to_idx"]
idx_to_word = vocab_data["idx_to_word"]
vocab_size = len(vocab_data["vocab"])
print(f"Vocabulary size: {vocab_size} words")
# ββ 3. Define model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class KhmerLSTMAutocomplete(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
out, _ = self.lstm(self.embedding(x))
return self.fc(out[:, -1, :])
# ββ 4. Load model weights ββββββββββββββββββββββββββββββββββββββββββββββββββββ
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = KhmerLSTMAutocomplete(vocab_size)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print("Model loaded successfully!\n")
# ββ 5. Autocomplete function βββββββββββββββββββββββββββββββββββββββββββββββββ
WINDOW_SIZE = 1
def get_autocomplete_suggestions(input_text, top_k=3):
tokens = tokenize(input_text)
tokens = [t.strip() for t in tokens if t.strip() != ""]
if len(tokens) < WINDOW_SIZE:
tokens = ["<PAD>"] * (WINDOW_SIZE - len(tokens)) + tokens
else:
tokens = tokens[-WINDOW_SIZE:]
input_idxs = [word_to_idx.get(w, word_to_idx["<UNK>"]) for w in tokens]
input_tensor = torch.tensor([input_idxs], dtype=torch.long).to(device)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=-1).squeeze(0)
top_probs, top_idxs = torch.topk(probs, top_k)
print(f"Input: '{input_text}'")
print("Suggestions:")
has_suggestions = False
for i in range(top_k):
word = idx_to_word[top_idxs[i].item()]
prob_val = top_probs[i].item() * 100
if word not in ["<PAD>", "<UNK>"]:
suggestion = f"{input_text.strip()}{word}".strip()
print(f" {i+1}. {suggestion} ({prob_val:.1f}%)")
has_suggestions = True
if not has_suggestions:
print("No relevant suggestions found.")
print()
# ββ 6. Test autocomplete βββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("=" * 50)
print(" KHMER AUTOCOMPLETE TEST")
print("=" * 50 + "\n")
test_inputs = [
"ααΌα",
"ααΌαα―αα§αααααααααααααααΈααααααΆ",
"ααΌααααααααΈαααααΆα",
"α’ααα»α",
"αααα»α",
]
for text in test_inputs:
get_autocomplete_suggestions(text, top_k=3)
print("=" * 50)
print("Testing complete!")
print("=" * 50)