phonsobon/khmer-word-segmentation
Viewer β’ Updated β’ 447k β’ 23
A word-level language model trained on the phonsobon/khmer-word-segmentation dataset (~358k administrative Khmer sentences).
| Component | Detail |
|---|---|
| Model | Bidirectional LSTM |
| Layers | 2 |
| Embedding dim | 128 |
| Hidden dim | 256 (Γ 2 for bi-directional) |
| Dropout | 0.3 |
| Vocabulary | restricted to words with frequency β₯ 10 |
| File | Description |
|---|---|
pytorch_model.bin |
Model weights (state dict) |
config.json |
Architecture & tokenisation config |
word2idx.json |
Word β index mapping |
idx2word.json |
Index β word mapping |
import json, re, torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
REPO = "phonsobon/khmer_prediction_sentence"
# Load config and vocab
config = json.load(open(hf_hub_download(REPO, "config.json")))
word2idx = json.load(open(hf_hub_download(REPO, "word2idx.json")))
idx2word = {int(k): v for k, v in json.load(open(hf_hub_download(REPO, "idx2word.json"))).items()}
class KhmerLanguageModel(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, dropout=0.3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.dropout = nn.Dropout(dropout)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True,
bidirectional=True, num_layers=2, dropout=dropout)
self.fc = nn.Linear(hidden_dim * 2, vocab_size)
def forward(self, x):
x = self.dropout(self.embedding(x))
_, (h, _) = self.lstm(x)
return self.fc(self.dropout(torch.cat([h[-2], h[-1]], dim=1)))
model = KhmerLanguageModel(config["vocab_size"])
state = torch.load(hf_hub_download(REPO, "pytorch_model.bin"), map_location="cpu")
model.load_state_dict(state)
model.eval()
def predict(text, max_length=config["max_length"]):
in_text = "startseq " + text
for _ in range(max_length):
seq = [word2idx.get(w, word2idx["<OOV>"]) for w in in_text.split()]
padded = ([0] * (max_length - len(seq)) + seq)[-max_length:]
x = torch.tensor([padded], dtype=torch.long)
with torch.no_grad():
idx = model(x).argmax(dim=1).item()
word = idx2word.get(idx)
if not word or word == "<PAD>":
break
in_text += " " + word
if word == "endseq":
break
# Strip spaces for natural Khmer output
return in_text.replace("startseq", "").replace("endseq", "").replace(" ", "").strip()
print(predict("ααΆαα’αα·αααα"))