Khmer LSTM Autocomplete

A word-level LSTM model for Khmer text autocomplete.

Model Architecture

  • Type: LSTM (1 layer)
  • Embedding dim: 128
  • Hidden dim: 256
  • Tokenizer: khmercut

Usage

import os
import pickle
import torch
import torch.nn as nn

# Install dependencies if needed
try:
    from khmercut import tokenize
except ImportError:
    os.system('pip install khmercut')
    from khmercut import tokenize

try:
    from huggingface_hub import hf_hub_download
except ImportError:
    os.system('pip install huggingface_hub')
    from huggingface_hub import hf_hub_download

# ── 1. Download files from HuggingFace ──────────────────────────────────────
print("Downloading model and vocab from HuggingFace...")
model_path = hf_hub_download("phonsobon/khmer_auto_completed", "khmer_lstm_autocomplete_best.pth")
vocab_path  = hf_hub_download("phonsobon/khmer_auto_completed", "vocab_mapping.pkl")

# ── 2. Load vocabulary ───────────────────────────────────────────────────────
print("Loading vocabulary...")
with open(vocab_path, "rb") as f:
    vocab_data = pickle.load(f)

word_to_idx = vocab_data["word_to_idx"]
idx_to_word = vocab_data["idx_to_word"]
vocab_size  = len(vocab_data["vocab"])
print(f"Vocabulary size: {vocab_size} words")

# ── 3. Define model ──────────────────────────────────────────────────────────
class KhmerLSTMAutocomplete(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        out, _ = self.lstm(self.embedding(x))
        return self.fc(out[:, -1, :])

# ── 4. Load model weights ────────────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = KhmerLSTMAutocomplete(vocab_size)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print("Model loaded successfully!\n")

# ── 5. Autocomplete function ─────────────────────────────────────────────────
WINDOW_SIZE = 1

def get_autocomplete_suggestions(input_text, top_k=3):
    tokens = tokenize(input_text)
    tokens = [t.strip() for t in tokens if t.strip() != ""]

    if len(tokens) < WINDOW_SIZE:
        tokens = ["<PAD>"] * (WINDOW_SIZE - len(tokens)) + tokens
    else:
        tokens = tokens[-WINDOW_SIZE:]

    input_idxs = [word_to_idx.get(w, word_to_idx["<UNK>"]) for w in tokens]
    input_tensor = torch.tensor([input_idxs], dtype=torch.long).to(device)

    with torch.no_grad():
        logits = model(input_tensor)
        probs  = torch.softmax(logits, dim=-1).squeeze(0)
        top_probs, top_idxs = torch.topk(probs, top_k)

    print(f"Input: '{input_text}'")
    print("Suggestions:")
    has_suggestions = False
    for i in range(top_k):
        word     = idx_to_word[top_idxs[i].item()]
        prob_val = top_probs[i].item() * 100
        if word not in ["<PAD>", "<UNK>"]:
            suggestion = f"{input_text.strip()}{word}".strip()
            print(f"  {i+1}. {suggestion}  ({prob_val:.1f}%)")
            has_suggestions = True
    if not has_suggestions:
        print("No relevant suggestions found.")
    print()

# ── 6. Test autocomplete ─────────────────────────────────────────────────────
print("=" * 50)
print("        KHMER AUTOCOMPLETE TEST")
print("=" * 50 + "\n")

test_inputs = [
    "សូម",
    "αžŸαžΌαž˜αž―αž€αž§αžαŸ’αžαž˜αžšαžŠαŸ’αž‹αž˜αž“αŸ’αžαŸ’αžšαžΈαž˜αŸαžαŸ’αžαžΆ",
    "αžŸαžΌαž˜αž›αŸ„αž€αžŸαŸ’αžšαžΈαž”αŸ’αžšαž’αžΆαž“",
    "αž’αžšαž‚αž»αžŽ",
    "αžαŸ’αž‰αž»αŸ†",
]

for text in test_inputs:
    get_autocomplete_suggestions(text, top_k=3)

print("=" * 50)
print("Testing complete!")
print("=" * 50)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support