rain1024 Claude Opus 4.5 commited on
Commit
b059f86
·
1 Parent(s): 2ac432e

Refactor: move Rust extensions to underthesea_core upstream

Browse files

- Remove local Rust extensions (now in underthesea_core 3.1.7)
- Consolidate training scripts into unified CLI (src/train.py)
- Add benchmark CLI (src/bench.py) with vntc, bank, synthetic commands
- Remove src/sen module (TextClassifier now in underthesea_core)
- Add CLAUDE.md for project documentation
- Update dependencies to use underthesea_core>=3.1.7

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

CLAUDE.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Sen-1 is a lightweight Vietnamese text classification model combining TF-IDF vectorization with Linear SVM. Part of the UnderTheSea NLP ecosystem, it serves as a practical baseline compatible with the underthesea API.
8
+
9
+ ## Build & Development Commands
10
+
11
+ ### Running Tests
12
+ ```bash
13
+ pytest tests/test_classifier.py
14
+ ```
15
+
16
+ ### Development Installation
17
+ ```bash
18
+ pip install -e ".[dev]"
19
+ ```
20
+
21
+ ### Training on VNTC Dataset
22
+ ```bash
23
+ python src/scripts/train_vntc.py
24
+ ```
25
+
26
+ ## Architecture
27
+
28
+ **3-Stage Pipeline:**
29
+ ```
30
+ Input Text → TF-IDF Vectorizer (max_features=20k, ngram 1-2)
31
+ → Linear SVM (C=1.0, max_iter=1000)
32
+ → Label + Confidence
33
+ ```
34
+
35
+ **Key Design Decisions:**
36
+ - Operates at syllable-level (no word segmentation) for speed
37
+ - All-Rust implementation via `underthesea_core` for fast training and inference
38
+ - Model serialization uses binary format (bincode)
39
+
40
+ **Core Module:** `src/sen/text_classifier.py` contains `SenTextClassifier` wrapper with train/predict/evaluate/save/load methods.
41
+
42
+ **Rust Backend:** Uses `underthesea_core.TextClassifier` which combines TF-IDF vectorization and Linear SVM in a unified Rust implementation via PyO3.
43
+
44
+ ## Public API
45
+
46
+ ```python
47
+ from sen import SenTextClassifier, Sentence, Label, classify
48
+
49
+ # Pre-trained model inference
50
+ labels = classify("Văn bản tiếng Việt", model_path="models/sen-1")
51
+
52
+ # Custom training
53
+ clf = SenTextClassifier()
54
+ clf.train(texts, labels)
55
+ clf.save("my_model")
56
+ ```
57
+
58
+ ## Key Files
59
+
60
+ - `src/sen/text_classifier.py` - Main classifier implementation (wraps underthesea_core)
61
+ - `src/scripts/train_vntc.py` - Training script for VNTC dataset
62
+ - `src/scripts/bench_vntc_full.py` - Benchmark comparing sklearn vs Rust
63
+ - `TECHNICAL_REPORT.md` - Detailed methodology and benchmark results
64
+ - `RESEARCH_PLAN.md` - Future work roadmap (PhoBERT comparison, word segmentation)
65
+
66
+ ## Performance Benchmarks
67
+
68
+ - VNTC (news, 10 topics): 92.49% accuracy, 37.6s training
69
+ - UTS2017_Bank (14 categories): 75.76% accuracy
70
+ - Inference: 66,678 samples/sec batch, 0.465ms single
71
+
72
+ ## Known Limitations
73
+
74
+ 1. Syllable-level only (no word segmentation) - ~4.6% gap vs word-level approaches
75
+ 2. Single-label classification only
76
+ 3. Trained on news domain - may not generalize to social media/reviews
77
+ 4. Lower performance on imbalanced categories
extensions/underthesea_core_extend/.gitignore DELETED
@@ -1,5 +0,0 @@
1
- target/
2
- .venv/
3
- __pycache__/
4
- *.so
5
- *.pyc
 
 
 
 
 
 
extensions/underthesea_core_extend/Cargo.lock DELETED
@@ -1,351 +0,0 @@
1
- # This file is automatically @generated by Cargo.
2
- # It is not intended for manual editing.
3
- version = 4
4
-
5
- [[package]]
6
- name = "allocator-api2"
7
- version = "0.2.21"
8
- source = "registry+https://github.com/rust-lang/crates.io-index"
9
- checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
10
-
11
- [[package]]
12
- name = "autocfg"
13
- version = "1.5.0"
14
- source = "registry+https://github.com/rust-lang/crates.io-index"
15
- checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
16
-
17
- [[package]]
18
- name = "cfg-if"
19
- version = "1.0.4"
20
- source = "registry+https://github.com/rust-lang/crates.io-index"
21
- checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
22
-
23
- [[package]]
24
- name = "crossbeam-deque"
25
- version = "0.8.6"
26
- source = "registry+https://github.com/rust-lang/crates.io-index"
27
- checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
28
- dependencies = [
29
- "crossbeam-epoch",
30
- "crossbeam-utils",
31
- ]
32
-
33
- [[package]]
34
- name = "crossbeam-epoch"
35
- version = "0.9.18"
36
- source = "registry+https://github.com/rust-lang/crates.io-index"
37
- checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
38
- dependencies = [
39
- "crossbeam-utils",
40
- ]
41
-
42
- [[package]]
43
- name = "crossbeam-utils"
44
- version = "0.8.21"
45
- source = "registry+https://github.com/rust-lang/crates.io-index"
46
- checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
47
-
48
- [[package]]
49
- name = "either"
50
- version = "1.15.0"
51
- source = "registry+https://github.com/rust-lang/crates.io-index"
52
- checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
53
-
54
- [[package]]
55
- name = "equivalent"
56
- version = "1.0.2"
57
- source = "registry+https://github.com/rust-lang/crates.io-index"
58
- checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
59
-
60
- [[package]]
61
- name = "foldhash"
62
- version = "0.1.5"
63
- source = "registry+https://github.com/rust-lang/crates.io-index"
64
- checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
65
-
66
- [[package]]
67
- name = "hashbrown"
68
- version = "0.15.5"
69
- source = "registry+https://github.com/rust-lang/crates.io-index"
70
- checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
71
- dependencies = [
72
- "allocator-api2",
73
- "equivalent",
74
- "foldhash",
75
- "serde",
76
- ]
77
-
78
- [[package]]
79
- name = "heck"
80
- version = "0.5.0"
81
- source = "registry+https://github.com/rust-lang/crates.io-index"
82
- checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
83
-
84
- [[package]]
85
- name = "indoc"
86
- version = "2.0.7"
87
- source = "registry+https://github.com/rust-lang/crates.io-index"
88
- checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
89
- dependencies = [
90
- "rustversion",
91
- ]
92
-
93
- [[package]]
94
- name = "itoa"
95
- version = "1.0.17"
96
- source = "registry+https://github.com/rust-lang/crates.io-index"
97
- checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
98
-
99
- [[package]]
100
- name = "libc"
101
- version = "0.2.180"
102
- source = "registry+https://github.com/rust-lang/crates.io-index"
103
- checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
104
-
105
- [[package]]
106
- name = "memchr"
107
- version = "2.7.6"
108
- source = "registry+https://github.com/rust-lang/crates.io-index"
109
- checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
110
-
111
- [[package]]
112
- name = "memoffset"
113
- version = "0.9.1"
114
- source = "registry+https://github.com/rust-lang/crates.io-index"
115
- checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
116
- dependencies = [
117
- "autocfg",
118
- ]
119
-
120
- [[package]]
121
- name = "once_cell"
122
- version = "1.21.3"
123
- source = "registry+https://github.com/rust-lang/crates.io-index"
124
- checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
125
-
126
- [[package]]
127
- name = "portable-atomic"
128
- version = "1.13.1"
129
- source = "registry+https://github.com/rust-lang/crates.io-index"
130
- checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
131
-
132
- [[package]]
133
- name = "proc-macro2"
134
- version = "1.0.106"
135
- source = "registry+https://github.com/rust-lang/crates.io-index"
136
- checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
137
- dependencies = [
138
- "unicode-ident",
139
- ]
140
-
141
- [[package]]
142
- name = "pyo3"
143
- version = "0.22.6"
144
- source = "registry+https://github.com/rust-lang/crates.io-index"
145
- checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
146
- dependencies = [
147
- "cfg-if",
148
- "indoc",
149
- "libc",
150
- "memoffset",
151
- "once_cell",
152
- "portable-atomic",
153
- "pyo3-build-config",
154
- "pyo3-ffi",
155
- "pyo3-macros",
156
- "unindent",
157
- ]
158
-
159
- [[package]]
160
- name = "pyo3-build-config"
161
- version = "0.22.6"
162
- source = "registry+https://github.com/rust-lang/crates.io-index"
163
- checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
164
- dependencies = [
165
- "once_cell",
166
- "target-lexicon",
167
- ]
168
-
169
- [[package]]
170
- name = "pyo3-ffi"
171
- version = "0.22.6"
172
- source = "registry+https://github.com/rust-lang/crates.io-index"
173
- checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
174
- dependencies = [
175
- "libc",
176
- "pyo3-build-config",
177
- ]
178
-
179
- [[package]]
180
- name = "pyo3-macros"
181
- version = "0.22.6"
182
- source = "registry+https://github.com/rust-lang/crates.io-index"
183
- checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
184
- dependencies = [
185
- "proc-macro2",
186
- "pyo3-macros-backend",
187
- "quote",
188
- "syn",
189
- ]
190
-
191
- [[package]]
192
- name = "pyo3-macros-backend"
193
- version = "0.22.6"
194
- source = "registry+https://github.com/rust-lang/crates.io-index"
195
- checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
196
- dependencies = [
197
- "heck",
198
- "proc-macro2",
199
- "pyo3-build-config",
200
- "quote",
201
- "syn",
202
- ]
203
-
204
- [[package]]
205
- name = "quote"
206
- version = "1.0.44"
207
- source = "registry+https://github.com/rust-lang/crates.io-index"
208
- checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
209
- dependencies = [
210
- "proc-macro2",
211
- ]
212
-
213
- [[package]]
214
- name = "rayon"
215
- version = "1.11.0"
216
- source = "registry+https://github.com/rust-lang/crates.io-index"
217
- checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
218
- dependencies = [
219
- "either",
220
- "rayon-core",
221
- ]
222
-
223
- [[package]]
224
- name = "rayon-core"
225
- version = "1.13.0"
226
- source = "registry+https://github.com/rust-lang/crates.io-index"
227
- checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
228
- dependencies = [
229
- "crossbeam-deque",
230
- "crossbeam-utils",
231
- ]
232
-
233
- [[package]]
234
- name = "rustversion"
235
- version = "1.0.22"
236
- source = "registry+https://github.com/rust-lang/crates.io-index"
237
- checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
238
-
239
- [[package]]
240
- name = "serde"
241
- version = "1.0.228"
242
- source = "registry+https://github.com/rust-lang/crates.io-index"
243
- checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
244
- dependencies = [
245
- "serde_core",
246
- "serde_derive",
247
- ]
248
-
249
- [[package]]
250
- name = "serde_core"
251
- version = "1.0.228"
252
- source = "registry+https://github.com/rust-lang/crates.io-index"
253
- checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
254
- dependencies = [
255
- "serde_derive",
256
- ]
257
-
258
- [[package]]
259
- name = "serde_derive"
260
- version = "1.0.228"
261
- source = "registry+https://github.com/rust-lang/crates.io-index"
262
- checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
263
- dependencies = [
264
- "proc-macro2",
265
- "quote",
266
- "syn",
267
- ]
268
-
269
- [[package]]
270
- name = "serde_json"
271
- version = "1.0.149"
272
- source = "registry+https://github.com/rust-lang/crates.io-index"
273
- checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
274
- dependencies = [
275
- "itoa",
276
- "memchr",
277
- "serde",
278
- "serde_core",
279
- "zmij",
280
- ]
281
-
282
- [[package]]
283
- name = "syn"
284
- version = "2.0.114"
285
- source = "registry+https://github.com/rust-lang/crates.io-index"
286
- checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
287
- dependencies = [
288
- "proc-macro2",
289
- "quote",
290
- "unicode-ident",
291
- ]
292
-
293
- [[package]]
294
- name = "target-lexicon"
295
- version = "0.12.16"
296
- source = "registry+https://github.com/rust-lang/crates.io-index"
297
- checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
298
-
299
- [[package]]
300
- name = "tinyvec"
301
- version = "1.10.0"
302
- source = "registry+https://github.com/rust-lang/crates.io-index"
303
- checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
304
- dependencies = [
305
- "tinyvec_macros",
306
- ]
307
-
308
- [[package]]
309
- name = "tinyvec_macros"
310
- version = "0.1.1"
311
- source = "registry+https://github.com/rust-lang/crates.io-index"
312
- checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
313
-
314
- [[package]]
315
- name = "underthesea_core_extend"
316
- version = "0.1.0"
317
- dependencies = [
318
- "hashbrown",
319
- "pyo3",
320
- "rayon",
321
- "serde",
322
- "serde_json",
323
- "unicode-normalization",
324
- ]
325
-
326
- [[package]]
327
- name = "unicode-ident"
328
- version = "1.0.22"
329
- source = "registry+https://github.com/rust-lang/crates.io-index"
330
- checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
331
-
332
- [[package]]
333
- name = "unicode-normalization"
334
- version = "0.1.25"
335
- source = "registry+https://github.com/rust-lang/crates.io-index"
336
- checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
337
- dependencies = [
338
- "tinyvec",
339
- ]
340
-
341
- [[package]]
342
- name = "unindent"
343
- version = "0.2.4"
344
- source = "registry+https://github.com/rust-lang/crates.io-index"
345
- checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
346
-
347
- [[package]]
348
- name = "zmij"
349
- version = "1.0.19"
350
- source = "registry+https://github.com/rust-lang/crates.io-index"
351
- checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extensions/underthesea_core_extend/Cargo.toml DELETED
@@ -1,23 +0,0 @@
1
- [package]
2
- name = "underthesea_core_extend"
3
- version = "0.1.0"
4
- edition = "2021"
5
- description = "Rust extensions for underthesea - Text Classification"
6
- license = "Apache-2.0"
7
-
8
- [lib]
9
- name = "underthesea_core_extend"
10
- crate-type = ["cdylib"]
11
-
12
- [dependencies]
13
- pyo3 = { version = "0.22", features = ["extension-module"] }
14
- serde = { version = "1.0", features = ["derive"] }
15
- serde_json = "1.0"
16
- rayon = "1.10"
17
- hashbrown = { version = "0.15", features = ["serde"] }
18
- unicode-normalization = "0.1"
19
-
20
- [profile.release]
21
- lto = true
22
- codegen-units = 1
23
- opt-level = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extensions/underthesea_core_extend/pyproject.toml DELETED
@@ -1,22 +0,0 @@
1
- [build-system]
2
- requires = ["maturin>=1.0,<2.0"]
3
- build-backend = "maturin"
4
-
5
- [project]
6
- name = "underthesea_core_extend"
7
- version = "0.1.0"
8
- description = "Rust extensions for underthesea - Text Classification"
9
- requires-python = ">=3.10"
10
- license = { text = "Apache-2.0" }
11
- authors = [{ name = "UnderTheSea NLP", email = "anhv.ict91@gmail.com" }]
12
- classifiers = [
13
- "Programming Language :: Rust",
14
- "Programming Language :: Python :: Implementation :: CPython",
15
- "Programming Language :: Python :: 3.10",
16
- "Programming Language :: Python :: 3.11",
17
- "Programming Language :: Python :: 3.12",
18
- ]
19
-
20
- [tool.maturin]
21
- features = ["pyo3/extension-module"]
22
- module-name = "underthesea_core_extend"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extensions/underthesea_core_extend/src/lib.rs DELETED
@@ -1,21 +0,0 @@
1
- //! underthesea_core_extend - Rust extensions for Vietnamese Text Classification
2
- //!
3
- //! Provides fast TF-IDF vectorization and Linear SVM classification.
4
-
5
- use pyo3::prelude::*;
6
-
7
- mod tfidf;
8
- mod svm;
9
-
10
- pub use tfidf::TfIdfVectorizer;
11
- pub use svm::{LinearSVM, SVMTrainer, FastSVMTrainer};
12
-
13
- /// Python module
14
- #[pymodule]
15
- fn underthesea_core_extend(m: &Bound<'_, PyModule>) -> PyResult<()> {
16
- m.add_class::<TfIdfVectorizer>()?;
17
- m.add_class::<LinearSVM>()?;
18
- m.add_class::<SVMTrainer>()?;
19
- m.add_class::<FastSVMTrainer>()?;
20
- Ok(())
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extensions/underthesea_core_extend/src/svm.rs DELETED
@@ -1,512 +0,0 @@
1
- //! Optimized Linear SVM - LIBLINEAR-style Dual Coordinate Descent
2
- //!
3
- //! Pure Rust implementation of L2-regularized L2-loss SVM (dual form)
4
- //! Reference: "A Dual Coordinate Descent Method for Large-scale Linear SVM"
5
- //! Hsieh et al., ICML 2008
6
-
7
- use hashbrown::HashMap;
8
- use pyo3::prelude::*;
9
- use rayon::prelude::*;
10
- use serde::{Deserialize, Serialize};
11
- use std::fs::File;
12
- use std::io::{BufReader, BufWriter};
13
-
14
- /// Sparse feature vector
15
- pub type SparseVec = Vec<(u32, f32)>; // Use u32/f32 for memory efficiency
16
-
17
- /// Linear SVM Model
18
- #[pyclass]
19
- #[derive(Clone, Serialize, Deserialize)]
20
- pub struct LinearSVM {
21
- weights: Vec<Vec<f32>>,
22
- biases: Vec<f32>,
23
- classes: Vec<String>,
24
- n_features: usize,
25
- }
26
-
27
- #[pymethods]
28
- impl LinearSVM {
29
- #[new]
30
- pub fn new() -> Self {
31
- Self {
32
- weights: Vec::new(),
33
- biases: Vec::new(),
34
- classes: Vec::new(),
35
- n_features: 0,
36
- }
37
- }
38
-
39
- pub fn predict(&self, features: Vec<f64>) -> String {
40
- let idx = self.predict_idx(&features);
41
- self.classes[idx].clone()
42
- }
43
-
44
- pub fn predict_with_score(&self, features: Vec<f64>) -> (String, f64) {
45
- let scores = self.decision_scores(&features);
46
- let (idx, &max_score) = scores
47
- .iter()
48
- .enumerate()
49
- .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
50
- .unwrap();
51
- let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
52
- (self.classes[idx].clone(), confidence)
53
- }
54
-
55
- pub fn predict_batch(&self, features_batch: Vec<Vec<f64>>) -> Vec<String> {
56
- features_batch
57
- .par_iter()
58
- .map(|f| {
59
- let idx = self.predict_idx(f);
60
- self.classes[idx].clone()
61
- })
62
- .collect()
63
- }
64
-
65
- pub fn predict_batch_sparse(&self, features_batch: Vec<Vec<(usize, f64)>>) -> Vec<String> {
66
- features_batch
67
- .par_iter()
68
- .map(|f| {
69
- let idx = self.predict_idx_sparse(f);
70
- self.classes[idx].clone()
71
- })
72
- .collect()
73
- }
74
-
75
- pub fn predict_sparse_with_score(&self, features: Vec<(usize, f64)>) -> (String, f64) {
76
- let scores = self.decision_scores_sparse(&features);
77
- let (idx, &max_score) = scores
78
- .iter()
79
- .enumerate()
80
- .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
81
- .unwrap();
82
- let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
83
- (self.classes[idx].clone(), confidence)
84
- }
85
-
86
- pub fn decision_function(&self, features: Vec<f64>) -> Vec<f64> {
87
- self.decision_scores(&features).into_iter().map(|x| x as f64).collect()
88
- }
89
-
90
- #[getter]
91
- pub fn classes(&self) -> Vec<String> {
92
- self.classes.clone()
93
- }
94
-
95
- #[getter]
96
- pub fn n_classes(&self) -> usize {
97
- self.classes.len()
98
- }
99
-
100
- #[getter]
101
- pub fn n_features(&self) -> usize {
102
- self.n_features
103
- }
104
-
105
- pub fn save(&self, path: &str) -> PyResult<()> {
106
- let file = File::create(path)
107
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
108
- let writer = BufWriter::new(file);
109
- serde_json::to_writer(writer, self)
110
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
111
- Ok(())
112
- }
113
-
114
- #[staticmethod]
115
- pub fn load(path: &str) -> PyResult<Self> {
116
- let file = File::open(path)
117
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
118
- let reader = BufReader::new(file);
119
- let model: Self = serde_json::from_reader(reader)
120
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
121
- Ok(model)
122
- }
123
- }
124
-
125
- impl LinearSVM {
126
- #[inline]
127
- fn predict_idx(&self, features: &[f64]) -> usize {
128
- let mut best_idx = 0;
129
- let mut best_score = f32::NEG_INFINITY;
130
-
131
- for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
132
- let score: f32 = w.iter()
133
- .zip(features.iter())
134
- .map(|(&wi, &fi)| wi * fi as f32)
135
- .sum::<f32>() + b;
136
-
137
- if score > best_score {
138
- best_score = score;
139
- best_idx = idx;
140
- }
141
- }
142
- best_idx
143
- }
144
-
145
- #[inline]
146
- fn predict_idx_sparse(&self, features: &[(usize, f64)]) -> usize {
147
- let mut best_idx = 0;
148
- let mut best_score = f32::NEG_INFINITY;
149
-
150
- for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
151
- let score: f32 = features.iter()
152
- .map(|&(j, v)| w[j] * v as f32)
153
- .sum::<f32>() + b;
154
-
155
- if score > best_score {
156
- best_score = score;
157
- best_idx = idx;
158
- }
159
- }
160
- best_idx
161
- }
162
-
163
- fn decision_scores(&self, features: &[f64]) -> Vec<f32> {
164
- self.weights
165
- .iter()
166
- .zip(self.biases.iter())
167
- .map(|(w, &b)| {
168
- w.iter()
169
- .zip(features.iter())
170
- .map(|(&wi, &fi)| wi * fi as f32)
171
- .sum::<f32>() + b
172
- })
173
- .collect()
174
- }
175
-
176
- fn decision_scores_sparse(&self, features: &[(usize, f64)]) -> Vec<f32> {
177
- self.weights
178
- .iter()
179
- .zip(self.biases.iter())
180
- .map(|(w, &b)| {
181
- features.iter()
182
- .map(|&(j, v)| w[j] * v as f32)
183
- .sum::<f32>() + b
184
- })
185
- .collect()
186
- }
187
- }
188
-
189
- /// LIBLINEAR-style SVM Trainer
190
- #[pyclass]
191
- pub struct SVMTrainer {
192
- c: f64,
193
- max_iter: usize,
194
- tol: f64,
195
- verbose: bool,
196
- }
197
-
198
- #[pymethods]
199
- impl SVMTrainer {
200
- #[new]
201
- #[pyo3(signature = (c=1.0, max_iter=1000, tol=0.1, verbose=false))]
202
- pub fn new(c: f64, max_iter: usize, tol: f64, verbose: bool) -> Self {
203
- Self { c, max_iter, tol, verbose }
204
- }
205
-
206
- pub fn set_c(&mut self, c: f64) {
207
- self.c = c;
208
- }
209
-
210
- pub fn set_max_iter(&mut self, max_iter: usize) {
211
- self.max_iter = max_iter;
212
- }
213
-
214
- pub fn set_verbose(&mut self, verbose: bool) {
215
- self.verbose = verbose;
216
- }
217
-
218
- pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
219
- let n_samples = features.len();
220
- let n_features = if n_samples > 0 { features[0].len() } else { 0 };
221
-
222
- // Convert to compact sparse format (f32 for memory/cache efficiency)
223
- let sparse_features: Vec<SparseVec> = features
224
- .par_iter()
225
- .map(|dense| {
226
- dense
227
- .iter()
228
- .enumerate()
229
- .filter(|&(_, &v)| v.abs() > 1e-10)
230
- .map(|(i, &v)| (i as u32, v as f32))
231
- .collect()
232
- })
233
- .collect();
234
-
235
- // Precompute ||x_i||^2
236
- let x_sq_norms: Vec<f32> = sparse_features
237
- .par_iter()
238
- .map(|x| x.iter().map(|&(_, v)| v * v).sum())
239
- .collect();
240
-
241
- // Get unique classes
242
- let mut classes: Vec<String> = labels.iter().cloned().collect();
243
- classes.sort();
244
- classes.dedup();
245
- let n_classes = classes.len();
246
-
247
- let class_to_idx: HashMap<String, usize> = classes
248
- .iter()
249
- .enumerate()
250
- .map(|(i, c)| (c.clone(), i))
251
- .collect();
252
-
253
- let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
254
-
255
- // Train binary classifiers in parallel (one-vs-rest)
256
- let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
257
- .into_par_iter()
258
- .map(|class_idx| {
259
- let y_binary: Vec<i8> = y_idx
260
- .iter()
261
- .map(|&idx| if idx == class_idx { 1 } else { -1 })
262
- .collect();
263
-
264
- solve_l2r_l2_svc(
265
- &sparse_features,
266
- &y_binary,
267
- &x_sq_norms,
268
- n_features,
269
- self.c as f32,
270
- self.tol as f32,
271
- self.max_iter,
272
- )
273
- })
274
- .collect();
275
-
276
- let weights = results.iter().map(|(w, _)| w.clone()).collect();
277
- let biases = results.iter().map(|(_, b)| *b).collect();
278
-
279
- LinearSVM {
280
- weights,
281
- biases,
282
- classes,
283
- n_features,
284
- }
285
- }
286
- }
287
-
288
- /// LIBLINEAR's solve_l2r_l2_svc - Dual Coordinate Descent for L2-loss SVM
289
- ///
290
- /// Solves: min_α 0.5 * α^T * Q * α - e^T * α, s.t. α_i ≥ 0
291
- /// where Q_ij = y_i * y_j * x_i^T * x_j + δ_ij / (2C)
292
- ///
293
- /// Primal-dual relationship: w = Σ α_i * y_i * x_i
294
- #[inline(never)]
295
- fn solve_l2r_l2_svc(
296
- x: &[SparseVec],
297
- y: &[i8],
298
- x_sq_norms: &[f32],
299
- n_features: usize,
300
- c: f32,
301
- eps: f32,
302
- max_iter: usize,
303
- ) -> (Vec<f32>, f32) {
304
- let n = x.len();
305
-
306
- // D_ii = 1/(2C) for L2-loss SVM
307
- let diag = 0.5 / c;
308
-
309
- // QD[i] = ||x_i||^2 + D_ii
310
- let qd: Vec<f32> = x_sq_norms.iter().map(|&xn| xn + diag).collect();
311
-
312
- // Initialize α = 0
313
- let mut alpha = vec![0.0f32; n];
314
-
315
- // w = Σ α_i * y_i * x_i (initially 0)
316
- let mut w = vec![0.0f32; n_features];
317
-
318
- // Index for permutation
319
- let mut index: Vec<usize> = (0..n).collect();
320
-
321
- // Main loop
322
- for iter in 0..max_iter {
323
- // Shuffle indices
324
- for i in 0..n {
325
- let j = i + (iter * 1103515245 + 12345) % (n - i).max(1);
326
- index.swap(i, j);
327
- }
328
-
329
- let mut max_violation = 0.0f32;
330
-
331
- for &i in &index {
332
- let yi = y[i] as f32;
333
- let xi = &x[i];
334
-
335
- // G = y_i * (w · x_i) - 1 + D_ii * α_i
336
- let wxi: f32 = xi.iter().map(|&(j, v)| w[j as usize] * v).sum();
337
- let g = yi * wxi - 1.0 + diag * alpha[i];
338
-
339
- // Projected gradient (α ≥ 0, no upper bound for L2-loss)
340
- let pg = if alpha[i] == 0.0 { g.min(0.0) } else { g };
341
-
342
- max_violation = max_violation.max(pg.abs());
343
-
344
- if pg.abs() > 1e-12 {
345
- let alpha_old = alpha[i];
346
-
347
- // α_i = max(0, α_i - G/Q_ii)
348
- alpha[i] = (alpha[i] - g / qd[i]).max(0.0);
349
-
350
- // Update w: w += (α_new - α_old) * y_i * x_i
351
- let d = (alpha[i] - alpha_old) * yi;
352
- if d.abs() > 1e-12 {
353
- for &(j, v) in xi.iter() {
354
- w[j as usize] += d * v;
355
- }
356
- }
357
- }
358
- }
359
-
360
- // Stopping criterion
361
- if max_violation <= eps {
362
- break;
363
- }
364
- }
365
-
366
- // Compute bias from KKT conditions
367
- // For α_i > 0: y_i * (w · x_i + b) = 1 - α_i / (2C)
368
- let mut bias_sum = 0.0f32;
369
- let mut n_sv = 0;
370
-
371
- for i in 0..n {
372
- if alpha[i] > 1e-8 {
373
- let yi = y[i] as f32;
374
- let wxi: f32 = x[i].iter().map(|&(j, v)| w[j as usize] * v).sum();
375
- // b = y_i * (1 - α_i * diag) - w · x_i
376
- bias_sum += yi * (1.0 - alpha[i] * diag) - wxi;
377
- n_sv += 1;
378
- }
379
- }
380
-
381
- let bias = if n_sv > 0 { bias_sum / n_sv as f32 } else { 0.0 };
382
-
383
- (w, bias)
384
- }
385
-
386
- /// Fast SVM using Pegasos algorithm
387
- #[pyclass]
388
- pub struct FastSVMTrainer {
389
- c: f64,
390
- max_iter: usize,
391
- }
392
-
393
- #[pymethods]
394
- impl FastSVMTrainer {
395
- #[new]
396
- #[pyo3(signature = (c=1.0, max_iter=100))]
397
- pub fn new(c: f64, max_iter: usize) -> Self {
398
- Self { c, max_iter }
399
- }
400
-
401
- pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
402
- let n_samples = features.len();
403
- let n_features = if n_samples > 0 { features[0].len() } else { 0 };
404
-
405
- let sparse_features: Vec<SparseVec> = features
406
- .par_iter()
407
- .map(|dense| {
408
- dense
409
- .iter()
410
- .enumerate()
411
- .filter(|&(_, &v)| v.abs() > 1e-10)
412
- .map(|(i, &v)| (i as u32, v as f32))
413
- .collect()
414
- })
415
- .collect();
416
-
417
- let mut classes: Vec<String> = labels.iter().cloned().collect();
418
- classes.sort();
419
- classes.dedup();
420
- let n_classes = classes.len();
421
-
422
- let class_to_idx: HashMap<String, usize> = classes
423
- .iter()
424
- .enumerate()
425
- .map(|(i, c)| (c.clone(), i))
426
- .collect();
427
-
428
- let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
429
-
430
- let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
431
- .into_par_iter()
432
- .map(|class_idx| {
433
- let y_binary: Vec<i8> = y_idx
434
- .iter()
435
- .map(|&idx| if idx == class_idx { 1 } else { -1 })
436
- .collect();
437
-
438
- pegasos(&sparse_features, &y_binary, n_features, self.c as f32, self.max_iter)
439
- })
440
- .collect();
441
-
442
- LinearSVM {
443
- weights: results.iter().map(|(w, _)| w.clone()).collect(),
444
- biases: results.iter().map(|(_, b)| *b).collect(),
445
- classes,
446
- n_features,
447
- }
448
- }
449
- }
450
-
451
- /// Pegasos algorithm with lazy scaling
452
- #[inline(never)]
453
- fn pegasos(
454
- x: &[SparseVec],
455
- y: &[i8],
456
- n_features: usize,
457
- c: f32,
458
- max_iter: usize,
459
- ) -> (Vec<f32>, f32) {
460
- let n = x.len();
461
- let lambda = 1.0 / c;
462
-
463
- let mut w = vec![0.0f32; n_features];
464
- let mut scale = 1.0f32;
465
- let mut b = 0.0f32;
466
-
467
- let eta0 = 0.5;
468
- let t0 = 1.0 / (eta0 * lambda);
469
-
470
- let mut indices: Vec<usize> = (0..n).collect();
471
-
472
- for epoch in 0..max_iter {
473
- // Shuffle
474
- for i in 0..n {
475
- let j = (i + epoch * 1103515245 + 12345) % n;
476
- indices.swap(i, j);
477
- }
478
-
479
- for (t_inner, &i) in indices.iter().enumerate() {
480
- let t = (epoch * n + t_inner) as f32;
481
- let eta = 1.0 / (lambda * (t + t0));
482
-
483
- let yi = y[i] as f32;
484
- let xi = &x[i];
485
-
486
- let margin: f32 = scale * xi.iter().map(|&(j, v)| w[j as usize] * v).sum::<f32>() + b;
487
-
488
- scale *= 1.0 - eta * lambda;
489
-
490
- if scale < 1e-9 {
491
- for wj in w.iter_mut() {
492
- *wj *= scale;
493
- }
494
- scale = 1.0;
495
- }
496
-
497
- if yi * margin < 1.0 {
498
- let update = eta / scale;
499
- for &(j, v) in xi.iter() {
500
- w[j as usize] += update * yi * v;
501
- }
502
- b += eta * yi * 0.1;
503
- }
504
- }
505
- }
506
-
507
- for wj in w.iter_mut() {
508
- *wj *= scale;
509
- }
510
-
511
- (w, b)
512
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extensions/underthesea_core_extend/src/tfidf.rs DELETED
@@ -1,235 +0,0 @@
1
- //! TF-IDF Vectorizer implementation
2
-
3
- use hashbrown::HashMap;
4
- use pyo3::prelude::*;
5
- use rayon::prelude::*;
6
- use serde::{Deserialize, Serialize};
7
- use std::fs::File;
8
- use std::io::{BufReader, BufWriter};
9
-
10
- /// TF-IDF Vectorizer
11
- ///
12
- /// Converts text documents into TF-IDF feature vectors.
13
- #[pyclass]
14
- #[derive(Clone, Serialize, Deserialize)]
15
- pub struct TfIdfVectorizer {
16
- /// Vocabulary: word -> index
17
- vocab: HashMap<String, usize>,
18
- /// Inverse vocabulary: index -> word
19
- inv_vocab: Vec<String>,
20
- /// IDF values for each term
21
- idf: Vec<f64>,
22
- /// Number of documents used for fitting
23
- n_docs: usize,
24
- /// Maximum number of features
25
- max_features: usize,
26
- /// N-gram range (min, max)
27
- ngram_range: (usize, usize),
28
- /// Minimum document frequency
29
- min_df: usize,
30
- /// Maximum document frequency (as ratio)
31
- max_df: f64,
32
- /// Whether the vectorizer is fitted
33
- is_fitted: bool,
34
- }
35
-
36
- #[pymethods]
37
- impl TfIdfVectorizer {
38
- /// Create a new TfIdfVectorizer
39
- #[new]
40
- #[pyo3(signature = (max_features=20000, ngram_range=(1, 2), min_df=1, max_df=1.0))]
41
- pub fn new(
42
- max_features: usize,
43
- ngram_range: (usize, usize),
44
- min_df: usize,
45
- max_df: f64,
46
- ) -> Self {
47
- Self {
48
- vocab: HashMap::new(),
49
- inv_vocab: Vec::new(),
50
- idf: Vec::new(),
51
- n_docs: 0,
52
- max_features,
53
- ngram_range,
54
- min_df,
55
- max_df,
56
- is_fitted: false,
57
- }
58
- }
59
-
60
- /// Fit the vectorizer on a list of documents
61
- pub fn fit(&mut self, documents: Vec<String>) {
62
- let n_docs = documents.len();
63
- self.n_docs = n_docs;
64
-
65
- // Count document frequency for each term
66
- let mut df: HashMap<String, usize> = HashMap::new();
67
-
68
- for doc in &documents {
69
- let tokens = self.tokenize(doc);
70
- let unique_tokens: std::collections::HashSet<_> = tokens.into_iter().collect();
71
- for token in unique_tokens {
72
- *df.entry(token).or_insert(0) += 1;
73
- }
74
- }
75
-
76
- // Filter by min_df and max_df
77
- let max_df_count = (self.max_df * n_docs as f64) as usize;
78
- let mut filtered: Vec<(String, usize)> = df
79
- .into_iter()
80
- .filter(|(_, count)| *count >= self.min_df && *count <= max_df_count)
81
- .collect();
82
-
83
- // Sort by frequency (descending) and take top max_features
84
- filtered.sort_by(|a, b| b.1.cmp(&a.1));
85
- filtered.truncate(self.max_features);
86
-
87
- // Build vocabulary
88
- self.vocab.clear();
89
- self.inv_vocab.clear();
90
- self.idf.clear();
91
-
92
- for (idx, (term, doc_freq)) in filtered.into_iter().enumerate() {
93
- self.vocab.insert(term.clone(), idx);
94
- self.inv_vocab.push(term);
95
- // IDF with smoothing: log((n_docs + 1) / (df + 1)) + 1
96
- let idf_value = ((n_docs as f64 + 1.0) / (doc_freq as f64 + 1.0)).ln() + 1.0;
97
- self.idf.push(idf_value);
98
- }
99
-
100
- self.is_fitted = true;
101
- }
102
-
103
- /// Transform a single document to TF-IDF vector (sparse format)
104
- ///
105
- /// Returns list of (index, value) tuples
106
- pub fn transform(&self, document: &str) -> Vec<(usize, f64)> {
107
- if !self.is_fitted {
108
- return Vec::new();
109
- }
110
-
111
- let tokens = self.tokenize(document);
112
- let mut tf: HashMap<usize, usize> = HashMap::new();
113
-
114
- for token in &tokens {
115
- if let Some(&idx) = self.vocab.get(token) {
116
- *tf.entry(idx).or_insert(0) += 1;
117
- }
118
- }
119
-
120
- let n_tokens = tokens.len() as f64;
121
- if n_tokens == 0.0 {
122
- return Vec::new();
123
- }
124
-
125
- let mut result: Vec<(usize, f64)> = tf
126
- .into_iter()
127
- .map(|(idx, count)| {
128
- let tf_value = count as f64 / n_tokens;
129
- let tfidf = tf_value * self.idf[idx];
130
- (idx, tfidf)
131
- })
132
- .collect();
133
-
134
- // L2 normalize
135
- let norm: f64 = result.iter().map(|(_, v)| v * v).sum::<f64>().sqrt();
136
- if norm > 0.0 {
137
- for (_, v) in &mut result {
138
- *v /= norm;
139
- }
140
- }
141
-
142
- result.sort_by_key(|(idx, _)| *idx);
143
- result
144
- }
145
-
146
- /// Transform a single document to dense TF-IDF vector
147
- pub fn transform_dense(&self, document: &str) -> Vec<f64> {
148
- let sparse = self.transform(document);
149
- let mut dense = vec![0.0; self.vocab.len()];
150
- for (idx, val) in sparse {
151
- dense[idx] = val;
152
- }
153
- dense
154
- }
155
-
156
- /// Transform multiple documents to dense TF-IDF vectors (parallel)
157
- pub fn transform_batch(&self, documents: Vec<String>) -> Vec<Vec<f64>> {
158
- documents
159
- .par_iter()
160
- .map(|doc| self.transform_dense(doc))
161
- .collect()
162
- }
163
-
164
- /// Transform multiple documents to sparse TF-IDF vectors (parallel)
165
- pub fn transform_batch_sparse(&self, documents: Vec<String>) -> Vec<Vec<(usize, f64)>> {
166
- documents
167
- .par_iter()
168
- .map(|doc| self.transform(doc))
169
- .collect()
170
- }
171
-
172
- /// Fit and transform in one step
173
- pub fn fit_transform(&mut self, documents: Vec<String>) -> Vec<Vec<f64>> {
174
- self.fit(documents.clone());
175
- self.transform_batch(documents)
176
- }
177
-
178
- /// Get vocabulary size
179
- #[getter]
180
- pub fn vocab_size(&self) -> usize {
181
- self.vocab.len()
182
- }
183
-
184
- /// Get feature names (vocabulary terms)
185
- pub fn get_feature_names(&self) -> Vec<String> {
186
- self.inv_vocab.clone()
187
- }
188
-
189
- /// Check if vectorizer is fitted
190
- #[getter]
191
- pub fn is_fitted(&self) -> bool {
192
- self.is_fitted
193
- }
194
-
195
- /// Save vectorizer to file
196
- pub fn save(&self, path: &str) -> PyResult<()> {
197
- let file = File::create(path)
198
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
199
- let writer = BufWriter::new(file);
200
- serde_json::to_writer(writer, self)
201
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
202
- Ok(())
203
- }
204
-
205
- /// Load vectorizer from file
206
- #[staticmethod]
207
- pub fn load(path: &str) -> PyResult<Self> {
208
- let file = File::open(path)
209
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
210
- let reader = BufReader::new(file);
211
- let vectorizer: Self = serde_json::from_reader(reader)
212
- .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
213
- Ok(vectorizer)
214
- }
215
- }
216
-
217
- impl TfIdfVectorizer {
218
- /// Tokenize document into n-grams
219
- fn tokenize(&self, document: &str) -> Vec<String> {
220
- let words: Vec<&str> = document.split_whitespace().collect();
221
- let mut tokens = Vec::new();
222
-
223
- for n in self.ngram_range.0..=self.ngram_range.1 {
224
- if n > words.len() {
225
- continue;
226
- }
227
- for i in 0..=(words.len() - n) {
228
- let ngram = words[i..i + n].join(" ");
229
- tokens.push(ngram);
230
- }
231
- }
232
-
233
- tokens
234
- }
235
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extensions/underthesea_core_extend/uv.lock DELETED
@@ -1,8 +0,0 @@
1
- version = 1
2
- revision = 3
3
- requires-python = ">=3.10"
4
-
5
- [[package]]
6
- name = "underthesea-core-extend"
7
- version = "0.1.0"
8
- source = { editable = "." }
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -1,7 +1,7 @@
1
  [project]
2
  name = "sen"
3
  version = "1.1.0"
4
- description = "Vietnamese Text Classification Model - Rust-powered"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  license = "Apache-2.0"
@@ -10,14 +10,16 @@ authors = [
10
  ]
11
  keywords = ["vietnamese", "nlp", "text-classification", "rust", "svm"]
12
  dependencies = [
13
- "underthesea_core_extend>=0.1.0",
 
14
  ]
15
 
16
  [project.optional-dependencies]
17
  dev = [
18
  "pytest>=7.0.0",
19
  "huggingface-hub>=0.20.0",
20
- "maturin>=1.0.0",
 
21
  ]
22
 
23
  [project.urls]
@@ -27,6 +29,3 @@ Repository = "https://github.com/undertheseanlp/sen"
27
  [build-system]
28
  requires = ["hatchling"]
29
  build-backend = "hatchling.build"
30
-
31
- [tool.hatch.build.targets.wheel]
32
- packages = ["src/sen"]
 
1
  [project]
2
  name = "sen"
3
  version = "1.1.0"
4
+ description = "Vietnamese Text Classification - Training scripts for underthesea_core"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  license = "Apache-2.0"
 
10
  ]
11
  keywords = ["vietnamese", "nlp", "text-classification", "rust", "svm"]
12
  dependencies = [
13
+ "underthesea>=6.0.0",
14
+ "click>=8.0.0",
15
  ]
16
 
17
  [project.optional-dependencies]
18
  dev = [
19
  "pytest>=7.0.0",
20
  "huggingface-hub>=0.20.0",
21
+ "scikit-learn>=1.0.0",
22
+ "datasets>=2.0.0",
23
  ]
24
 
25
  [project.urls]
 
29
  [build-system]
30
  requires = ["hatchling"]
31
  build-backend = "hatchling.build"
 
 
 
src/bench.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark CLI for Vietnamese Text Classification.
3
+
4
+ Compares Rust TextClassifier vs sklearn.
5
+
6
+ Usage:
7
+ python bench.py vntc
8
+ python bench.py bank
9
+ python bench.py synthetic
10
+ """
11
+
12
+ import os
13
+ import time
14
+ import random
15
+ from pathlib import Path
16
+
17
+ import click
18
+ from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVectorizer
19
+ from sklearn.svm import LinearSVC as SklearnLinearSVC
20
+ from sklearn.metrics import accuracy_score, f1_score, classification_report
21
+
22
+ from underthesea_core import TextClassifier
23
+
24
+
25
+ def read_file(filepath):
26
+ """Read text file with multiple encoding attempts."""
27
+ for enc in ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']:
28
+ try:
29
+ with open(filepath, 'r', encoding=enc) as f:
30
+ text = ' '.join(f.read().split())
31
+ if len(text) > 10:
32
+ return text
33
+ except (UnicodeDecodeError, UnicodeError):
34
+ continue
35
+ return None
36
+
37
+
38
+ def benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=20000):
39
+ """Benchmark scikit-learn TF-IDF + LinearSVC."""
40
+ click.echo("\n" + "=" * 70)
41
+ click.echo("scikit-learn: TfidfVectorizer + LinearSVC")
42
+ click.echo("=" * 70)
43
+
44
+ # Vectorize
45
+ click.echo(" Vectorizing...")
46
+ t0 = time.perf_counter()
47
+ vectorizer = SklearnTfidfVectorizer(max_features=max_features, ngram_range=(1, 2), min_df=2)
48
+ X_train = vectorizer.fit_transform(train_texts)
49
+ X_test = vectorizer.transform(test_texts)
50
+ vec_time = time.perf_counter() - t0
51
+ click.echo(f" Vectorization time: {vec_time:.2f}s")
52
+ click.echo(f" Vocabulary size: {len(vectorizer.vocabulary_)}")
53
+
54
+ # Train
55
+ click.echo(" Training LinearSVC...")
56
+ t0 = time.perf_counter()
57
+ clf = SklearnLinearSVC(C=1.0, max_iter=2000)
58
+ clf.fit(X_train, train_labels)
59
+ train_time = time.perf_counter() - t0
60
+ click.echo(f" Training time: {train_time:.2f}s")
61
+
62
+ # End-to-end inference
63
+ click.echo(" End-to-end inference...")
64
+ t0 = time.perf_counter()
65
+ X_test_e2e = vectorizer.transform(test_texts)
66
+ preds = clf.predict(X_test_e2e)
67
+ e2e_time = time.perf_counter() - t0
68
+ e2e_throughput = len(test_texts) / e2e_time
69
+ click.echo(f" E2E time: {e2e_time:.2f}s ({e2e_throughput:.0f} samples/sec)")
70
+
71
+ # Metrics
72
+ acc = accuracy_score(test_labels, preds)
73
+ f1_w = f1_score(test_labels, preds, average='weighted')
74
+ click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")
75
+
76
+ return {
77
+ "total_train": vec_time + train_time,
78
+ "e2e_throughput": e2e_throughput,
79
+ "accuracy": acc,
80
+ "f1_weighted": f1_w,
81
+ }
82
+
83
+
84
+ def benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=20000):
85
+ """Benchmark Rust TextClassifier."""
86
+ click.echo("\n" + "=" * 70)
87
+ click.echo("Rust: TextClassifier (underthesea_core)")
88
+ click.echo("=" * 70)
89
+
90
+ clf = TextClassifier(
91
+ max_features=max_features,
92
+ ngram_range=(1, 2),
93
+ min_df=2,
94
+ c=1.0,
95
+ max_iter=1000,
96
+ tol=0.1,
97
+ )
98
+
99
+ # Train
100
+ click.echo(" Training...")
101
+ t0 = time.perf_counter()
102
+ clf.fit(list(train_texts), list(train_labels))
103
+ train_time = time.perf_counter() - t0
104
+ click.echo(f" Training time: {train_time:.2f}s")
105
+ click.echo(f" Vocabulary size: {clf.n_features}")
106
+
107
+ # Inference
108
+ click.echo(" Inference...")
109
+ t0 = time.perf_counter()
110
+ preds = clf.predict_batch(list(test_texts))
111
+ infer_time = time.perf_counter() - t0
112
+ throughput = len(test_texts) / infer_time
113
+ click.echo(f" Inference time: {infer_time:.2f}s ({throughput:.0f} samples/sec)")
114
+
115
+ # Metrics
116
+ acc = accuracy_score(test_labels, preds)
117
+ f1_w = f1_score(test_labels, preds, average='weighted')
118
+ click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")
119
+
120
+ return {
121
+ "total_train": train_time,
122
+ "throughput": throughput,
123
+ "accuracy": acc,
124
+ "f1_weighted": f1_w,
125
+ "clf": clf,
126
+ "preds": preds,
127
+ }
128
+
129
+
130
+ def print_comparison(sklearn_results, rust_results):
131
+ """Print comparison summary."""
132
+ click.echo("\n" + "=" * 70)
133
+ click.echo("COMPARISON SUMMARY")
134
+ click.echo("=" * 70)
135
+ click.echo(f"{'Metric':<30} {'sklearn':<20} {'Rust':<20}")
136
+ click.echo("-" * 70)
137
+
138
+ click.echo(f"{'Training time (s)':<30} {sklearn_results['total_train']:<20.2f} {rust_results['total_train']:<20.2f}")
139
+ click.echo(f"{'Inference (samples/sec)':<30} {sklearn_results['e2e_throughput']:<20.0f} {rust_results['throughput']:<20.0f}")
140
+ click.echo(f"{'Accuracy':<30} {sklearn_results['accuracy']:<20.4f} {rust_results['accuracy']:<20.4f}")
141
+ click.echo(f"{'F1 (weighted)':<30} {sklearn_results['f1_weighted']:<20.4f} {rust_results['f1_weighted']:<20.4f}")
142
+
143
+ click.echo("-" * 70)
144
+ train_speedup = sklearn_results['total_train'] / rust_results['total_train'] if rust_results['total_train'] > 0 else 0
145
+ infer_speedup = rust_results['throughput'] / sklearn_results['e2e_throughput'] if sklearn_results['e2e_throughput'] > 0 else 0
146
+ click.echo(f"Speedup: Training {train_speedup:.2f}x, Inference {infer_speedup:.2f}x")
147
+ click.echo("=" * 70)
148
+
149
+
150
+ @click.group()
151
+ def cli():
152
+ """Benchmark Vietnamese text classification models."""
153
+ pass
154
+
155
+
156
+ @cli.command()
157
+ @click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
158
+ help='Path to VNTC dataset')
159
+ @click.option('--save-model', is_flag=True, help='Save the trained Rust model')
160
+ @click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
161
+ def vntc(data_dir, save_model, output):
162
+ """Benchmark on VNTC dataset (10 topics, ~84k documents)."""
163
+ click.echo("=" * 70)
164
+ click.echo("VNTC Full Dataset Benchmark")
165
+ click.echo("Vietnamese News Text Classification (10 Topics)")
166
+ click.echo("=" * 70)
167
+
168
+ train_dir = os.path.join(data_dir, "Train_Full")
169
+ test_dir = os.path.join(data_dir, "Test_Full")
170
+
171
+ # Load data
172
+ click.echo("\nLoading training data...")
173
+ t0 = time.perf_counter()
174
+ train_texts, train_labels = [], []
175
+ for folder in sorted(os.listdir(train_dir)):
176
+ folder_path = os.path.join(train_dir, folder)
177
+ if not os.path.isdir(folder_path):
178
+ continue
179
+ for fname in os.listdir(folder_path):
180
+ if fname.endswith('.txt'):
181
+ text = read_file(os.path.join(folder_path, fname))
182
+ if text:
183
+ train_texts.append(text)
184
+ train_labels.append(folder)
185
+ click.echo(f" Loaded {len(train_texts)} training samples in {time.perf_counter()-t0:.1f}s")
186
+
187
+ click.echo("Loading test data...")
188
+ t0 = time.perf_counter()
189
+ test_texts, test_labels = [], []
190
+ for folder in sorted(os.listdir(test_dir)):
191
+ folder_path = os.path.join(test_dir, folder)
192
+ if not os.path.isdir(folder_path):
193
+ continue
194
+ for fname in os.listdir(folder_path):
195
+ if fname.endswith('.txt'):
196
+ text = read_file(os.path.join(folder_path, fname))
197
+ if text:
198
+ test_texts.append(text)
199
+ test_labels.append(folder)
200
+ click.echo(f" Loaded {len(test_texts)} test samples in {time.perf_counter()-t0:.1f}s")
201
+
202
+ # Run benchmarks
203
+ sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels)
204
+ rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels)
205
+
206
+ print_comparison(sklearn_results, rust_results)
207
+
208
+ if save_model:
209
+ model_path = Path(output)
210
+ model_path.parent.mkdir(parents=True, exist_ok=True)
211
+ rust_results['clf'].save(str(model_path))
212
+ size_mb = model_path.stat().st_size / (1024 * 1024)
213
+ click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
214
+
215
+
216
+ @cli.command()
217
+ @click.option('--save-model', is_flag=True, help='Save the trained Rust model')
218
+ @click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
219
+ def bank(save_model, output):
220
+ """Benchmark on UTS2017_Bank dataset (14 categories, banking domain)."""
221
+ from datasets import load_dataset
222
+
223
+ click.echo("=" * 70)
224
+ click.echo("UTS2017_Bank Dataset Benchmark")
225
+ click.echo("Vietnamese Banking Domain Text Classification (14 Categories)")
226
+ click.echo("=" * 70)
227
+
228
+ # Load data
229
+ click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
230
+ dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
231
+
232
+ train_texts = list(dataset["train"]["text"])
233
+ train_labels = list(dataset["train"]["label"])
234
+ test_texts = list(dataset["test"]["text"])
235
+ test_labels = list(dataset["test"]["label"])
236
+
237
+ click.echo(f" Train samples: {len(train_texts)}")
238
+ click.echo(f" Test samples: {len(test_texts)}")
239
+ click.echo(f" Categories: {len(set(train_labels))}")
240
+
241
+ # Run benchmarks (smaller max_features for smaller dataset)
242
+ sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
243
+ rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)
244
+
245
+ print_comparison(sklearn_results, rust_results)
246
+
247
+ click.echo("\nClassification Report (Rust):")
248
+ click.echo(classification_report(test_labels, rust_results['preds']))
249
+
250
+ if save_model:
251
+ model_path = Path(output)
252
+ model_path.parent.mkdir(parents=True, exist_ok=True)
253
+ rust_results['clf'].save(str(model_path))
254
+ size_mb = model_path.stat().st_size / (1024 * 1024)
255
+ click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
256
+
257
+
258
+ @cli.command()
259
+ @click.option('--train-per-cat', default=340, help='Training samples per category')
260
+ @click.option('--test-per-cat', default=500, help='Test samples per category')
261
+ @click.option('--seed', default=42, help='Random seed')
262
+ def synthetic(train_per_cat, test_per_cat, seed):
263
+ """Benchmark on synthetic VNTC-like data."""
264
+ # Vietnamese text templates by category
265
+ TEMPLATES = {
266
+ "the_thao": ["Đội tuyển {} thắng {} với tỷ số {}", "Cầu thủ {} ghi bàn đẹp mắt"],
267
+ "kinh_doanh": ["Chứng khoán {} điểm trong phiên giao dịch", "Ngân hàng {} công bố lãi suất {}"],
268
+ "cong_nghe": ["Apple ra mắt {} với nhiều tính năng", "Trí tuệ nhân tạo đang thay đổi {}"],
269
+ "chinh_tri": ["Quốc hội thông qua nghị quyết về {}", "Chủ tịch {} tiếp đón phái đoàn"],
270
+ "van_hoa": ["Nghệ sĩ {} ra mắt album mới", "Liên hoan phim {} trao giải"],
271
+ "khoa_hoc": ["Nhà khoa học phát hiện {} mới", "Nghiên cứu cho thấy {} có tác dụng"],
272
+ "suc_khoe": ["Bộ Y tế cảnh báo về {} trong mùa", "Vaccine {} đạt hiệu quả cao"],
273
+ "giao_duc": ["Trường {} công bố điểm chuẩn", "Học sinh đoạt huy chương tại Olympic"],
274
+ "phap_luat": ["Tòa án xét xử vụ án {} với bị cáo", "Công an triệt phá đường dây"],
275
+ "doi_song": ["Giá {} tăng trong tháng", "Người dân đổ xô đi mua {}"],
276
+ }
277
+ FILLS = {
278
+ "the_thao": ["Việt Nam", "Thái Lan", "3-0", "bóng đá", "AFF Cup"],
279
+ "kinh_doanh": ["tăng", "giảm", "VN-Index", "Vietcombank", "8%"],
280
+ "cong_nghe": ["iPhone 16", "ChatGPT", "công việc", "VinAI", "5G"],
281
+ "chinh_tri": ["kinh tế", "nước", "Trung Quốc", "Hà Nội", "phát triển"],
282
+ "van_hoa": ["Mỹ Tâm", "Cannes", "nghệ thuật", "Hà Nội", "Bố Già"],
283
+ "khoa_hoc": ["loài sinh vật", "trà xanh", "VNREDSat-1", "Nobel", "robot"],
284
+ "suc_khoe": ["dịch cúm", "COVID-19", "Bạch Mai", "dinh dưỡng", "tiểu đường"],
285
+ "giao_duc": ["Bách Khoa", "Việt Nam", "Toán", "THPT", "STEM"],
286
+ "phap_luat": ["tham nhũng", "TP.HCM", "ma túy", "Hình sự", "gian lận"],
287
+ "doi_song": ["xăng", "vàng", "nắng nóng", "Trung thu", "bún chả"],
288
+ }
289
+
290
+ def generate_sample(category):
291
+ template = random.choice(TEMPLATES[category])
292
+ fills = FILLS[category]
293
+ n = template.count("{}")
294
+ return template.format(*random.choices(fills, k=n))
295
+
296
+ def generate_dataset(n_per_cat, categories):
297
+ texts, labels = [], []
298
+ for cat in categories:
299
+ for _ in range(n_per_cat):
300
+ texts.append(generate_sample(cat))
301
+ labels.append(cat)
302
+ combined = list(zip(texts, labels))
303
+ random.shuffle(combined)
304
+ return [t for t, _ in combined], [l for _, l in combined]
305
+
306
+ click.echo("=" * 70)
307
+ click.echo("Synthetic VNTC-like Benchmark")
308
+ click.echo("=" * 70)
309
+
310
+ random.seed(seed)
311
+ categories = list(TEMPLATES.keys())
312
+
313
+ click.echo(f"\nConfiguration:")
314
+ click.echo(f" Categories: {len(categories)}")
315
+ click.echo(f" Train samples: {train_per_cat * len(categories)}")
316
+ click.echo(f" Test samples: {test_per_cat * len(categories)}")
317
+
318
+ train_texts, train_labels = generate_dataset(train_per_cat, categories)
319
+ test_texts, test_labels = generate_dataset(test_per_cat, categories)
320
+
321
+ sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
322
+ rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)
323
+
324
+ print_comparison(sklearn_results, rust_results)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ cli()
src/scripts/train.py DELETED
@@ -1,221 +0,0 @@
1
- """
2
- Train Sen Text Classifier on Vietnamese news data.
3
-
4
- Categories based on VNTC dataset:
5
- - Chinh tri Xa hoi (Politics/Society)
6
- - Doi song (Lifestyle)
7
- - Khoa hoc (Science)
8
- - Kinh doanh (Business)
9
- - Phap luat (Law)
10
- - Suc khoe (Health)
11
- - The gioi (World)
12
- - The thao (Sports)
13
- - Van hoa (Culture)
14
- - Vi tinh (Technology)
15
- """
16
-
17
- import json
18
- import os
19
- import sys
20
-
21
- sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
22
-
23
- from sen import SenTextClassifier
24
-
25
-
26
- # Sample Vietnamese news data for each category
27
- SAMPLE_DATA = {
28
- "chinh_tri_xa_hoi": [
29
- "Quốc hội thông qua nghị quyết về phát triển kinh tế xã hội",
30
- "Thủ tướng chủ trì họp Chính phủ thường kỳ tháng này",
31
- "Đại hội Đảng toàn quốc lần thứ XIII thành công tốt đẹp",
32
- "Chủ tịch nước tiếp đoàn đại biểu quốc tế",
33
- "Bộ Nội vụ triển khai cải cách hành chính",
34
- "Ủy ban Mặt trận Tổ quốc tổ chức hội nghị toàn quốc",
35
- "Đoàn thanh niên phát động phong trào tình nguyện",
36
- "Hội Liên hiệp Phụ nữ tổ chức đại hội đại biểu",
37
- ],
38
- "doi_song": [
39
- "Mẹo hay giúp tiết kiệm chi phí sinh hoạt hàng ngày",
40
- "Xu hướng thời trang mới nhất mùa thu đông năm nay",
41
- "Cách trang trí nhà cửa đón Tết đẹp và tiết kiệm",
42
- "Bí quyết nấu ăn ngon cho cả gia đình",
43
- "Kinh nghiệm du lịch Đà Nẵng tiết kiệm chi phí",
44
- "Cách chăm sóc cây cảnh trong nhà hiệu quả",
45
- "Chia sẻ cách dạy con học tập hiệu quả",
46
- "Mẹo vặt hay cho cuộc sống hàng ngày",
47
- ],
48
- "khoa_hoc": [
49
- "Các nhà khoa học phát hiện hành tinh mới ngoài hệ mặt trời",
50
- "Nghiên cứu mới về biến đổi khí hậu toàn cầu",
51
- "Vệ tinh nhân tạo được phóng thành công lên quỹ đạo",
52
- "Khám phá mới về nguồn gốc vũ trụ",
53
- "Công nghệ nano ứng dụng trong y học",
54
- "Phát hiện loài động vật mới ở rừng Amazon",
55
- "Nghiên cứu về trí tuệ nhân tạo và học máy",
56
- "Thí nghiệm vật lý hạt nhân tại CERN",
57
- ],
58
- "kinh_doanh": [
59
- "Chứng khoán Việt Nam tăng điểm mạnh phiên đầu tuần",
60
- "Ngân hàng Nhà nước điều chỉnh lãi suất điều hành",
61
- "Doanh nghiệp xuất khẩu gặp khó khăn do tỷ giá",
62
- "Thị trường bất động sản có dấu hiệu phục hồi",
63
- "VN-Index vượt mốc 1200 điểm trong phiên giao dịch",
64
- "FDI vào Việt Nam tăng trưởng ấn tượng",
65
- "Giá vàng thế giới biến động mạnh trong tuần",
66
- "Startup công nghệ Việt gọi vốn thành công Series A",
67
- ],
68
- "phap_luat": [
69
- "Tòa án xét xử vụ án tham nhũng lớn",
70
- "Công an triệt phá đường dây buôn lậu xuyên quốc gia",
71
- "Luật mới về bảo vệ môi trường có hiệu lực",
72
- "Khởi tố vụ án lừa đảo chiếm đoạt tài sản",
73
- "Bộ Công an cảnh báo thủ đoạn lừa đảo qua mạng",
74
- "Tòa án tuyên án vụ vi phạm an toàn giao thông",
75
- "Viện Kiểm sát truy tố các bị cáo trong vụ án kinh tế",
76
- "Cảnh sát giao thông xử lý vi phạm nồng độ cồn",
77
- ],
78
- "suc_khoe": [
79
- "Bộ Y tế khuyến cáo phòng chống dịch bệnh mùa đông",
80
- "Phát hiện phương pháp điều trị ung thư mới",
81
- "Cách phòng tránh các bệnh về đường hô hấp",
82
- "Chế độ ăn uống lành mạnh cho người cao tuổi",
83
- "Vaccine mới được phê duyệt sử dụng tại Việt Nam",
84
- "Bệnh viện Bạch Mai ứng dụng kỹ thuật mổ nội soi",
85
- "Cách chăm sóc sức khỏe tinh thần hiệu quả",
86
- "Tập thể dục đúng cách để có sức khỏe tốt",
87
- ],
88
- "the_gioi": [
89
- "Tổng thống Mỹ công bố chính sách đối ngoại mới",
90
- "Hội nghị thượng đỉnh G20 thảo luận về biến đổi khí hậu",
91
- "Xung đột vũ trang leo thang tại Trung Đông",
92
- "Liên Hợp Quốc họp khẩn về tình hình nhân đạo",
93
- "Châu Âu đối mặt với khủng hoảng năng lượng",
94
- "Trung Quốc công bố số liệu tăng trưởng kinh tế",
95
- "Nhật Bản bầu cử thủ tướng mới",
96
- "Nga và Ukraine tiếp tục đàm phán hòa bình",
97
- ],
98
- "the_thao": [
99
- "Đội tuyển Việt Nam thắng đậm trong trận giao hữu",
100
- "Cầu thủ Nguyễn Quang Hải ghi bàn đẹp mắt",
101
- "V-League 2024 khởi tranh vào tháng tới",
102
- "HLV Park Hang-seo chia tay bóng đá Việt Nam",
103
- "U23 Việt Nam vô địch giải Đông Nam Á",
104
- "Hoàng Xuân Vinh giành huy chương vàng Olympic",
105
- "Ánh Viên phá kỷ lục quốc gia môn bơi lội",
106
- "SEA Games 31 tổ chức thành công tại Việt Nam",
107
- ],
108
- "van_hoa": [
109
- "Liên hoan phim Việt Nam lần thứ 23 khai mạc",
110
- "Nghệ sĩ nhân dân được phong tặng danh hiệu cao quý",
111
- "Triển lãm tranh của họa sĩ nổi tiếng tại Hà Nội",
112
- "Ca sĩ Việt Nam giành giải thưởng âm nhạc châu Á",
113
- "Lễ hội truyền thống thu hút đông đảo du khách",
114
- "Phim Việt Nam được đề cử tại liên hoan phim quốc tế",
115
- "Nhạc sĩ sáng tác ca khúc mới về quê hương",
116
- "Bảo tàng lịch sử khai trương triển lãm mới",
117
- ],
118
- "vi_tinh": [
119
- "Apple ra mắt iPhone mới với nhiều tính năng",
120
- "Trí tuệ nhân tạo đang thay đổi cuộc sống",
121
- "Samsung công bố điện thoại gập thế hệ mới",
122
- "Microsoft phát hành bản cập nhật Windows",
123
- "ChatGPT và cuộc cách mạng trí tuệ nhân tạo",
124
- "5G được triển khai rộng rãi tại các thành phố lớn",
125
- "Ứng dụng di động phổ biến nhất trong năm",
126
- "An ninh mạng trước nguy cơ tấn công hacker",
127
- ],
128
- }
129
-
130
-
131
- def prepare_data():
132
- """Prepare training and validation data."""
133
- train_texts = []
134
- train_labels = []
135
- val_texts = []
136
- val_labels = []
137
-
138
- for label, texts in SAMPLE_DATA.items():
139
- # Use first 6 for training, last 2 for validation
140
- for text in texts[:6]:
141
- train_texts.append(text)
142
- train_labels.append(label)
143
- for text in texts[6:]:
144
- val_texts.append(text)
145
- val_labels.append(label)
146
-
147
- return train_texts, train_labels, val_texts, val_labels
148
-
149
-
150
- def main():
151
- print("=" * 60)
152
- print("Training Sen Text Classifier")
153
- print("Based on VNTC Vietnamese News Classification")
154
- print("=" * 60)
155
-
156
- # Prepare data
157
- train_texts, train_labels, val_texts, val_labels = prepare_data()
158
- print(f"\nDataset:")
159
- print(f" - Training samples: {len(train_texts)}")
160
- print(f" - Validation samples: {len(val_texts)}")
161
- print(f" - Categories: {len(SAMPLE_DATA)}")
162
-
163
- # Initialize classifier
164
- classifier = SenTextClassifier(
165
- max_features=5000,
166
- ngram_range=(1, 2),
167
- min_df=1,
168
- max_df=0.95,
169
- sublinear_tf=True,
170
- C=1.0,
171
- max_iter=1000,
172
- )
173
-
174
- # Train
175
- print("\n" + "=" * 60)
176
- print("Training...")
177
- print("=" * 60)
178
- results = classifier.train(
179
- train_texts=train_texts,
180
- train_labels=train_labels,
181
- val_texts=val_texts,
182
- val_labels=val_labels,
183
- )
184
-
185
- # Evaluate
186
- print("\n" + "=" * 60)
187
- print("Evaluation on validation set:")
188
- print("=" * 60)
189
- classifier.evaluate(val_texts, val_labels)
190
-
191
- # Test predictions
192
- print("\n" + "=" * 60)
193
- print("Sample Predictions:")
194
- print("=" * 60)
195
- test_texts = [
196
- "Đội tuyển bóng đá Việt Nam chiến thắng",
197
- "Giá vàng tăng mạnh trong phiên giao dịch",
198
- "Apple công bố sản phẩm mới tại sự kiện",
199
- "Bộ Y tế cảnh báo dịch cúm mùa",
200
- "Quốc hội họp phiên bất thường",
201
- ]
202
-
203
- for text in test_texts:
204
- from sen import Sentence
205
- sentence = Sentence(text)
206
- classifier.predict(sentence)
207
- print(f" '{text}' -> {sentence.labels[0]}")
208
-
209
- # Save model
210
- save_path = "/home/anhvu2/projects/workspace_underthesea/sen/trained_model"
211
- print(f"\n" + "=" * 60)
212
- print(f"Saving model to: {save_path}")
213
- print("=" * 60)
214
- classifier.save(save_path)
215
-
216
- print("\nTraining completed!")
217
- return classifier
218
-
219
-
220
- if __name__ == "__main__":
221
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/scripts/train_sonar.py DELETED
@@ -1,234 +0,0 @@
1
- """
2
- Reproduce sonar_core_1 training configuration.
3
- Uses CountVectorizer + TfidfTransformer + SVC pipeline.
4
- Target: 92.80% accuracy on VNTC (matching sonar_core_1)
5
- """
6
-
7
- import os
8
- import sys
9
- import time
10
- import json
11
- from datetime import datetime
12
-
13
- import numpy as np
14
- import joblib
15
- from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
16
- from sklearn.svm import SVC, LinearSVC
17
- from sklearn.linear_model import LogisticRegression
18
- from sklearn.pipeline import Pipeline
19
- from sklearn.metrics import accuracy_score, classification_report, f1_score
20
-
21
- sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
22
-
23
- # VNTC data paths
24
- VNTC_BASE = "/home/anhvu2/projects/workspace_underthesea/VNTC_github/Data/10Topics/Ver1.1"
25
- TRAIN_DIR = os.path.join(VNTC_BASE, "Train_Full")
26
- TEST_DIR = os.path.join(VNTC_BASE, "Test_Full")
27
-
28
- # Category mapping
29
- CATEGORY_MAP = {
30
- "Chinh tri Xa hoi": "Chinh tri Xa hoi",
31
- "Doi song": "Doi song",
32
- "Khoa hoc": "Khoa hoc",
33
- "Kinh doanh": "Kinh doanh",
34
- "Phap luat": "Phap luat",
35
- "Suc khoe": "Suc khoe",
36
- "The gioi": "The gioi",
37
- "The thao": "The thao",
38
- "Van hoa": "Van hoa",
39
- "Vi tinh": "Vi tinh",
40
- }
41
-
42
-
43
- def read_file(filepath):
44
- """Read text file with multiple encoding attempts."""
45
- encodings = ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']
46
- for encoding in encodings:
47
- try:
48
- with open(filepath, 'r', encoding=encoding) as f:
49
- text = f.read()
50
- text = ' '.join(text.split())
51
- if len(text) > 10:
52
- return text
53
- except (UnicodeDecodeError, UnicodeError):
54
- continue
55
- return None
56
-
57
-
58
- def load_vntc_data(data_dir):
59
- """Load VNTC data from directory."""
60
- texts = []
61
- labels = []
62
-
63
- for folder_name, label in CATEGORY_MAP.items():
64
- folder_path = os.path.join(data_dir, folder_name)
65
- if not os.path.exists(folder_path):
66
- print(f" Warning: {folder_path} not found")
67
- continue
68
-
69
- files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
70
- for filename in files:
71
- filepath = os.path.join(folder_path, filename)
72
- text = read_file(filepath)
73
- if text:
74
- texts.append(text)
75
- labels.append(label)
76
-
77
- return np.array(texts), np.array(labels)
78
-
79
-
80
- def train_sonar_config():
81
- """Train with sonar_core_1 configuration."""
82
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
83
-
84
- print("=" * 70)
85
- print("Reproducing sonar_core_1 Configuration")
86
- print("=" * 70)
87
-
88
- # Load data
89
- print("\n[1/4] Loading VNTC data...")
90
- start = time.time()
91
- X_train, y_train = load_vntc_data(TRAIN_DIR)
92
- X_test, y_test = load_vntc_data(TEST_DIR)
93
- print(f" Train: {len(X_train)} samples")
94
- print(f" Test: {len(X_test)} samples")
95
- print(f" Classes: {len(set(y_train))}")
96
- print(f" Load time: {time.time()-start:.1f}s")
97
-
98
- # sonar_core_1 configuration
99
- configs = [
100
- {
101
- "name": "SVC (sonar_core_1 config)",
102
- "max_features": 20000,
103
- "ngram_range": (1, 2),
104
- "classifier": SVC(kernel='linear', probability=True, random_state=42),
105
- },
106
- {
107
- "name": "LinearSVC (faster)",
108
- "max_features": 20000,
109
- "ngram_range": (1, 2),
110
- "classifier": LinearSVC(C=1.0, max_iter=2000, random_state=42),
111
- },
112
- {
113
- "name": "LogisticRegression",
114
- "max_features": 20000,
115
- "ngram_range": (1, 2),
116
- "classifier": LogisticRegression(max_iter=1000, random_state=42),
117
- },
118
- ]
119
-
120
- results = []
121
- best_model = None
122
- best_accuracy = 0
123
-
124
- for i, config in enumerate(configs):
125
- print(f"\n[{i+2}/4] Training: {config['name']}")
126
- print("-" * 50)
127
-
128
- # Create pipeline (sonar_core_1 style)
129
- pipeline = Pipeline([
130
- ('vect', CountVectorizer(
131
- max_features=config['max_features'],
132
- ngram_range=config['ngram_range']
133
- )),
134
- ('tfidf', TfidfTransformer(use_idf=True)),
135
- ('clf', config['classifier']),
136
- ])
137
-
138
- # Train
139
- start = time.time()
140
- pipeline.fit(X_train, y_train)
141
- train_time = time.time() - start
142
-
143
- # Evaluate
144
- train_pred = pipeline.predict(X_train)
145
- test_pred = pipeline.predict(X_test)
146
-
147
- train_acc = accuracy_score(y_train, train_pred)
148
- test_acc = accuracy_score(y_test, test_pred)
149
- test_f1 = f1_score(y_test, test_pred, average='weighted')
150
-
151
- print(f" Train accuracy: {train_acc:.4f}")
152
- print(f" Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
153
- print(f" Test F1: {test_f1:.4f}")
154
- print(f" Train time: {train_time:.1f}s")
155
-
156
- results.append({
157
- "name": config['name'],
158
- "train_acc": train_acc,
159
- "test_acc": test_acc,
160
- "test_f1": test_f1,
161
- "train_time": train_time,
162
- })
163
-
164
- if test_acc > best_accuracy:
165
- best_accuracy = test_acc
166
- best_model = pipeline
167
- best_config = config
168
-
169
- # Print comparison
170
- print("\n" + "=" * 70)
171
- print("Results Comparison")
172
- print("=" * 70)
173
- print(f"{'Model':<30} {'Test Acc':>10} {'Test F1':>10} {'Time':>10}")
174
- print("-" * 70)
175
- for r in results:
176
- print(f"{r['name']:<30} {r['test_acc']*100:>9.2f}% {r['test_f1']:>10.4f} {r['train_time']:>9.1f}s")
177
- print("-" * 70)
178
- print(f"sonar_core_1 reference: {92.80:>9.2f}%")
179
- print("=" * 70)
180
-
181
- # Save best model
182
- save_dir = "/home/anhvu2/projects/workspace_underthesea/sen/sen-general-1.0.0-20260202"
183
- os.makedirs(save_dir, exist_ok=True)
184
-
185
- # Save pipeline
186
- joblib.dump(best_model, os.path.join(save_dir, "pipeline.joblib"))
187
-
188
- # Save label encoder (for compatibility)
189
- from sklearn.preprocessing import LabelEncoder
190
- le = LabelEncoder()
191
- le.fit(y_train)
192
- joblib.dump(le, os.path.join(save_dir, "label_encoder.joblib"))
193
-
194
- # Save metadata
195
- metadata = {
196
- "model_type": "sonar_core_1_reproduction",
197
- "architecture": "CountVectorizer + TfidfTransformer + LinearSVC",
198
- "max_features": best_config['max_features'],
199
- "ngram_range": list(best_config['ngram_range']),
200
- "train_samples": len(X_train),
201
- "test_samples": len(X_test),
202
- "train_accuracy": float(results[1]['train_acc']), # LinearSVC
203
- "test_accuracy": float(results[1]['test_acc']),
204
- "test_f1_weighted": float(results[1]['test_f1']),
205
- "labels": sorted(list(set(y_train))),
206
- "timestamp": timestamp,
207
- }
208
-
209
- with open(os.path.join(save_dir, "metadata.json"), 'w') as f:
210
- json.dump(metadata, f, indent=2, ensure_ascii=False)
211
-
212
- print(f"\nBest model saved to: {save_dir}")
213
-
214
- # Print detailed classification report for best model
215
- print("\n" + "=" * 70)
216
- print("Classification Report (LinearSVC)")
217
- print("=" * 70)
218
-
219
- # Retrain LinearSVC for report
220
- pipeline = Pipeline([
221
- ('vect', CountVectorizer(max_features=20000, ngram_range=(1, 2))),
222
- ('tfidf', TfidfTransformer(use_idf=True)),
223
- ('clf', LinearSVC(C=1.0, max_iter=2000, random_state=42)),
224
- ])
225
- pipeline.fit(X_train, y_train)
226
- test_pred = pipeline.predict(X_test)
227
-
228
- print(classification_report(y_test, test_pred))
229
-
230
- return results
231
-
232
-
233
- if __name__ == "__main__":
234
- train_sonar_config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/scripts/train_vntc.py DELETED
@@ -1,181 +0,0 @@
1
- """
2
- Train Sen Text Classifier on full VNTC dataset.
3
-
4
- VNTC: Vietnamese News Text Classification Corpus
5
- - 10 Topics, ~33,759 train / ~50,373 test documents
6
- - Reference: Vu et al. (2007) RIVF
7
- """
8
-
9
- import os
10
- import sys
11
- import time
12
- from pathlib import Path
13
-
14
- sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
15
-
16
- from sen import SenTextClassifier
17
-
18
- # VNTC data paths
19
- VNTC_BASE = "/home/anhvu2/projects/workspace_underthesea/VNTC_github/Data/10Topics/Ver1.1"
20
- TRAIN_DIR = os.path.join(VNTC_BASE, "Train_Full")
21
- TEST_DIR = os.path.join(VNTC_BASE, "Test_Full")
22
-
23
- # Category mapping (folder name -> normalized label)
24
- CATEGORY_MAP = {
25
- "Chinh tri Xa hoi": "chinh_tri_xa_hoi",
26
- "Doi song": "doi_song",
27
- "Khoa hoc": "khoa_hoc",
28
- "Kinh doanh": "kinh_doanh",
29
- "Phap luat": "phap_luat",
30
- "Suc khoe": "suc_khoe",
31
- "The gioi": "the_gioi",
32
- "The thao": "the_thao",
33
- "Van hoa": "van_hoa",
34
- "Vi tinh": "vi_tinh",
35
- }
36
-
37
-
38
- def read_file(filepath):
39
- """Read text file with multiple encoding attempts."""
40
- encodings = ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']
41
-
42
- for encoding in encodings:
43
- try:
44
- with open(filepath, 'r', encoding=encoding) as f:
45
- text = f.read()
46
- # Clean up text (remove extra whitespace)
47
- text = ' '.join(text.split())
48
- if len(text) > 10: # Valid text
49
- return text
50
- except (UnicodeDecodeError, UnicodeError):
51
- continue
52
-
53
- return None
54
-
55
-
56
- def load_vntc_data(data_dir, max_per_category=None):
57
- """Load VNTC data from directory."""
58
- texts = []
59
- labels = []
60
- stats = {}
61
-
62
- for folder_name, label in CATEGORY_MAP.items():
63
- folder_path = os.path.join(data_dir, folder_name)
64
- if not os.path.exists(folder_path):
65
- print(f" Warning: {folder_path} not found")
66
- continue
67
-
68
- files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
69
- if max_per_category:
70
- files = files[:max_per_category]
71
-
72
- count = 0
73
- for filename in files:
74
- filepath = os.path.join(folder_path, filename)
75
- text = read_file(filepath)
76
- if text:
77
- texts.append(text)
78
- labels.append(label)
79
- count += 1
80
-
81
- stats[label] = count
82
-
83
- return texts, labels, stats
84
-
85
-
86
- def main():
87
- print("=" * 70)
88
- print("Training Sen Text Classifier on VNTC Dataset")
89
- print("Vietnamese News Text Classification Corpus")
90
- print("=" * 70)
91
-
92
- # Load training data
93
- print("\n[1/5] Loading training data...")
94
- start = time.time()
95
- train_texts, train_labels, train_stats = load_vntc_data(TRAIN_DIR)
96
- print(f" Loaded {len(train_texts)} training samples in {time.time()-start:.1f}s")
97
- print(" Per category:")
98
- for label, count in sorted(train_stats.items()):
99
- print(f" - {label}: {count}")
100
-
101
- # Load test data
102
- print("\n[2/5] Loading test data...")
103
- start = time.time()
104
- test_texts, test_labels, test_stats = load_vntc_data(TEST_DIR)
105
- print(f" Loaded {len(test_texts)} test samples in {time.time()-start:.1f}s")
106
- print(" Per category:")
107
- for label, count in sorted(test_stats.items()):
108
- print(f" - {label}: {count}")
109
-
110
- # Initialize classifier
111
- print("\n[3/5] Initializing classifier...")
112
- classifier = SenTextClassifier(
113
- max_features=10000, # Increased for larger dataset
114
- ngram_range=(1, 2),
115
- min_df=2, # Require term in at least 2 docs
116
- max_df=0.95,
117
- sublinear_tf=True,
118
- C=1.0,
119
- max_iter=2000,
120
- )
121
-
122
- # Train
123
- print("\n[4/5] Training...")
124
- start = time.time()
125
- results = classifier.train(
126
- train_texts=train_texts,
127
- train_labels=train_labels,
128
- val_texts=test_texts[:5000], # Use subset for validation during training
129
- val_labels=test_labels[:5000],
130
- )
131
- train_time = time.time() - start
132
- print(f" Training completed in {train_time:.1f}s")
133
-
134
- # Evaluate on full test set
135
- print("\n[5/5] Evaluating on full test set...")
136
- start = time.time()
137
- eval_results = classifier.evaluate(test_texts, test_labels)
138
- eval_time = time.time() - start
139
-
140
- print("\n" + "=" * 70)
141
- print("VNTC Benchmark Results (10 Topics)")
142
- print("=" * 70)
143
- print(f" Test samples: {len(test_texts)}")
144
- print(f" Accuracy: {eval_results['accuracy']:.4f} ({eval_results['accuracy']*100:.2f}%)")
145
- print(f" F1 (weighted):{eval_results['f1_weighted']:.4f}")
146
- print(f" Train time: {train_time:.1f}s")
147
- print(f" Eval time: {eval_time:.1f}s")
148
- print("=" * 70)
149
-
150
- # Save model
151
- save_path = "/home/anhvu2/projects/workspace_underthesea/sen/sen-1.0.0-vntc"
152
- print(f"\nSaving model to: {save_path}")
153
- classifier.save(save_path)
154
-
155
- # Sample predictions
156
- print("\nSample Predictions:")
157
- print("-" * 70)
158
- test_samples = [
159
- "Đội tuyển Việt Nam thắng đậm 3-0 trước Indonesia",
160
- "Giá vàng tăng mạnh trong phiên giao dịch hôm nay",
161
- "Apple ra mắt iPhone mới với nhiều tính năng hấp dẫn",
162
- "Bộ Y tế cảnh báo về dịch cúm mùa đông",
163
- "Quốc hội thông qua nghị quyết phát triển kinh tế",
164
- ]
165
-
166
- from sen import Sentence
167
- for text in test_samples:
168
- sentence = Sentence(text)
169
- classifier.predict(sentence)
170
- label = sentence.labels[0]
171
- print(f" '{text[:50]}...' -> {label.value} ({label.score:.2f})")
172
-
173
- print("\n" + "=" * 70)
174
- print("Training completed successfully!")
175
- print("=" * 70)
176
-
177
- return eval_results
178
-
179
-
180
- if __name__ == "__main__":
181
- results = main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sen/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- """
2
- Sen-1: Vietnamese Text Classification by UnderTheSea NLP.
3
-
4
- Based on: "A Comparative Study on Vietnamese Text Classification Methods"
5
- Vu et al., RIVF 2007
6
-
7
- Methods:
8
- - TF-IDF vectorization (sklearn)
9
- - SVM (Support Vector Machine) classifier
10
- """
11
-
12
- from .text_classifier import (
13
- Label,
14
- Sentence,
15
- SenTextClassifier,
16
- classify,
17
- )
18
-
19
- __version__ = "1.0.0"
20
-
21
- __all__ = [
22
- "Label",
23
- "Sentence",
24
- "SenTextClassifier",
25
- "classify",
26
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sen/text_classifier.py DELETED
@@ -1,374 +0,0 @@
1
- """
2
- Sen Text Classifier - Rust-based classifier using underthesea_core_extend.
3
-
4
- Based on: "A Comparative Study on Vietnamese Text Classification Methods"
5
- Vu et al., RIVF 2007
6
- https://ieeexplore.ieee.org/document/4223084/
7
-
8
- Methods:
9
- - TF-IDF vectorization (Rust: underthesea_core_extend.TfIdfVectorizer)
10
- - Linear SVM classifier (Rust: underthesea_core_extend.LinearSVM)
11
- """
12
-
13
- import json
14
- import os
15
- from typing import List, Optional, Union
16
-
17
- from underthesea_core_extend import TfIdfVectorizer, LinearSVM, SVMTrainer
18
-
19
-
20
- class Label:
21
- """Label class compatible with underthesea."""
22
-
23
- def __init__(self, value: str, score: float = 1.0):
24
- self.value = value
25
- self.score = min(max(score, 0.0), 1.0)
26
-
27
- def __str__(self):
28
- return f"{self.value} ({self.score:.4f})"
29
-
30
- def __repr__(self):
31
- return f"{self.value} ({self.score:.4f})"
32
-
33
-
34
- class Sentence:
35
- """Sentence class compatible with underthesea."""
36
-
37
- def __init__(self, text: str = None, labels: List[Label] = None):
38
- self.text = text
39
- self.labels = labels or []
40
-
41
- def __str__(self):
42
- return f'Sentence: "{self.text}" - Labels: {self.labels}'
43
-
44
- def __repr__(self):
45
- return f'Sentence: "{self.text}" - Labels: {self.labels}'
46
-
47
- def add_labels(self, labels: List[Union[Label, str]]):
48
- for label in labels:
49
- if isinstance(label, str):
50
- label = Label(label)
51
- self.labels.append(label)
52
-
53
-
54
- class SenTextClassifier:
55
- """
56
- Rust-based text classifier using TF-IDF + Linear SVM.
57
-
58
- Uses underthesea_core_extend for fast training and inference.
59
- Compatible with underthesea API.
60
-
61
- Reference:
62
- Vu et al. "A Comparative Study on Vietnamese Text Classification Methods"
63
- RIVF 2007
64
- """
65
-
66
- def __init__(
67
- self,
68
- # TF-IDF parameters
69
- max_features: int = 20000,
70
- ngram_range: tuple = (1, 2),
71
- min_df: int = 1,
72
- max_df: float = 1.0,
73
- # SVM parameters
74
- c: float = 1.0,
75
- max_iter: int = 1000,
76
- tol: float = 0.1,
77
- verbose: bool = True,
78
- ):
79
- self.max_features = max_features
80
- self.ngram_range = ngram_range
81
- self.min_df = min_df
82
- self.max_df = max_df
83
- self.c = c
84
- self.max_iter = max_iter
85
- self.tol = tol
86
- self.verbose = verbose
87
-
88
- self.vectorizer: Optional[TfIdfVectorizer] = None
89
- self.classifier: Optional[LinearSVM] = None
90
- self.labels_: Optional[List[str]] = None
91
-
92
- def train(
93
- self,
94
- train_texts: List[str],
95
- train_labels: List[str],
96
- val_texts: List[str] = None,
97
- val_labels: List[str] = None,
98
- ) -> dict:
99
- """
100
- Train the classifier.
101
-
102
- Args:
103
- train_texts: List of training texts
104
- train_labels: List of training labels
105
- val_texts: Optional validation texts
106
- val_labels: Optional validation labels
107
-
108
- Returns:
109
- Dictionary with training metrics
110
- """
111
- # Get unique labels
112
- self.labels_ = sorted(list(set(train_labels)))
113
-
114
- # Build and fit vectorizer
115
- self.vectorizer = TfIdfVectorizer(
116
- max_features=self.max_features,
117
- ngram_range=self.ngram_range,
118
- min_df=self.min_df,
119
- max_df=self.max_df,
120
- )
121
- self.vectorizer.fit(train_texts)
122
-
123
- # Transform to features
124
- train_features = self.vectorizer.transform_batch(train_texts)
125
-
126
- # Build and train SVM model
127
- trainer = SVMTrainer(
128
- c=self.c,
129
- max_iter=self.max_iter,
130
- tol=self.tol,
131
- verbose=self.verbose,
132
- )
133
- self.classifier = trainer.train(train_features, train_labels)
134
-
135
- # Calculate training metrics
136
- train_preds = self.classifier.predict_batch(train_features)
137
- train_acc = sum(1 for p, t in zip(train_preds, train_labels) if p == t) / len(train_labels)
138
-
139
- # Calculate F1 score
140
- train_f1 = self._calculate_f1(train_labels, train_preds)
141
-
142
- results = {
143
- "train_accuracy": train_acc,
144
- "train_f1": train_f1,
145
- "num_classes": len(self.labels_),
146
- "num_samples": len(train_texts),
147
- "vocab_size": self.vectorizer.vocab_size,
148
- }
149
-
150
- print(f"Training completed:")
151
- print(f" - Samples: {len(train_texts)}")
152
- print(f" - Classes: {len(self.labels_)}")
153
- print(f" - Vocab size: {self.vectorizer.vocab_size}")
154
- print(f" - Train accuracy: {train_acc:.4f}")
155
- print(f" - Train F1: {train_f1:.4f}")
156
-
157
- # Validation metrics
158
- if val_texts and val_labels:
159
- val_features = self.vectorizer.transform_batch(val_texts)
160
- val_preds = self.classifier.predict_batch(val_features)
161
- val_acc = sum(1 for p, t in zip(val_preds, val_labels) if p == t) / len(val_labels)
162
- val_f1 = self._calculate_f1(val_labels, val_preds)
163
-
164
- results["val_accuracy"] = val_acc
165
- results["val_f1"] = val_f1
166
-
167
- print(f" - Val accuracy: {val_acc:.4f}")
168
- print(f" - Val F1: {val_f1:.4f}")
169
-
170
- return results
171
-
172
- def _calculate_f1(self, y_true: List[str], y_pred: List[str]) -> float:
173
- """Calculate weighted F1 score."""
174
- from collections import Counter
175
-
176
- label_counts = Counter(y_true)
177
- total = len(y_true)
178
-
179
- f1_sum = 0.0
180
- for label in self.labels_:
181
- tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
182
- fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
183
- fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
184
-
185
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
186
- recall = tp / (tp + fn) if (tp + fn) > 0 else 0
187
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
188
-
189
- weight = label_counts[label] / total
190
- f1_sum += f1 * weight
191
-
192
- return f1_sum
193
-
194
- def predict(self, sentence: Sentence) -> None:
195
- """
196
- Predict label for a sentence (underthesea-compatible API).
197
-
198
- Args:
199
- sentence: Sentence object with text attribute
200
- """
201
- if self.classifier is None or self.vectorizer is None:
202
- raise ValueError("Model not trained. Call train() first or load a model.")
203
-
204
- features = self.vectorizer.transform_dense(sentence.text)
205
- label_value, score = self.classifier.predict_with_score(features)
206
-
207
- sentence.labels = []
208
- sentence.add_labels([Label(label_value, score)])
209
-
210
- def predict_batch(self, texts: List[str]) -> List[Label]:
211
- """
212
- Predict labels for multiple texts.
213
-
214
- Args:
215
- texts: List of texts to classify
216
-
217
- Returns:
218
- List of Label objects
219
- """
220
- if self.classifier is None or self.vectorizer is None:
221
- raise ValueError("Model not trained. Call train() first or load a model.")
222
-
223
- # Use dense transform (faster Python-Rust interface)
224
- features = self.vectorizer.transform_batch(texts)
225
- results = []
226
- for feat in features:
227
- label_value, score = self.classifier.predict_with_score(feat)
228
- results.append(Label(label_value, float(score)))
229
-
230
- return results
231
-
232
- def evaluate(self, texts: List[str], labels: List[str]) -> dict:
233
- """
234
- Evaluate model on test data.
235
-
236
- Args:
237
- texts: List of texts
238
- labels: List of true labels
239
-
240
- Returns:
241
- Dictionary with evaluation metrics
242
- """
243
- # Use dense transform (faster Python-Rust interface)
244
- features = self.vectorizer.transform_batch(texts)
245
- y_pred = self.classifier.predict_batch(features)
246
-
247
- acc = sum(1 for p, t in zip(y_pred, labels) if p == t) / len(labels)
248
- f1 = self._calculate_f1(labels, y_pred)
249
-
250
- print(f"Evaluation:")
251
- print(f" - Accuracy: {acc:.4f}")
252
- print(f" - F1 (weighted): {f1:.4f}")
253
-
254
- # Print classification report
255
- self._print_classification_report(labels, y_pred)
256
-
257
- return {"accuracy": acc, "f1": f1}
258
-
259
- def _print_classification_report(self, y_true: List[str], y_pred: List[str]):
260
- """Print classification report."""
261
- from collections import Counter
262
-
263
- print("\nClassification Report:")
264
- print(f"{'':>20} {'precision':>10} {'recall':>10} {'f1-score':>10} {'support':>10}")
265
- print()
266
-
267
- label_counts = Counter(y_true)
268
-
269
- for label in self.labels_:
270
- tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
271
- fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
272
- fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
273
-
274
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
275
- recall = tp / (tp + fn) if (tp + fn) > 0 else 0
276
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
277
- support = label_counts[label]
278
-
279
- print(f"{label:>20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f} {support:>10}")
280
-
281
- print()
282
-
283
- def save(self, path: str) -> None:
284
- """
285
- Save model to disk.
286
-
287
- Args:
288
- path: Directory path to save model
289
- """
290
- os.makedirs(path, exist_ok=True)
291
-
292
- # Save vectorizer
293
- self.vectorizer.save(os.path.join(path, "vectorizer.json"))
294
-
295
- # Save classifier
296
- self.classifier.save(os.path.join(path, "classifier.json"))
297
-
298
- # Save metadata
299
- metadata = {
300
- "estimator": "RUST_SVM",
301
- "max_features": self.max_features,
302
- "ngram_range": self.ngram_range,
303
- "min_df": self.min_df,
304
- "max_df": self.max_df,
305
- "c": self.c,
306
- "max_iter": self.max_iter,
307
- "tol": self.tol,
308
- "labels": self.labels_,
309
- "vocab_size": self.vectorizer.vocab_size,
310
- "n_classes": self.classifier.n_classes,
311
- }
312
- with open(os.path.join(path, "metadata.json"), "w", encoding="utf-8") as f:
313
- json.dump(metadata, f, ensure_ascii=False, indent=2)
314
-
315
- print(f"Model saved to: {path}")
316
-
317
- @classmethod
318
- def load(cls, path: str) -> "SenTextClassifier":
319
- """
320
- Load model from disk.
321
-
322
- Args:
323
- path: Directory path containing saved model
324
-
325
- Returns:
326
- Loaded SenTextClassifier instance
327
- """
328
- # Load metadata
329
- with open(os.path.join(path, "metadata.json"), "r", encoding="utf-8") as f:
330
- metadata = json.load(f)
331
-
332
- # Create instance with saved parameters
333
- classifier = cls(
334
- max_features=metadata.get("max_features", 20000),
335
- ngram_range=tuple(metadata.get("ngram_range", (1, 2))),
336
- min_df=metadata.get("min_df", 1),
337
- max_df=metadata.get("max_df", 1.0),
338
- c=metadata.get("c", 1.0),
339
- max_iter=metadata.get("max_iter", 1000),
340
- tol=metadata.get("tol", 0.1),
341
- )
342
-
343
- # Load vectorizer
344
- classifier.vectorizer = TfIdfVectorizer.load(os.path.join(path, "vectorizer.json"))
345
-
346
- # Load SVM model
347
- classifier.classifier = LinearSVM.load(os.path.join(path, "classifier.json"))
348
- classifier.labels_ = metadata.get("labels", [])
349
-
350
- print(f"Model loaded from: {path}")
351
- return classifier
352
-
353
-
354
- def classify(text: str, model_path: str = None) -> List[str]:
355
- """
356
- Classify text using Sen model.
357
-
358
- Args:
359
- text: Input text to classify
360
- model_path: Path to trained model
361
-
362
- Returns:
363
- List of predicted labels
364
- """
365
- if not hasattr(classify, "_classifier") or classify._model_path != model_path:
366
- if model_path:
367
- classify._classifier = SenTextClassifier.load(model_path)
368
- classify._model_path = model_path
369
- else:
370
- raise ValueError("model_path is required")
371
-
372
- sentence = Sentence(text)
373
- classify._classifier.predict(sentence)
374
- return [label.value for label in sentence.labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/train.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training CLI for Vietnamese Text Classification.
3
+
4
+ Usage:
5
+ python train.py vntc --output models/sen-vntc.bin
6
+ python train.py bank --output models/sen-bank.bin
7
+ """
8
+
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+
13
+ import click
14
+ from sklearn.metrics import accuracy_score, f1_score, classification_report
15
+
16
+ from underthesea_core import TextClassifier
17
+
18
+
19
+ def read_file(filepath):
20
+ """Read text file with multiple encoding attempts."""
21
+ for enc in ['utf-16', 'utf-8', 'latin-1']:
22
+ try:
23
+ with open(filepath, 'r', encoding=enc) as f:
24
+ text = ' '.join(f.read().split())
25
+ if len(text) > 10:
26
+ return text
27
+ except (UnicodeDecodeError, UnicodeError):
28
+ continue
29
+ return None
30
+
31
+
32
+ def load_vntc_data(data_dir):
33
+ """Load VNTC data from directory."""
34
+ texts, labels = [], []
35
+
36
+ for folder in sorted(os.listdir(data_dir)):
37
+ folder_path = os.path.join(data_dir, folder)
38
+ if not os.path.isdir(folder_path):
39
+ continue
40
+
41
+ for fname in os.listdir(folder_path):
42
+ if fname.endswith('.txt'):
43
+ text = read_file(os.path.join(folder_path, fname))
44
+ if text:
45
+ texts.append(text)
46
+ labels.append(folder)
47
+
48
+ return texts, labels
49
+
50
+
51
+ @click.group()
52
+ def cli():
53
+ """Train Vietnamese text classification models."""
54
+ pass
55
+
56
+
57
+ @cli.command()
58
+ @click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
59
+ help='Path to VNTC dataset')
60
+ @click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
61
+ @click.option('--max-features', default=20000, help='Maximum vocabulary size')
62
+ @click.option('--ngram-min', default=1, help='Minimum n-gram')
63
+ @click.option('--ngram-max', default=2, help='Maximum n-gram')
64
+ @click.option('--min-df', default=2, help='Minimum document frequency')
65
+ @click.option('--c', default=1.0, help='SVM regularization parameter')
66
+ @click.option('--max-iter', default=1000, help='Maximum iterations')
67
+ @click.option('--tol', default=0.1, help='Convergence tolerance')
68
+ def vntc(data_dir, output, max_features, ngram_min, ngram_max, min_df, c, max_iter, tol):
69
+ """Train on VNTC dataset (10 topics, ~84k documents)."""
70
+ click.echo("=" * 70)
71
+ click.echo("VNTC Dataset Training (10 Topics)")
72
+ click.echo("=" * 70)
73
+
74
+ train_dir = os.path.join(data_dir, "Train_Full")
75
+ test_dir = os.path.join(data_dir, "Test_Full")
76
+
77
+ # Load data
78
+ click.echo("\nLoading data...")
79
+ t0 = time.perf_counter()
80
+ train_texts, train_labels = load_vntc_data(train_dir)
81
+ test_texts, test_labels = load_vntc_data(test_dir)
82
+ load_time = time.perf_counter() - t0
83
+
84
+ click.echo(f" Train samples: {len(train_texts)}")
85
+ click.echo(f" Test samples: {len(test_texts)}")
86
+ click.echo(f" Categories: {len(set(train_labels))}")
87
+ click.echo(f" Load time: {load_time:.2f}s")
88
+
89
+ # Train
90
+ click.echo("\nTraining Rust TextClassifier...")
91
+ clf = TextClassifier(
92
+ max_features=max_features,
93
+ ngram_range=(ngram_min, ngram_max),
94
+ min_df=min_df,
95
+ c=c,
96
+ max_iter=max_iter,
97
+ tol=tol,
98
+ )
99
+
100
+ t0 = time.perf_counter()
101
+ clf.fit(train_texts, train_labels)
102
+ train_time = time.perf_counter() - t0
103
+ click.echo(f" Training time: {train_time:.2f}s")
104
+ click.echo(f" Vocabulary size: {clf.n_features}")
105
+
106
+ # Evaluate
107
+ click.echo("\nEvaluating...")
108
+ t0 = time.perf_counter()
109
+ preds = clf.predict_batch(test_texts)
110
+ infer_time = time.perf_counter() - t0
111
+ throughput = len(test_texts) / infer_time
112
+
113
+ acc = accuracy_score(test_labels, preds)
114
+ f1_w = f1_score(test_labels, preds, average='weighted')
115
+ f1_m = f1_score(test_labels, preds, average='macro')
116
+
117
+ click.echo(f" Inference: {infer_time:.3f}s ({throughput:.0f} samples/sec)")
118
+
119
+ click.echo("\n" + "=" * 70)
120
+ click.echo("RESULTS")
121
+ click.echo("=" * 70)
122
+ click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
123
+ click.echo(f" F1 (weighted): {f1_w:.4f}")
124
+ click.echo(f" F1 (macro): {f1_m:.4f}")
125
+
126
+ click.echo("\nClassification Report:")
127
+ click.echo(classification_report(test_labels, preds))
128
+
129
+ # Save model
130
+ model_path = Path(output)
131
+ model_path.parent.mkdir(parents=True, exist_ok=True)
132
+ clf.save(str(model_path))
133
+
134
+ size_mb = model_path.stat().st_size / (1024 * 1024)
135
+ click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
136
+
137
+
138
+ @cli.command()
139
+ @click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
140
+ @click.option('--max-features', default=10000, help='Maximum vocabulary size')
141
+ @click.option('--ngram-min', default=1, help='Minimum n-gram')
142
+ @click.option('--ngram-max', default=2, help='Maximum n-gram')
143
+ @click.option('--min-df', default=1, help='Minimum document frequency')
144
+ @click.option('--c', default=1.0, help='SVM regularization parameter')
145
+ @click.option('--max-iter', default=1000, help='Maximum iterations')
146
+ @click.option('--tol', default=0.1, help='Convergence tolerance')
147
+ def bank(output, max_features, ngram_min, ngram_max, min_df, c, max_iter, tol):
148
+ """Train on UTS2017_Bank dataset (14 categories, banking domain)."""
149
+ from datasets import load_dataset
150
+
151
+ click.echo("=" * 70)
152
+ click.echo("UTS2017_Bank Dataset Training (14 Categories)")
153
+ click.echo("=" * 70)
154
+
155
+ # Load data
156
+ click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
157
+ dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
158
+
159
+ train_texts = list(dataset["train"]["text"])
160
+ train_labels = list(dataset["train"]["label"])
161
+ test_texts = list(dataset["test"]["text"])
162
+ test_labels = list(dataset["test"]["label"])
163
+
164
+ click.echo(f" Train samples: {len(train_texts)}")
165
+ click.echo(f" Test samples: {len(test_texts)}")
166
+ click.echo(f" Categories: {len(set(train_labels))}")
167
+
168
+ # Train
169
+ click.echo("\nTraining Rust TextClassifier...")
170
+ clf = TextClassifier(
171
+ max_features=max_features,
172
+ ngram_range=(ngram_min, ngram_max),
173
+ min_df=min_df,
174
+ c=c,
175
+ max_iter=max_iter,
176
+ tol=tol,
177
+ )
178
+
179
+ t0 = time.perf_counter()
180
+ clf.fit(train_texts, train_labels)
181
+ train_time = time.perf_counter() - t0
182
+ click.echo(f" Training time: {train_time:.3f}s")
183
+ click.echo(f" Vocabulary size: {clf.n_features}")
184
+
185
+ # Evaluate
186
+ click.echo("\nEvaluating...")
187
+ preds = clf.predict_batch(test_texts)
188
+
189
+ acc = accuracy_score(test_labels, preds)
190
+ f1_w = f1_score(test_labels, preds, average='weighted')
191
+ f1_m = f1_score(test_labels, preds, average='macro')
192
+
193
+ click.echo("\n" + "=" * 70)
194
+ click.echo("RESULTS")
195
+ click.echo("=" * 70)
196
+ click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
197
+ click.echo(f" F1 (weighted): {f1_w:.4f}")
198
+ click.echo(f" F1 (macro): {f1_m:.4f}")
199
+
200
+ click.echo("\nClassification Report:")
201
+ click.echo(classification_report(test_labels, preds))
202
+
203
+ # Save model
204
+ model_path = Path(output)
205
+ model_path.parent.mkdir(parents=True, exist_ok=True)
206
+ clf.save(str(model_path))
207
+
208
+ size_mb = model_path.stat().st_size / (1024 * 1024)
209
+ click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
210
+
211
+
212
+ if __name__ == "__main__":
213
+ cli()
tests/test_classifier.py DELETED
@@ -1,165 +0,0 @@
1
- """Test Sen Text Classifier (sklearn-based)."""
2
-
3
- import sys
4
- sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
5
-
6
- from sen import SenTextClassifier, Sentence, Label
7
-
8
-
9
- def test_training():
10
- """Test training on sample Vietnamese sentiment data."""
11
- print("=" * 60)
12
- print("Test: Training sklearn-based Text Classifier")
13
- print("=" * 60)
14
-
15
- # Sample Vietnamese sentiment data
16
- train_texts = [
17
- "Sản phẩm rất tốt, tôi hài lòng!",
18
- "Chất lượng tuyệt vời, giao hàng nhanh",
19
- "Hàng đẹp, đóng gói cẩn thận",
20
- "Mình rất thích sản phẩm này",
21
- "Shop phục vụ nhiệt tình, sẽ ủng hộ tiếp",
22
- "Hàng chính hãng, giá tốt",
23
- "Rất đáng tiền, recommend cho mọi người",
24
- "Chất liệu tốt, may đẹp",
25
- "Hàng tệ quá, không như mô tả",
26
- "Chất lượng kém, không đáng tiền",
27
- "Giao hàng chậm, đóng gói cẩu thả",
28
- "Sản phẩm lỗi, shop không hỗ trợ",
29
- "Thất vọng, không bao giờ mua lại",
30
- "Hàng giả, không nên mua",
31
- "Tệ lắm, phí tiền",
32
- "Màu không đúng, size sai",
33
- ]
34
- train_labels = [
35
- "positive", "positive", "positive", "positive",
36
- "positive", "positive", "positive", "positive",
37
- "negative", "negative", "negative", "negative",
38
- "negative", "negative", "negative", "negative",
39
- ]
40
-
41
- val_texts = [
42
- "Hàng ok, sẽ mua lại",
43
- "Tệ lắm, không nên mua",
44
- ]
45
- val_labels = ["positive", "negative"]
46
-
47
- # Initialize and train
48
- classifier = SenTextClassifier(
49
- max_features=1000,
50
- ngram_range=(1, 2),
51
- min_df=1,
52
- C=1.0,
53
- )
54
-
55
- print("\nTraining...")
56
- history = classifier.train(
57
- train_texts=train_texts,
58
- train_labels=train_labels,
59
- val_texts=val_texts,
60
- val_labels=val_labels,
61
- )
62
- print()
63
-
64
- # Test predictions
65
- print("Testing predictions:")
66
- test_texts = [
67
- "Sản phẩm tuyệt vời!",
68
- "Hàng rất tệ",
69
- "Giao hàng nhanh, hàng đẹp",
70
- "Thất vọng với chất lượng",
71
- "Chất lượng tốt, giá hợp lý",
72
- "Không đáng tiền, hàng kém",
73
- ]
74
-
75
- for text in test_texts:
76
- sentence = Sentence(text)
77
- classifier.predict(sentence)
78
- print(f" '{text}' -> {sentence.labels[0]}")
79
- print()
80
-
81
- # Test batch prediction
82
- print("Batch prediction:")
83
- labels = classifier.predict_batch(test_texts)
84
- for text, label in zip(test_texts, labels):
85
- print(f" '{text}' -> {label}")
86
- print()
87
-
88
- # Save model
89
- save_path = "/tmp/sen-classifier-sklearn"
90
- classifier.save(save_path)
91
-
92
- # Load and test
93
- print("\nLoading saved model...")
94
- loaded_classifier = SenTextClassifier.load(save_path)
95
-
96
- sentence = Sentence("Rất hài lòng với sản phẩm")
97
- loaded_classifier.predict(sentence)
98
- print(f"Loaded model prediction: '{sentence.text}' -> {sentence.labels[0]}")
99
- print()
100
-
101
-
102
- def test_multiclass():
103
- """Test multi-class classification (news categories)."""
104
- print("=" * 60)
105
- print("Test: Multi-class Classification (News Categories)")
106
- print("=" * 60)
107
-
108
- # Sample Vietnamese news data (simulating VNTC categories)
109
- train_texts = [
110
- # Thể thao (Sports)
111
- "Đội tuyển Việt Nam thắng 3-0 trước Indonesia",
112
- "Cầu thủ Nguyễn Quang Hải ghi bàn đẹp mắt",
113
- "V-League 2024 khởi tranh vào tháng tới",
114
- "HLV Park Hang-seo chia tay bóng đá Việt Nam",
115
- # Kinh doanh (Business)
116
- "Chứng khoán tăng điểm mạnh phiên đầu tuần",
117
- "Ngân hàng Nhà nước điều chỉnh lãi suất",
118
- "Doanh nghiệp xuất khẩu gặp khó khăn",
119
- "Thị trường bất động sản phục hồi",
120
- # Công nghệ (Technology)
121
- "Apple ra mắt iPhone 16 với nhiều tính năng mới",
122
- "Trí tuệ nhân tạo đang thay đổi cuộc sống",
123
- "Startup công nghệ Việt Nam gọi vốn thành công",
124
- "5G được triển khai rộng rãi tại Việt Nam",
125
- ]
126
- train_labels = [
127
- "the_thao", "the_thao", "the_thao", "the_thao",
128
- "kinh_doanh", "kinh_doanh", "kinh_doanh", "kinh_doanh",
129
- "cong_nghe", "cong_nghe", "cong_nghe", "cong_nghe",
130
- ]
131
-
132
- # Initialize and train
133
- classifier = SenTextClassifier(
134
- max_features=500,
135
- ngram_range=(1, 2),
136
- min_df=1,
137
- )
138
-
139
- print("\nTraining...")
140
- classifier.train(train_texts, train_labels)
141
- print()
142
-
143
- # Test predictions
144
- print("Testing predictions:")
145
- test_texts = [
146
- "Ronaldo ghi hat-trick trong trận đấu",
147
- "VN-Index tăng 10 điểm hôm nay",
148
- "Samsung ra mắt điện thoại mới",
149
- "Đội bóng đá Việt Nam vô đ���ch AFF Cup",
150
- "Lãi suất ngân hàng giảm mạnh",
151
- ]
152
-
153
- for text in test_texts:
154
- sentence = Sentence(text)
155
- classifier.predict(sentence)
156
- print(f" '{text}' -> {sentence.labels[0]}")
157
- print()
158
-
159
-
160
- if __name__ == "__main__":
161
- test_training()
162
- test_multiclass()
163
- print("=" * 60)
164
- print("All tests completed!")
165
- print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff