sen-1 / src /bench.py
Tiep's picture
Refactor training to Hydra config and use underthesea imports
903cdb2
"""
Benchmark CLI for Vietnamese Text Classification.
Compares Rust TextClassifier vs sklearn.
Usage:
python bench.py vntc
python bench.py bank
python bench.py synthetic
"""
import os
import time
import random
from pathlib import Path
import click
from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVectorizer
from sklearn.svm import LinearSVC as SklearnLinearSVC
from sklearn.metrics import accuracy_score, f1_score, classification_report
from underthesea import TextClassifier
def read_file(filepath):
"""Read text file with multiple encoding attempts."""
for enc in ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']:
try:
with open(filepath, 'r', encoding=enc) as f:
text = ' '.join(f.read().split())
if len(text) > 10:
return text
except (UnicodeDecodeError, UnicodeError):
continue
return None
def benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=20000):
"""Benchmark scikit-learn TF-IDF + LinearSVC."""
click.echo("\n" + "=" * 70)
click.echo("scikit-learn: TfidfVectorizer + LinearSVC")
click.echo("=" * 70)
# Vectorize
click.echo(" Vectorizing...")
t0 = time.perf_counter()
vectorizer = SklearnTfidfVectorizer(max_features=max_features, ngram_range=(1, 2), min_df=2)
X_train = vectorizer.fit_transform(train_texts)
X_test = vectorizer.transform(test_texts)
vec_time = time.perf_counter() - t0
click.echo(f" Vectorization time: {vec_time:.2f}s")
click.echo(f" Vocabulary size: {len(vectorizer.vocabulary_)}")
# Train
click.echo(" Training LinearSVC...")
t0 = time.perf_counter()
clf = SklearnLinearSVC(C=1.0, max_iter=2000)
clf.fit(X_train, train_labels)
train_time = time.perf_counter() - t0
click.echo(f" Training time: {train_time:.2f}s")
# End-to-end inference
click.echo(" End-to-end inference...")
t0 = time.perf_counter()
X_test_e2e = vectorizer.transform(test_texts)
preds = clf.predict(X_test_e2e)
e2e_time = time.perf_counter() - t0
e2e_throughput = len(test_texts) / e2e_time
click.echo(f" E2E time: {e2e_time:.2f}s ({e2e_throughput:.0f} samples/sec)")
# Metrics
acc = accuracy_score(test_labels, preds)
f1_w = f1_score(test_labels, preds, average='weighted')
click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")
return {
"total_train": vec_time + train_time,
"e2e_throughput": e2e_throughput,
"accuracy": acc,
"f1_weighted": f1_w,
}
def benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=20000):
"""Benchmark Rust TextClassifier."""
click.echo("\n" + "=" * 70)
click.echo("Rust: TextClassifier (underthesea_core)")
click.echo("=" * 70)
clf = TextClassifier(
max_features=max_features,
ngram_range=(1, 2),
min_df=2,
c=1.0,
max_iter=1000,
tol=0.1,
)
# Train
click.echo(" Training...")
t0 = time.perf_counter()
clf.fit(list(train_texts), list(train_labels))
train_time = time.perf_counter() - t0
click.echo(f" Training time: {train_time:.2f}s")
click.echo(f" Vocabulary size: {clf.n_features}")
# Inference
click.echo(" Inference...")
t0 = time.perf_counter()
preds = clf.predict_batch(list(test_texts))
infer_time = time.perf_counter() - t0
throughput = len(test_texts) / infer_time
click.echo(f" Inference time: {infer_time:.2f}s ({throughput:.0f} samples/sec)")
# Metrics
acc = accuracy_score(test_labels, preds)
f1_w = f1_score(test_labels, preds, average='weighted')
click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")
return {
"total_train": train_time,
"throughput": throughput,
"accuracy": acc,
"f1_weighted": f1_w,
"clf": clf,
"preds": preds,
}
def print_comparison(sklearn_results, rust_results):
"""Print comparison summary."""
click.echo("\n" + "=" * 70)
click.echo("COMPARISON SUMMARY")
click.echo("=" * 70)
click.echo(f"{'Metric':<30} {'sklearn':<20} {'Rust':<20}")
click.echo("-" * 70)
click.echo(f"{'Training time (s)':<30} {sklearn_results['total_train']:<20.2f} {rust_results['total_train']:<20.2f}")
click.echo(f"{'Inference (samples/sec)':<30} {sklearn_results['e2e_throughput']:<20.0f} {rust_results['throughput']:<20.0f}")
click.echo(f"{'Accuracy':<30} {sklearn_results['accuracy']:<20.4f} {rust_results['accuracy']:<20.4f}")
click.echo(f"{'F1 (weighted)':<30} {sklearn_results['f1_weighted']:<20.4f} {rust_results['f1_weighted']:<20.4f}")
click.echo("-" * 70)
train_speedup = sklearn_results['total_train'] / rust_results['total_train'] if rust_results['total_train'] > 0 else 0
infer_speedup = rust_results['throughput'] / sklearn_results['e2e_throughput'] if sklearn_results['e2e_throughput'] > 0 else 0
click.echo(f"Speedup: Training {train_speedup:.2f}x, Inference {infer_speedup:.2f}x")
click.echo("=" * 70)
@click.group()
def cli():
"""Benchmark Vietnamese text classification models."""
pass
@cli.command()
@click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
help='Path to VNTC dataset')
@click.option('--save-model', is_flag=True, help='Save the trained Rust model')
@click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
def vntc(data_dir, save_model, output):
"""Benchmark on VNTC dataset (10 topics, ~84k documents)."""
click.echo("=" * 70)
click.echo("VNTC Full Dataset Benchmark")
click.echo("Vietnamese News Text Classification (10 Topics)")
click.echo("=" * 70)
train_dir = os.path.join(data_dir, "Train_Full")
test_dir = os.path.join(data_dir, "Test_Full")
# Load data
click.echo("\nLoading training data...")
t0 = time.perf_counter()
train_texts, train_labels = [], []
for folder in sorted(os.listdir(train_dir)):
folder_path = os.path.join(train_dir, folder)
if not os.path.isdir(folder_path):
continue
for fname in os.listdir(folder_path):
if fname.endswith('.txt'):
text = read_file(os.path.join(folder_path, fname))
if text:
train_texts.append(text)
train_labels.append(folder)
click.echo(f" Loaded {len(train_texts)} training samples in {time.perf_counter()-t0:.1f}s")
click.echo("Loading test data...")
t0 = time.perf_counter()
test_texts, test_labels = [], []
for folder in sorted(os.listdir(test_dir)):
folder_path = os.path.join(test_dir, folder)
if not os.path.isdir(folder_path):
continue
for fname in os.listdir(folder_path):
if fname.endswith('.txt'):
text = read_file(os.path.join(folder_path, fname))
if text:
test_texts.append(text)
test_labels.append(folder)
click.echo(f" Loaded {len(test_texts)} test samples in {time.perf_counter()-t0:.1f}s")
# Run benchmarks
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels)
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels)
print_comparison(sklearn_results, rust_results)
if save_model:
model_path = Path(output)
model_path.parent.mkdir(parents=True, exist_ok=True)
rust_results['clf'].save(str(model_path))
size_mb = model_path.stat().st_size / (1024 * 1024)
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
@cli.command()
@click.option('--save-model', is_flag=True, help='Save the trained Rust model')
@click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
def bank(save_model, output):
"""Benchmark on UTS2017_Bank dataset (14 categories, banking domain)."""
from datasets import load_dataset
click.echo("=" * 70)
click.echo("UTS2017_Bank Dataset Benchmark")
click.echo("Vietnamese Banking Domain Text Classification (14 Categories)")
click.echo("=" * 70)
# Load data
click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
train_texts = list(dataset["train"]["text"])
train_labels = list(dataset["train"]["label"])
test_texts = list(dataset["test"]["text"])
test_labels = list(dataset["test"]["label"])
click.echo(f" Train samples: {len(train_texts)}")
click.echo(f" Test samples: {len(test_texts)}")
click.echo(f" Categories: {len(set(train_labels))}")
# Run benchmarks (smaller max_features for smaller dataset)
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)
print_comparison(sklearn_results, rust_results)
click.echo("\nClassification Report (Rust):")
click.echo(classification_report(test_labels, rust_results['preds']))
if save_model:
model_path = Path(output)
model_path.parent.mkdir(parents=True, exist_ok=True)
rust_results['clf'].save(str(model_path))
size_mb = model_path.stat().st_size / (1024 * 1024)
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
@cli.command()
@click.option('--train-per-cat', default=340, help='Training samples per category')
@click.option('--test-per-cat', default=500, help='Test samples per category')
@click.option('--seed', default=42, help='Random seed')
def synthetic(train_per_cat, test_per_cat, seed):
"""Benchmark on synthetic VNTC-like data."""
# Vietnamese text templates by category
TEMPLATES = {
"the_thao": ["Đội tuyển {} thắng {} với tỷ số {}", "Cầu thủ {} ghi bàn đẹp mắt"],
"kinh_doanh": ["Chứng khoán {} điểm trong phiên giao dịch", "Ngân hàng {} công bố lãi suất {}"],
"cong_nghe": ["Apple ra mắt {} với nhiều tính năng", "Trí tuệ nhân tạo đang thay đổi {}"],
"chinh_tri": ["Quốc hội thông qua nghị quyết về {}", "Chủ tịch {} tiếp đón phái đoàn"],
"van_hoa": ["Nghệ sĩ {} ra mắt album mới", "Liên hoan phim {} trao giải"],
"khoa_hoc": ["Nhà khoa học phát hiện {} mới", "Nghiên cứu cho thấy {} có tác dụng"],
"suc_khoe": ["Bộ Y tế cảnh báo về {} trong mùa", "Vaccine {} đạt hiệu quả cao"],
"giao_duc": ["Trường {} công bố điểm chuẩn", "Học sinh đoạt huy chương tại Olympic"],
"phap_luat": ["Tòa án xét xử vụ án {} với bị cáo", "Công an triệt phá đường dây"],
"doi_song": ["Giá {} tăng trong tháng", "Người dân đổ xô đi mua {}"],
}
FILLS = {
"the_thao": ["Việt Nam", "Thái Lan", "3-0", "bóng đá", "AFF Cup"],
"kinh_doanh": ["tăng", "giảm", "VN-Index", "Vietcombank", "8%"],
"cong_nghe": ["iPhone 16", "ChatGPT", "công việc", "VinAI", "5G"],
"chinh_tri": ["kinh tế", "nước", "Trung Quốc", "Hà Nội", "phát triển"],
"van_hoa": ["Mỹ Tâm", "Cannes", "nghệ thuật", "Hà Nội", "Bố Già"],
"khoa_hoc": ["loài sinh vật", "trà xanh", "VNREDSat-1", "Nobel", "robot"],
"suc_khoe": ["dịch cúm", "COVID-19", "Bạch Mai", "dinh dưỡng", "tiểu đường"],
"giao_duc": ["Bách Khoa", "Việt Nam", "Toán", "THPT", "STEM"],
"phap_luat": ["tham nhũng", "TP.HCM", "ma túy", "Hình sự", "gian lận"],
"doi_song": ["xăng", "vàng", "nắng nóng", "Trung thu", "bún chả"],
}
def generate_sample(category):
template = random.choice(TEMPLATES[category])
fills = FILLS[category]
n = template.count("{}")
return template.format(*random.choices(fills, k=n))
def generate_dataset(n_per_cat, categories):
texts, labels = [], []
for cat in categories:
for _ in range(n_per_cat):
texts.append(generate_sample(cat))
labels.append(cat)
combined = list(zip(texts, labels))
random.shuffle(combined)
return [t for t, _ in combined], [l for _, l in combined]
click.echo("=" * 70)
click.echo("Synthetic VNTC-like Benchmark")
click.echo("=" * 70)
random.seed(seed)
categories = list(TEMPLATES.keys())
click.echo(f"\nConfiguration:")
click.echo(f" Categories: {len(categories)}")
click.echo(f" Train samples: {train_per_cat * len(categories)}")
click.echo(f" Test samples: {test_per_cat * len(categories)}")
train_texts, train_labels = generate_dataset(train_per_cat, categories)
test_texts, test_labels = generate_dataset(test_per_cat, categories)
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)
print_comparison(sklearn_results, rust_results)
if __name__ == "__main__":
cli()