tre-1 / scripts /test_trainers.py
rain1024's picture
Add examples and test trainer script
f189f3f
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "python-crfsuite>=0.9.11",
# "datasets>=2.0.0",
# "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
# ]
# ///
"""
Test script to compare python-crfsuite and underthesea-core trainers
on a tiny dataset to validate correctness.
"""
import time
def create_tiny_dataset():
"""Create a tiny dataset for testing."""
# 3 simple sentences
data = [
# Sentence 1: "Tôi yêu Việt Nam"
(["Tôi", "yêu", "Việt_Nam"], ["PRON", "VERB", "PROPN"]),
# Sentence 2: "Hà Nội đẹp"
(["Hà_Nội", "đẹp"], ["PROPN", "ADJ"]),
# Sentence 3: "Tôi ở Hà Nội"
(["Tôi", "ở", "Hà_Nội"], ["PRON", "VERB", "PROPN"]),
]
return data
def create_medium_dataset(num_sentences=100):
"""Create a medium dataset from UDD-1 for testing."""
from datasets import load_dataset
dataset = load_dataset("undertheseanlp/UDD-1")
train_data = dataset["train"]
data = []
for i, item in enumerate(train_data):
if i >= num_sentences:
break
tokens = item["tokens"]
tags = item["upos"] # Already strings
if tokens and tags and len(tokens) == len(tags):
data.append((tokens, tags))
return data
def extract_features(tokens, position):
"""Simple feature extraction."""
features = {}
token = tokens[position]
features["word"] = token
features["lower"] = token.lower()
if position > 0:
features["prev"] = tokens[position - 1]
else:
features["prev"] = "__BOS__"
if position < len(tokens) - 1:
features["next"] = tokens[position + 1]
else:
features["next"] = "__EOS__"
return features
def sentence_to_features(tokens):
return [
[f"{k}={v}" for k, v in extract_features(tokens, i).items()]
for i in range(len(tokens))
]
def test_python_crfsuite(data, max_iter=10):
"""Test with python-crfsuite."""
import pycrfsuite
X_train = [sentence_to_features(tokens) for tokens, _ in data]
y_train = [tags for _, tags in data]
print("\n=== Python-CRFsuite ===")
print(f"Training data: {len(data)} sentences")
trainer = pycrfsuite.Trainer(verbose=True)
for xseq, yseq in zip(X_train, y_train):
trainer.append(xseq, yseq)
trainer.set_params({
"c1": 0.1,
"c2": 0.01,
"max_iterations": max_iter,
"feature.possible_transitions": True,
})
start = time.time()
trainer.train("/tmp/test_pycrfsuite.model")
elapsed = time.time() - start
print(f"Training time: {elapsed:.4f}s")
# Test prediction accuracy
tagger = pycrfsuite.Tagger()
tagger.open("/tmp/test_pycrfsuite.model")
correct = 0
total = 0
for tokens, gold in data:
features = sentence_to_features(tokens)
pred = tagger.tag(features)
for p, g in zip(pred, gold):
if p == g:
correct += 1
total += 1
print(f"Accuracy: {correct}/{total} = {correct/total:.4f}")
return tagger
def test_underthesea_core(data, max_iter=10):
"""Test with underthesea-core."""
try:
from underthesea_core import CRFTrainer, CRFModel, CRFTagger
except ImportError:
try:
from underthesea_core.underthesea_core import CRFTrainer, CRFModel, CRFTagger
except ImportError:
print("\n=== Underthesea-core ===")
print("ERROR: CRFTrainer not available")
return None
X_train = [sentence_to_features(tokens) for tokens, _ in data]
y_train = [tags for _, tags in data]
print("\n=== Underthesea-core ===")
print(f"Training data: {len(data)} sentences")
# Same iterations as CRFsuite for fair speed comparison
trainer = CRFTrainer(
loss_function="lbfgs",
l1_penalty=0.1,
l2_penalty=0.01,
max_iterations=max_iter,
verbose=1, # Show sparse feature count
)
start = time.time()
model = trainer.train(X_train, y_train)
elapsed = time.time() - start
print(f"Training time: {elapsed:.4f}s")
# Save and load for testing
model.save("/tmp/test_underthesea.crf")
model = CRFModel.load("/tmp/test_underthesea.crf")
tagger = CRFTagger.from_model(model)
correct = 0
total = 0
for tokens, gold in data:
features = sentence_to_features(tokens)
pred = tagger.tag(features)
for p, g in zip(pred, gold):
if p == g:
correct += 1
total += 1
print(f"Accuracy: {correct}/{total} = {correct/total:.4f}")
return tagger
def main():
import sys
num_sentences = 100
if len(sys.argv) > 1:
num_sentences = int(sys.argv[1])
print("=" * 60)
print(f"Comparing CRF Trainers on {num_sentences} sentences")
print("=" * 60)
if num_sentences <= 3:
data = create_tiny_dataset()
else:
data = create_medium_dataset(num_sentences)
total_tokens = sum(len(tokens) for tokens, _ in data)
print(f"Total tokens: {total_tokens}")
max_iter = 100
# Test both
test_python_crfsuite(data, max_iter)
test_underthesea_core(data, max_iter)
print("\n" + "=" * 60)
if __name__ == "__main__":
main()