File size: 13,404 Bytes
b059f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903cdb2
b059f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""
Benchmark CLI for Vietnamese Text Classification.

Compares Rust TextClassifier vs sklearn.

Usage:
    python bench.py vntc
    python bench.py bank
    python bench.py synthetic
"""

import os
import time
import random
from pathlib import Path

import click
from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVectorizer
from sklearn.svm import LinearSVC as SklearnLinearSVC
from sklearn.metrics import accuracy_score, f1_score, classification_report

from underthesea import TextClassifier


def read_file(filepath):
    """Read text file with multiple encoding attempts."""
    for enc in ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']:
        try:
            with open(filepath, 'r', encoding=enc) as f:
                text = ' '.join(f.read().split())
                if len(text) > 10:
                    return text
        except (UnicodeDecodeError, UnicodeError):
            continue
    return None


def benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=20000):
    """Benchmark scikit-learn TF-IDF + LinearSVC."""
    click.echo("\n" + "=" * 70)
    click.echo("scikit-learn: TfidfVectorizer + LinearSVC")
    click.echo("=" * 70)

    # Vectorize
    click.echo("  Vectorizing...")
    t0 = time.perf_counter()
    vectorizer = SklearnTfidfVectorizer(max_features=max_features, ngram_range=(1, 2), min_df=2)
    X_train = vectorizer.fit_transform(train_texts)
    X_test = vectorizer.transform(test_texts)
    vec_time = time.perf_counter() - t0
    click.echo(f"    Vectorization time: {vec_time:.2f}s")
    click.echo(f"    Vocabulary size: {len(vectorizer.vocabulary_)}")

    # Train
    click.echo("  Training LinearSVC...")
    t0 = time.perf_counter()
    clf = SklearnLinearSVC(C=1.0, max_iter=2000)
    clf.fit(X_train, train_labels)
    train_time = time.perf_counter() - t0
    click.echo(f"    Training time: {train_time:.2f}s")

    # End-to-end inference
    click.echo("  End-to-end inference...")
    t0 = time.perf_counter()
    X_test_e2e = vectorizer.transform(test_texts)
    preds = clf.predict(X_test_e2e)
    e2e_time = time.perf_counter() - t0
    e2e_throughput = len(test_texts) / e2e_time
    click.echo(f"    E2E time: {e2e_time:.2f}s ({e2e_throughput:.0f} samples/sec)")

    # Metrics
    acc = accuracy_score(test_labels, preds)
    f1_w = f1_score(test_labels, preds, average='weighted')
    click.echo(f"  Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")

    return {
        "total_train": vec_time + train_time,
        "e2e_throughput": e2e_throughput,
        "accuracy": acc,
        "f1_weighted": f1_w,
    }


def benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=20000):
    """Benchmark Rust TextClassifier."""
    click.echo("\n" + "=" * 70)
    click.echo("Rust: TextClassifier (underthesea_core)")
    click.echo("=" * 70)

    clf = TextClassifier(
        max_features=max_features,
        ngram_range=(1, 2),
        min_df=2,
        c=1.0,
        max_iter=1000,
        tol=0.1,
    )

    # Train
    click.echo("  Training...")
    t0 = time.perf_counter()
    clf.fit(list(train_texts), list(train_labels))
    train_time = time.perf_counter() - t0
    click.echo(f"    Training time: {train_time:.2f}s")
    click.echo(f"    Vocabulary size: {clf.n_features}")

    # Inference
    click.echo("  Inference...")
    t0 = time.perf_counter()
    preds = clf.predict_batch(list(test_texts))
    infer_time = time.perf_counter() - t0
    throughput = len(test_texts) / infer_time
    click.echo(f"    Inference time: {infer_time:.2f}s ({throughput:.0f} samples/sec)")

    # Metrics
    acc = accuracy_score(test_labels, preds)
    f1_w = f1_score(test_labels, preds, average='weighted')
    click.echo(f"  Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")

    return {
        "total_train": train_time,
        "throughput": throughput,
        "accuracy": acc,
        "f1_weighted": f1_w,
        "clf": clf,
        "preds": preds,
    }


def print_comparison(sklearn_results, rust_results):
    """Print comparison summary."""
    click.echo("\n" + "=" * 70)
    click.echo("COMPARISON SUMMARY")
    click.echo("=" * 70)
    click.echo(f"{'Metric':<30} {'sklearn':<20} {'Rust':<20}")
    click.echo("-" * 70)

    click.echo(f"{'Training time (s)':<30} {sklearn_results['total_train']:<20.2f} {rust_results['total_train']:<20.2f}")
    click.echo(f"{'Inference (samples/sec)':<30} {sklearn_results['e2e_throughput']:<20.0f} {rust_results['throughput']:<20.0f}")
    click.echo(f"{'Accuracy':<30} {sklearn_results['accuracy']:<20.4f} {rust_results['accuracy']:<20.4f}")
    click.echo(f"{'F1 (weighted)':<30} {sklearn_results['f1_weighted']:<20.4f} {rust_results['f1_weighted']:<20.4f}")

    click.echo("-" * 70)
    train_speedup = sklearn_results['total_train'] / rust_results['total_train'] if rust_results['total_train'] > 0 else 0
    infer_speedup = rust_results['throughput'] / sklearn_results['e2e_throughput'] if sklearn_results['e2e_throughput'] > 0 else 0
    click.echo(f"Speedup: Training {train_speedup:.2f}x, Inference {infer_speedup:.2f}x")
    click.echo("=" * 70)


@click.group()
def cli():
    """Benchmark Vietnamese text classification models."""
    pass


@cli.command()
@click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
              help='Path to VNTC dataset')
@click.option('--save-model', is_flag=True, help='Save the trained Rust model')
@click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
def vntc(data_dir, save_model, output):
    """Benchmark on VNTC dataset (10 topics, ~84k documents)."""
    click.echo("=" * 70)
    click.echo("VNTC Full Dataset Benchmark")
    click.echo("Vietnamese News Text Classification (10 Topics)")
    click.echo("=" * 70)

    train_dir = os.path.join(data_dir, "Train_Full")
    test_dir = os.path.join(data_dir, "Test_Full")

    # Load data
    click.echo("\nLoading training data...")
    t0 = time.perf_counter()
    train_texts, train_labels = [], []
    for folder in sorted(os.listdir(train_dir)):
        folder_path = os.path.join(train_dir, folder)
        if not os.path.isdir(folder_path):
            continue
        for fname in os.listdir(folder_path):
            if fname.endswith('.txt'):
                text = read_file(os.path.join(folder_path, fname))
                if text:
                    train_texts.append(text)
                    train_labels.append(folder)
    click.echo(f"  Loaded {len(train_texts)} training samples in {time.perf_counter()-t0:.1f}s")

    click.echo("Loading test data...")
    t0 = time.perf_counter()
    test_texts, test_labels = [], []
    for folder in sorted(os.listdir(test_dir)):
        folder_path = os.path.join(test_dir, folder)
        if not os.path.isdir(folder_path):
            continue
        for fname in os.listdir(folder_path):
            if fname.endswith('.txt'):
                text = read_file(os.path.join(folder_path, fname))
                if text:
                    test_texts.append(text)
                    test_labels.append(folder)
    click.echo(f"  Loaded {len(test_texts)} test samples in {time.perf_counter()-t0:.1f}s")

    # Run benchmarks
    sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels)
    rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels)

    print_comparison(sklearn_results, rust_results)

    if save_model:
        model_path = Path(output)
        model_path.parent.mkdir(parents=True, exist_ok=True)
        rust_results['clf'].save(str(model_path))
        size_mb = model_path.stat().st_size / (1024 * 1024)
        click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")


@cli.command()
@click.option('--save-model', is_flag=True, help='Save the trained Rust model')
@click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
def bank(save_model, output):
    """Benchmark on UTS2017_Bank dataset (14 categories, banking domain)."""
    from datasets import load_dataset

    click.echo("=" * 70)
    click.echo("UTS2017_Bank Dataset Benchmark")
    click.echo("Vietnamese Banking Domain Text Classification (14 Categories)")
    click.echo("=" * 70)

    # Load data
    click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
    dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")

    train_texts = list(dataset["train"]["text"])
    train_labels = list(dataset["train"]["label"])
    test_texts = list(dataset["test"]["text"])
    test_labels = list(dataset["test"]["label"])

    click.echo(f"  Train samples: {len(train_texts)}")
    click.echo(f"  Test samples: {len(test_texts)}")
    click.echo(f"  Categories: {len(set(train_labels))}")

    # Run benchmarks (smaller max_features for smaller dataset)
    sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
    rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)

    print_comparison(sklearn_results, rust_results)

    click.echo("\nClassification Report (Rust):")
    click.echo(classification_report(test_labels, rust_results['preds']))

    if save_model:
        model_path = Path(output)
        model_path.parent.mkdir(parents=True, exist_ok=True)
        rust_results['clf'].save(str(model_path))
        size_mb = model_path.stat().st_size / (1024 * 1024)
        click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")


@cli.command()
@click.option('--train-per-cat', default=340, help='Training samples per category')
@click.option('--test-per-cat', default=500, help='Test samples per category')
@click.option('--seed', default=42, help='Random seed')
def synthetic(train_per_cat, test_per_cat, seed):
    """Benchmark on synthetic VNTC-like data."""
    # Vietnamese text templates by category
    TEMPLATES = {
        "the_thao": ["Đội tuyển {} thắng {} với tỷ số {}", "Cầu thủ {} ghi bàn đẹp mắt"],
        "kinh_doanh": ["Chứng khoán {} điểm trong phiên giao dịch", "Ngân hàng {} công bố lãi suất {}"],
        "cong_nghe": ["Apple ra mắt {} với nhiều tính năng", "Trí tuệ nhân tạo đang thay đổi {}"],
        "chinh_tri": ["Quốc hội thông qua nghị quyết về {}", "Chủ tịch {} tiếp đón phái đoàn"],
        "van_hoa": ["Nghệ sĩ {} ra mắt album mới", "Liên hoan phim {} trao giải"],
        "khoa_hoc": ["Nhà khoa học phát hiện {} mới", "Nghiên cứu cho thấy {} có tác dụng"],
        "suc_khoe": ["Bộ Y tế cảnh báo về {} trong mùa", "Vaccine {} đạt hiệu quả cao"],
        "giao_duc": ["Trường {} công bố điểm chuẩn", "Học sinh đoạt huy chương tại Olympic"],
        "phap_luat": ["Tòa án xét xử vụ án {} với bị cáo", "Công an triệt phá đường dây"],
        "doi_song": ["Giá {} tăng trong tháng", "Người dân đổ xô đi mua {}"],
    }
    FILLS = {
        "the_thao": ["Việt Nam", "Thái Lan", "3-0", "bóng đá", "AFF Cup"],
        "kinh_doanh": ["tăng", "giảm", "VN-Index", "Vietcombank", "8%"],
        "cong_nghe": ["iPhone 16", "ChatGPT", "công việc", "VinAI", "5G"],
        "chinh_tri": ["kinh tế", "nước", "Trung Quốc", "Hà Nội", "phát triển"],
        "van_hoa": ["Mỹ Tâm", "Cannes", "nghệ thuật", "Hà Nội", "Bố Già"],
        "khoa_hoc": ["loài sinh vật", "trà xanh", "VNREDSat-1", "Nobel", "robot"],
        "suc_khoe": ["dịch cúm", "COVID-19", "Bạch Mai", "dinh dưỡng", "tiểu đường"],
        "giao_duc": ["Bách Khoa", "Việt Nam", "Toán", "THPT", "STEM"],
        "phap_luat": ["tham nhũng", "TP.HCM", "ma túy", "Hình sự", "gian lận"],
        "doi_song": ["xăng", "vàng", "nắng nóng", "Trung thu", "bún chả"],
    }

    def generate_sample(category):
        template = random.choice(TEMPLATES[category])
        fills = FILLS[category]
        n = template.count("{}")
        return template.format(*random.choices(fills, k=n))

    def generate_dataset(n_per_cat, categories):
        texts, labels = [], []
        for cat in categories:
            for _ in range(n_per_cat):
                texts.append(generate_sample(cat))
                labels.append(cat)
        combined = list(zip(texts, labels))
        random.shuffle(combined)
        return [t for t, _ in combined], [l for _, l in combined]

    click.echo("=" * 70)
    click.echo("Synthetic VNTC-like Benchmark")
    click.echo("=" * 70)

    random.seed(seed)
    categories = list(TEMPLATES.keys())

    click.echo(f"\nConfiguration:")
    click.echo(f"  Categories: {len(categories)}")
    click.echo(f"  Train samples: {train_per_cat * len(categories)}")
    click.echo(f"  Test samples: {test_per_cat * len(categories)}")

    train_texts, train_labels = generate_dataset(train_per_cat, categories)
    test_texts, test_labels = generate_dataset(test_per_cat, categories)

    sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
    rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)

    print_comparison(sklearn_results, rust_results)


if __name__ == "__main__":
    cli()