|
|
""" |
|
|
Benchmark CLI for Vietnamese Text Classification. |
|
|
|
|
|
Compares Rust TextClassifier vs sklearn. |
|
|
|
|
|
Usage: |
|
|
python bench.py vntc |
|
|
python bench.py bank |
|
|
python bench.py synthetic |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import random |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVectorizer |
|
|
from sklearn.svm import LinearSVC as SklearnLinearSVC |
|
|
from sklearn.metrics import accuracy_score, f1_score, classification_report |
|
|
|
|
|
from underthesea import TextClassifier |
|
|
|
|
|
|
|
|
def read_file(filepath): |
|
|
"""Read text file with multiple encoding attempts.""" |
|
|
for enc in ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']: |
|
|
try: |
|
|
with open(filepath, 'r', encoding=enc) as f: |
|
|
text = ' '.join(f.read().split()) |
|
|
if len(text) > 10: |
|
|
return text |
|
|
except (UnicodeDecodeError, UnicodeError): |
|
|
continue |
|
|
return None |
|
|
|
|
|
|
|
|
def benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=20000): |
|
|
"""Benchmark scikit-learn TF-IDF + LinearSVC.""" |
|
|
click.echo("\n" + "=" * 70) |
|
|
click.echo("scikit-learn: TfidfVectorizer + LinearSVC") |
|
|
click.echo("=" * 70) |
|
|
|
|
|
|
|
|
click.echo(" Vectorizing...") |
|
|
t0 = time.perf_counter() |
|
|
vectorizer = SklearnTfidfVectorizer(max_features=max_features, ngram_range=(1, 2), min_df=2) |
|
|
X_train = vectorizer.fit_transform(train_texts) |
|
|
X_test = vectorizer.transform(test_texts) |
|
|
vec_time = time.perf_counter() - t0 |
|
|
click.echo(f" Vectorization time: {vec_time:.2f}s") |
|
|
click.echo(f" Vocabulary size: {len(vectorizer.vocabulary_)}") |
|
|
|
|
|
|
|
|
click.echo(" Training LinearSVC...") |
|
|
t0 = time.perf_counter() |
|
|
clf = SklearnLinearSVC(C=1.0, max_iter=2000) |
|
|
clf.fit(X_train, train_labels) |
|
|
train_time = time.perf_counter() - t0 |
|
|
click.echo(f" Training time: {train_time:.2f}s") |
|
|
|
|
|
|
|
|
click.echo(" End-to-end inference...") |
|
|
t0 = time.perf_counter() |
|
|
X_test_e2e = vectorizer.transform(test_texts) |
|
|
preds = clf.predict(X_test_e2e) |
|
|
e2e_time = time.perf_counter() - t0 |
|
|
e2e_throughput = len(test_texts) / e2e_time |
|
|
click.echo(f" E2E time: {e2e_time:.2f}s ({e2e_throughput:.0f} samples/sec)") |
|
|
|
|
|
|
|
|
acc = accuracy_score(test_labels, preds) |
|
|
f1_w = f1_score(test_labels, preds, average='weighted') |
|
|
click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}") |
|
|
|
|
|
return { |
|
|
"total_train": vec_time + train_time, |
|
|
"e2e_throughput": e2e_throughput, |
|
|
"accuracy": acc, |
|
|
"f1_weighted": f1_w, |
|
|
} |
|
|
|
|
|
|
|
|
def benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=20000): |
|
|
"""Benchmark Rust TextClassifier.""" |
|
|
click.echo("\n" + "=" * 70) |
|
|
click.echo("Rust: TextClassifier (underthesea_core)") |
|
|
click.echo("=" * 70) |
|
|
|
|
|
clf = TextClassifier( |
|
|
max_features=max_features, |
|
|
ngram_range=(1, 2), |
|
|
min_df=2, |
|
|
c=1.0, |
|
|
max_iter=1000, |
|
|
tol=0.1, |
|
|
) |
|
|
|
|
|
|
|
|
click.echo(" Training...") |
|
|
t0 = time.perf_counter() |
|
|
clf.fit(list(train_texts), list(train_labels)) |
|
|
train_time = time.perf_counter() - t0 |
|
|
click.echo(f" Training time: {train_time:.2f}s") |
|
|
click.echo(f" Vocabulary size: {clf.n_features}") |
|
|
|
|
|
|
|
|
click.echo(" Inference...") |
|
|
t0 = time.perf_counter() |
|
|
preds = clf.predict_batch(list(test_texts)) |
|
|
infer_time = time.perf_counter() - t0 |
|
|
throughput = len(test_texts) / infer_time |
|
|
click.echo(f" Inference time: {infer_time:.2f}s ({throughput:.0f} samples/sec)") |
|
|
|
|
|
|
|
|
acc = accuracy_score(test_labels, preds) |
|
|
f1_w = f1_score(test_labels, preds, average='weighted') |
|
|
click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}") |
|
|
|
|
|
return { |
|
|
"total_train": train_time, |
|
|
"throughput": throughput, |
|
|
"accuracy": acc, |
|
|
"f1_weighted": f1_w, |
|
|
"clf": clf, |
|
|
"preds": preds, |
|
|
} |
|
|
|
|
|
|
|
|
def print_comparison(sklearn_results, rust_results): |
|
|
"""Print comparison summary.""" |
|
|
click.echo("\n" + "=" * 70) |
|
|
click.echo("COMPARISON SUMMARY") |
|
|
click.echo("=" * 70) |
|
|
click.echo(f"{'Metric':<30} {'sklearn':<20} {'Rust':<20}") |
|
|
click.echo("-" * 70) |
|
|
|
|
|
click.echo(f"{'Training time (s)':<30} {sklearn_results['total_train']:<20.2f} {rust_results['total_train']:<20.2f}") |
|
|
click.echo(f"{'Inference (samples/sec)':<30} {sklearn_results['e2e_throughput']:<20.0f} {rust_results['throughput']:<20.0f}") |
|
|
click.echo(f"{'Accuracy':<30} {sklearn_results['accuracy']:<20.4f} {rust_results['accuracy']:<20.4f}") |
|
|
click.echo(f"{'F1 (weighted)':<30} {sklearn_results['f1_weighted']:<20.4f} {rust_results['f1_weighted']:<20.4f}") |
|
|
|
|
|
click.echo("-" * 70) |
|
|
train_speedup = sklearn_results['total_train'] / rust_results['total_train'] if rust_results['total_train'] > 0 else 0 |
|
|
infer_speedup = rust_results['throughput'] / sklearn_results['e2e_throughput'] if sklearn_results['e2e_throughput'] > 0 else 0 |
|
|
click.echo(f"Speedup: Training {train_speedup:.2f}x, Inference {infer_speedup:.2f}x") |
|
|
click.echo("=" * 70) |
|
|
|
|
|
|
|
|
@click.group() |
|
|
def cli(): |
|
|
"""Benchmark Vietnamese text classification models.""" |
|
|
pass |
|
|
|
|
|
|
|
|
@cli.command() |
|
|
@click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1', |
|
|
help='Path to VNTC dataset') |
|
|
@click.option('--save-model', is_flag=True, help='Save the trained Rust model') |
|
|
@click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path') |
|
|
def vntc(data_dir, save_model, output): |
|
|
"""Benchmark on VNTC dataset (10 topics, ~84k documents).""" |
|
|
click.echo("=" * 70) |
|
|
click.echo("VNTC Full Dataset Benchmark") |
|
|
click.echo("Vietnamese News Text Classification (10 Topics)") |
|
|
click.echo("=" * 70) |
|
|
|
|
|
train_dir = os.path.join(data_dir, "Train_Full") |
|
|
test_dir = os.path.join(data_dir, "Test_Full") |
|
|
|
|
|
|
|
|
click.echo("\nLoading training data...") |
|
|
t0 = time.perf_counter() |
|
|
train_texts, train_labels = [], [] |
|
|
for folder in sorted(os.listdir(train_dir)): |
|
|
folder_path = os.path.join(train_dir, folder) |
|
|
if not os.path.isdir(folder_path): |
|
|
continue |
|
|
for fname in os.listdir(folder_path): |
|
|
if fname.endswith('.txt'): |
|
|
text = read_file(os.path.join(folder_path, fname)) |
|
|
if text: |
|
|
train_texts.append(text) |
|
|
train_labels.append(folder) |
|
|
click.echo(f" Loaded {len(train_texts)} training samples in {time.perf_counter()-t0:.1f}s") |
|
|
|
|
|
click.echo("Loading test data...") |
|
|
t0 = time.perf_counter() |
|
|
test_texts, test_labels = [], [] |
|
|
for folder in sorted(os.listdir(test_dir)): |
|
|
folder_path = os.path.join(test_dir, folder) |
|
|
if not os.path.isdir(folder_path): |
|
|
continue |
|
|
for fname in os.listdir(folder_path): |
|
|
if fname.endswith('.txt'): |
|
|
text = read_file(os.path.join(folder_path, fname)) |
|
|
if text: |
|
|
test_texts.append(text) |
|
|
test_labels.append(folder) |
|
|
click.echo(f" Loaded {len(test_texts)} test samples in {time.perf_counter()-t0:.1f}s") |
|
|
|
|
|
|
|
|
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels) |
|
|
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels) |
|
|
|
|
|
print_comparison(sklearn_results, rust_results) |
|
|
|
|
|
if save_model: |
|
|
model_path = Path(output) |
|
|
model_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
rust_results['clf'].save(str(model_path)) |
|
|
size_mb = model_path.stat().st_size / (1024 * 1024) |
|
|
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)") |
|
|
|
|
|
|
|
|
@cli.command() |
|
|
@click.option('--save-model', is_flag=True, help='Save the trained Rust model') |
|
|
@click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path') |
|
|
def bank(save_model, output): |
|
|
"""Benchmark on UTS2017_Bank dataset (14 categories, banking domain).""" |
|
|
from datasets import load_dataset |
|
|
|
|
|
click.echo("=" * 70) |
|
|
click.echo("UTS2017_Bank Dataset Benchmark") |
|
|
click.echo("Vietnamese Banking Domain Text Classification (14 Categories)") |
|
|
click.echo("=" * 70) |
|
|
|
|
|
|
|
|
click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...") |
|
|
dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification") |
|
|
|
|
|
train_texts = list(dataset["train"]["text"]) |
|
|
train_labels = list(dataset["train"]["label"]) |
|
|
test_texts = list(dataset["test"]["text"]) |
|
|
test_labels = list(dataset["test"]["label"]) |
|
|
|
|
|
click.echo(f" Train samples: {len(train_texts)}") |
|
|
click.echo(f" Test samples: {len(test_texts)}") |
|
|
click.echo(f" Categories: {len(set(train_labels))}") |
|
|
|
|
|
|
|
|
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000) |
|
|
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000) |
|
|
|
|
|
print_comparison(sklearn_results, rust_results) |
|
|
|
|
|
click.echo("\nClassification Report (Rust):") |
|
|
click.echo(classification_report(test_labels, rust_results['preds'])) |
|
|
|
|
|
if save_model: |
|
|
model_path = Path(output) |
|
|
model_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
rust_results['clf'].save(str(model_path)) |
|
|
size_mb = model_path.stat().st_size / (1024 * 1024) |
|
|
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)") |
|
|
|
|
|
|
|
|
@cli.command() |
|
|
@click.option('--train-per-cat', default=340, help='Training samples per category') |
|
|
@click.option('--test-per-cat', default=500, help='Test samples per category') |
|
|
@click.option('--seed', default=42, help='Random seed') |
|
|
def synthetic(train_per_cat, test_per_cat, seed): |
|
|
"""Benchmark on synthetic VNTC-like data.""" |
|
|
|
|
|
TEMPLATES = { |
|
|
"the_thao": ["Đội tuyển {} thắng {} với tỷ số {}", "Cầu thủ {} ghi bàn đẹp mắt"], |
|
|
"kinh_doanh": ["Chứng khoán {} điểm trong phiên giao dịch", "Ngân hàng {} công bố lãi suất {}"], |
|
|
"cong_nghe": ["Apple ra mắt {} với nhiều tính năng", "Trí tuệ nhân tạo đang thay đổi {}"], |
|
|
"chinh_tri": ["Quốc hội thông qua nghị quyết về {}", "Chủ tịch {} tiếp đón phái đoàn"], |
|
|
"van_hoa": ["Nghệ sĩ {} ra mắt album mới", "Liên hoan phim {} trao giải"], |
|
|
"khoa_hoc": ["Nhà khoa học phát hiện {} mới", "Nghiên cứu cho thấy {} có tác dụng"], |
|
|
"suc_khoe": ["Bộ Y tế cảnh báo về {} trong mùa", "Vaccine {} đạt hiệu quả cao"], |
|
|
"giao_duc": ["Trường {} công bố điểm chuẩn", "Học sinh đoạt huy chương tại Olympic"], |
|
|
"phap_luat": ["Tòa án xét xử vụ án {} với bị cáo", "Công an triệt phá đường dây"], |
|
|
"doi_song": ["Giá {} tăng trong tháng", "Người dân đổ xô đi mua {}"], |
|
|
} |
|
|
FILLS = { |
|
|
"the_thao": ["Việt Nam", "Thái Lan", "3-0", "bóng đá", "AFF Cup"], |
|
|
"kinh_doanh": ["tăng", "giảm", "VN-Index", "Vietcombank", "8%"], |
|
|
"cong_nghe": ["iPhone 16", "ChatGPT", "công việc", "VinAI", "5G"], |
|
|
"chinh_tri": ["kinh tế", "nước", "Trung Quốc", "Hà Nội", "phát triển"], |
|
|
"van_hoa": ["Mỹ Tâm", "Cannes", "nghệ thuật", "Hà Nội", "Bố Già"], |
|
|
"khoa_hoc": ["loài sinh vật", "trà xanh", "VNREDSat-1", "Nobel", "robot"], |
|
|
"suc_khoe": ["dịch cúm", "COVID-19", "Bạch Mai", "dinh dưỡng", "tiểu đường"], |
|
|
"giao_duc": ["Bách Khoa", "Việt Nam", "Toán", "THPT", "STEM"], |
|
|
"phap_luat": ["tham nhũng", "TP.HCM", "ma túy", "Hình sự", "gian lận"], |
|
|
"doi_song": ["xăng", "vàng", "nắng nóng", "Trung thu", "bún chả"], |
|
|
} |
|
|
|
|
|
def generate_sample(category): |
|
|
template = random.choice(TEMPLATES[category]) |
|
|
fills = FILLS[category] |
|
|
n = template.count("{}") |
|
|
return template.format(*random.choices(fills, k=n)) |
|
|
|
|
|
def generate_dataset(n_per_cat, categories): |
|
|
texts, labels = [], [] |
|
|
for cat in categories: |
|
|
for _ in range(n_per_cat): |
|
|
texts.append(generate_sample(cat)) |
|
|
labels.append(cat) |
|
|
combined = list(zip(texts, labels)) |
|
|
random.shuffle(combined) |
|
|
return [t for t, _ in combined], [l for _, l in combined] |
|
|
|
|
|
click.echo("=" * 70) |
|
|
click.echo("Synthetic VNTC-like Benchmark") |
|
|
click.echo("=" * 70) |
|
|
|
|
|
random.seed(seed) |
|
|
categories = list(TEMPLATES.keys()) |
|
|
|
|
|
click.echo(f"\nConfiguration:") |
|
|
click.echo(f" Categories: {len(categories)}") |
|
|
click.echo(f" Train samples: {train_per_cat * len(categories)}") |
|
|
click.echo(f" Test samples: {test_per_cat * len(categories)}") |
|
|
|
|
|
train_texts, train_labels = generate_dataset(train_per_cat, categories) |
|
|
test_texts, test_labels = generate_dataset(test_per_cat, categories) |
|
|
|
|
|
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000) |
|
|
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000) |
|
|
|
|
|
print_comparison(sklearn_results, rust_results) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
cli() |
|
|
|