Refactor: move Rust extensions to underthesea_core upstream
Browse files- Remove local Rust extensions (now in underthesea_core 3.1.7)
- Consolidate training scripts into unified CLI (src/train.py)
- Add benchmark CLI (src/bench.py) with vntc, bank, synthetic commands
- Remove src/sen module (TextClassifier now in underthesea_core)
- Add CLAUDE.md for project documentation
- Update dependencies to use underthesea_core>=3.1.7
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- CLAUDE.md +77 -0
- extensions/underthesea_core_extend/.gitignore +0 -5
- extensions/underthesea_core_extend/Cargo.lock +0 -351
- extensions/underthesea_core_extend/Cargo.toml +0 -23
- extensions/underthesea_core_extend/pyproject.toml +0 -22
- extensions/underthesea_core_extend/src/lib.rs +0 -21
- extensions/underthesea_core_extend/src/svm.rs +0 -512
- extensions/underthesea_core_extend/src/tfidf.rs +0 -235
- extensions/underthesea_core_extend/uv.lock +0 -8
- pyproject.toml +5 -6
- src/bench.py +328 -0
- src/scripts/train.py +0 -221
- src/scripts/train_sonar.py +0 -234
- src/scripts/train_vntc.py +0 -181
- src/sen/__init__.py +0 -26
- src/sen/text_classifier.py +0 -374
- src/train.py +213 -0
- tests/test_classifier.py +0 -165
- uv.lock +0 -0
CLAUDE.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
Sen-1 is a lightweight Vietnamese text classification model combining TF-IDF vectorization with Linear SVM. Part of the UnderTheSea NLP ecosystem, it serves as a practical baseline compatible with the underthesea API.
|
| 8 |
+
|
| 9 |
+
## Build & Development Commands
|
| 10 |
+
|
| 11 |
+
### Running Tests
|
| 12 |
+
```bash
|
| 13 |
+
pytest tests/test_classifier.py
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
### Development Installation
|
| 17 |
+
```bash
|
| 18 |
+
pip install -e ".[dev]"
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### Training on VNTC Dataset
|
| 22 |
+
```bash
|
| 23 |
+
python src/scripts/train_vntc.py
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Architecture
|
| 27 |
+
|
| 28 |
+
**3-Stage Pipeline:**
|
| 29 |
+
```
|
| 30 |
+
Input Text → TF-IDF Vectorizer (max_features=20k, ngram 1-2)
|
| 31 |
+
→ Linear SVM (C=1.0, max_iter=1000)
|
| 32 |
+
→ Label + Confidence
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
**Key Design Decisions:**
|
| 36 |
+
- Operates at syllable-level (no word segmentation) for speed
|
| 37 |
+
- All-Rust implementation via `underthesea_core` for fast training and inference
|
| 38 |
+
- Model serialization uses binary format (bincode)
|
| 39 |
+
|
| 40 |
+
**Core Module:** `src/sen/text_classifier.py` contains `SenTextClassifier` wrapper with train/predict/evaluate/save/load methods.
|
| 41 |
+
|
| 42 |
+
**Rust Backend:** Uses `underthesea_core.TextClassifier` which combines TF-IDF vectorization and Linear SVM in a unified Rust implementation via PyO3.
|
| 43 |
+
|
| 44 |
+
## Public API
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from sen import SenTextClassifier, Sentence, Label, classify
|
| 48 |
+
|
| 49 |
+
# Pre-trained model inference
|
| 50 |
+
labels = classify("Văn bản tiếng Việt", model_path="models/sen-1")
|
| 51 |
+
|
| 52 |
+
# Custom training
|
| 53 |
+
clf = SenTextClassifier()
|
| 54 |
+
clf.train(texts, labels)
|
| 55 |
+
clf.save("my_model")
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Key Files
|
| 59 |
+
|
| 60 |
+
- `src/sen/text_classifier.py` - Main classifier implementation (wraps underthesea_core)
|
| 61 |
+
- `src/scripts/train_vntc.py` - Training script for VNTC dataset
|
| 62 |
+
- `src/scripts/bench_vntc_full.py` - Benchmark comparing sklearn vs Rust
|
| 63 |
+
- `TECHNICAL_REPORT.md` - Detailed methodology and benchmark results
|
| 64 |
+
- `RESEARCH_PLAN.md` - Future work roadmap (PhoBERT comparison, word segmentation)
|
| 65 |
+
|
| 66 |
+
## Performance Benchmarks
|
| 67 |
+
|
| 68 |
+
- VNTC (news, 10 topics): 92.49% accuracy, 37.6s training
|
| 69 |
+
- UTS2017_Bank (14 categories): 75.76% accuracy
|
| 70 |
+
- Inference: 66,678 samples/sec batch, 0.465ms single
|
| 71 |
+
|
| 72 |
+
## Known Limitations
|
| 73 |
+
|
| 74 |
+
1. Syllable-level only (no word segmentation) - ~4.6% gap vs word-level approaches
|
| 75 |
+
2. Single-label classification only
|
| 76 |
+
3. Trained on news domain - may not generalize to social media/reviews
|
| 77 |
+
4. Lower performance on imbalanced categories
|
extensions/underthesea_core_extend/.gitignore
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
target/
|
| 2 |
-
.venv/
|
| 3 |
-
__pycache__/
|
| 4 |
-
*.so
|
| 5 |
-
*.pyc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/Cargo.lock
DELETED
|
@@ -1,351 +0,0 @@
|
|
| 1 |
-
# This file is automatically @generated by Cargo.
|
| 2 |
-
# It is not intended for manual editing.
|
| 3 |
-
version = 4
|
| 4 |
-
|
| 5 |
-
[[package]]
|
| 6 |
-
name = "allocator-api2"
|
| 7 |
-
version = "0.2.21"
|
| 8 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 9 |
-
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
| 10 |
-
|
| 11 |
-
[[package]]
|
| 12 |
-
name = "autocfg"
|
| 13 |
-
version = "1.5.0"
|
| 14 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 15 |
-
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
| 16 |
-
|
| 17 |
-
[[package]]
|
| 18 |
-
name = "cfg-if"
|
| 19 |
-
version = "1.0.4"
|
| 20 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 21 |
-
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
| 22 |
-
|
| 23 |
-
[[package]]
|
| 24 |
-
name = "crossbeam-deque"
|
| 25 |
-
version = "0.8.6"
|
| 26 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 27 |
-
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
| 28 |
-
dependencies = [
|
| 29 |
-
"crossbeam-epoch",
|
| 30 |
-
"crossbeam-utils",
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
[[package]]
|
| 34 |
-
name = "crossbeam-epoch"
|
| 35 |
-
version = "0.9.18"
|
| 36 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 37 |
-
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
| 38 |
-
dependencies = [
|
| 39 |
-
"crossbeam-utils",
|
| 40 |
-
]
|
| 41 |
-
|
| 42 |
-
[[package]]
|
| 43 |
-
name = "crossbeam-utils"
|
| 44 |
-
version = "0.8.21"
|
| 45 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 46 |
-
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
| 47 |
-
|
| 48 |
-
[[package]]
|
| 49 |
-
name = "either"
|
| 50 |
-
version = "1.15.0"
|
| 51 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 52 |
-
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
| 53 |
-
|
| 54 |
-
[[package]]
|
| 55 |
-
name = "equivalent"
|
| 56 |
-
version = "1.0.2"
|
| 57 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 58 |
-
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
| 59 |
-
|
| 60 |
-
[[package]]
|
| 61 |
-
name = "foldhash"
|
| 62 |
-
version = "0.1.5"
|
| 63 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 64 |
-
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
| 65 |
-
|
| 66 |
-
[[package]]
|
| 67 |
-
name = "hashbrown"
|
| 68 |
-
version = "0.15.5"
|
| 69 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 70 |
-
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
| 71 |
-
dependencies = [
|
| 72 |
-
"allocator-api2",
|
| 73 |
-
"equivalent",
|
| 74 |
-
"foldhash",
|
| 75 |
-
"serde",
|
| 76 |
-
]
|
| 77 |
-
|
| 78 |
-
[[package]]
|
| 79 |
-
name = "heck"
|
| 80 |
-
version = "0.5.0"
|
| 81 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 82 |
-
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
| 83 |
-
|
| 84 |
-
[[package]]
|
| 85 |
-
name = "indoc"
|
| 86 |
-
version = "2.0.7"
|
| 87 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 88 |
-
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
|
| 89 |
-
dependencies = [
|
| 90 |
-
"rustversion",
|
| 91 |
-
]
|
| 92 |
-
|
| 93 |
-
[[package]]
|
| 94 |
-
name = "itoa"
|
| 95 |
-
version = "1.0.17"
|
| 96 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 97 |
-
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
| 98 |
-
|
| 99 |
-
[[package]]
|
| 100 |
-
name = "libc"
|
| 101 |
-
version = "0.2.180"
|
| 102 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 103 |
-
checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
|
| 104 |
-
|
| 105 |
-
[[package]]
|
| 106 |
-
name = "memchr"
|
| 107 |
-
version = "2.7.6"
|
| 108 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 109 |
-
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
|
| 110 |
-
|
| 111 |
-
[[package]]
|
| 112 |
-
name = "memoffset"
|
| 113 |
-
version = "0.9.1"
|
| 114 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 115 |
-
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
| 116 |
-
dependencies = [
|
| 117 |
-
"autocfg",
|
| 118 |
-
]
|
| 119 |
-
|
| 120 |
-
[[package]]
|
| 121 |
-
name = "once_cell"
|
| 122 |
-
version = "1.21.3"
|
| 123 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 124 |
-
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
| 125 |
-
|
| 126 |
-
[[package]]
|
| 127 |
-
name = "portable-atomic"
|
| 128 |
-
version = "1.13.1"
|
| 129 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 130 |
-
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
| 131 |
-
|
| 132 |
-
[[package]]
|
| 133 |
-
name = "proc-macro2"
|
| 134 |
-
version = "1.0.106"
|
| 135 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 136 |
-
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
| 137 |
-
dependencies = [
|
| 138 |
-
"unicode-ident",
|
| 139 |
-
]
|
| 140 |
-
|
| 141 |
-
[[package]]
|
| 142 |
-
name = "pyo3"
|
| 143 |
-
version = "0.22.6"
|
| 144 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 145 |
-
checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
|
| 146 |
-
dependencies = [
|
| 147 |
-
"cfg-if",
|
| 148 |
-
"indoc",
|
| 149 |
-
"libc",
|
| 150 |
-
"memoffset",
|
| 151 |
-
"once_cell",
|
| 152 |
-
"portable-atomic",
|
| 153 |
-
"pyo3-build-config",
|
| 154 |
-
"pyo3-ffi",
|
| 155 |
-
"pyo3-macros",
|
| 156 |
-
"unindent",
|
| 157 |
-
]
|
| 158 |
-
|
| 159 |
-
[[package]]
|
| 160 |
-
name = "pyo3-build-config"
|
| 161 |
-
version = "0.22.6"
|
| 162 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 163 |
-
checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
|
| 164 |
-
dependencies = [
|
| 165 |
-
"once_cell",
|
| 166 |
-
"target-lexicon",
|
| 167 |
-
]
|
| 168 |
-
|
| 169 |
-
[[package]]
|
| 170 |
-
name = "pyo3-ffi"
|
| 171 |
-
version = "0.22.6"
|
| 172 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 173 |
-
checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
|
| 174 |
-
dependencies = [
|
| 175 |
-
"libc",
|
| 176 |
-
"pyo3-build-config",
|
| 177 |
-
]
|
| 178 |
-
|
| 179 |
-
[[package]]
|
| 180 |
-
name = "pyo3-macros"
|
| 181 |
-
version = "0.22.6"
|
| 182 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 183 |
-
checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
|
| 184 |
-
dependencies = [
|
| 185 |
-
"proc-macro2",
|
| 186 |
-
"pyo3-macros-backend",
|
| 187 |
-
"quote",
|
| 188 |
-
"syn",
|
| 189 |
-
]
|
| 190 |
-
|
| 191 |
-
[[package]]
|
| 192 |
-
name = "pyo3-macros-backend"
|
| 193 |
-
version = "0.22.6"
|
| 194 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 195 |
-
checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
|
| 196 |
-
dependencies = [
|
| 197 |
-
"heck",
|
| 198 |
-
"proc-macro2",
|
| 199 |
-
"pyo3-build-config",
|
| 200 |
-
"quote",
|
| 201 |
-
"syn",
|
| 202 |
-
]
|
| 203 |
-
|
| 204 |
-
[[package]]
|
| 205 |
-
name = "quote"
|
| 206 |
-
version = "1.0.44"
|
| 207 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 208 |
-
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
|
| 209 |
-
dependencies = [
|
| 210 |
-
"proc-macro2",
|
| 211 |
-
]
|
| 212 |
-
|
| 213 |
-
[[package]]
|
| 214 |
-
name = "rayon"
|
| 215 |
-
version = "1.11.0"
|
| 216 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 217 |
-
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
|
| 218 |
-
dependencies = [
|
| 219 |
-
"either",
|
| 220 |
-
"rayon-core",
|
| 221 |
-
]
|
| 222 |
-
|
| 223 |
-
[[package]]
|
| 224 |
-
name = "rayon-core"
|
| 225 |
-
version = "1.13.0"
|
| 226 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 227 |
-
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
|
| 228 |
-
dependencies = [
|
| 229 |
-
"crossbeam-deque",
|
| 230 |
-
"crossbeam-utils",
|
| 231 |
-
]
|
| 232 |
-
|
| 233 |
-
[[package]]
|
| 234 |
-
name = "rustversion"
|
| 235 |
-
version = "1.0.22"
|
| 236 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 237 |
-
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
| 238 |
-
|
| 239 |
-
[[package]]
|
| 240 |
-
name = "serde"
|
| 241 |
-
version = "1.0.228"
|
| 242 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 243 |
-
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
|
| 244 |
-
dependencies = [
|
| 245 |
-
"serde_core",
|
| 246 |
-
"serde_derive",
|
| 247 |
-
]
|
| 248 |
-
|
| 249 |
-
[[package]]
|
| 250 |
-
name = "serde_core"
|
| 251 |
-
version = "1.0.228"
|
| 252 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 253 |
-
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
|
| 254 |
-
dependencies = [
|
| 255 |
-
"serde_derive",
|
| 256 |
-
]
|
| 257 |
-
|
| 258 |
-
[[package]]
|
| 259 |
-
name = "serde_derive"
|
| 260 |
-
version = "1.0.228"
|
| 261 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 262 |
-
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
| 263 |
-
dependencies = [
|
| 264 |
-
"proc-macro2",
|
| 265 |
-
"quote",
|
| 266 |
-
"syn",
|
| 267 |
-
]
|
| 268 |
-
|
| 269 |
-
[[package]]
|
| 270 |
-
name = "serde_json"
|
| 271 |
-
version = "1.0.149"
|
| 272 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 273 |
-
checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
|
| 274 |
-
dependencies = [
|
| 275 |
-
"itoa",
|
| 276 |
-
"memchr",
|
| 277 |
-
"serde",
|
| 278 |
-
"serde_core",
|
| 279 |
-
"zmij",
|
| 280 |
-
]
|
| 281 |
-
|
| 282 |
-
[[package]]
|
| 283 |
-
name = "syn"
|
| 284 |
-
version = "2.0.114"
|
| 285 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 286 |
-
checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
|
| 287 |
-
dependencies = [
|
| 288 |
-
"proc-macro2",
|
| 289 |
-
"quote",
|
| 290 |
-
"unicode-ident",
|
| 291 |
-
]
|
| 292 |
-
|
| 293 |
-
[[package]]
|
| 294 |
-
name = "target-lexicon"
|
| 295 |
-
version = "0.12.16"
|
| 296 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 297 |
-
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
| 298 |
-
|
| 299 |
-
[[package]]
|
| 300 |
-
name = "tinyvec"
|
| 301 |
-
version = "1.10.0"
|
| 302 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 303 |
-
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
| 304 |
-
dependencies = [
|
| 305 |
-
"tinyvec_macros",
|
| 306 |
-
]
|
| 307 |
-
|
| 308 |
-
[[package]]
|
| 309 |
-
name = "tinyvec_macros"
|
| 310 |
-
version = "0.1.1"
|
| 311 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 312 |
-
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
| 313 |
-
|
| 314 |
-
[[package]]
|
| 315 |
-
name = "underthesea_core_extend"
|
| 316 |
-
version = "0.1.0"
|
| 317 |
-
dependencies = [
|
| 318 |
-
"hashbrown",
|
| 319 |
-
"pyo3",
|
| 320 |
-
"rayon",
|
| 321 |
-
"serde",
|
| 322 |
-
"serde_json",
|
| 323 |
-
"unicode-normalization",
|
| 324 |
-
]
|
| 325 |
-
|
| 326 |
-
[[package]]
|
| 327 |
-
name = "unicode-ident"
|
| 328 |
-
version = "1.0.22"
|
| 329 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 330 |
-
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
| 331 |
-
|
| 332 |
-
[[package]]
|
| 333 |
-
name = "unicode-normalization"
|
| 334 |
-
version = "0.1.25"
|
| 335 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 336 |
-
checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
|
| 337 |
-
dependencies = [
|
| 338 |
-
"tinyvec",
|
| 339 |
-
]
|
| 340 |
-
|
| 341 |
-
[[package]]
|
| 342 |
-
name = "unindent"
|
| 343 |
-
version = "0.2.4"
|
| 344 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 345 |
-
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
|
| 346 |
-
|
| 347 |
-
[[package]]
|
| 348 |
-
name = "zmij"
|
| 349 |
-
version = "1.0.19"
|
| 350 |
-
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 351 |
-
checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/Cargo.toml
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
[package]
|
| 2 |
-
name = "underthesea_core_extend"
|
| 3 |
-
version = "0.1.0"
|
| 4 |
-
edition = "2021"
|
| 5 |
-
description = "Rust extensions for underthesea - Text Classification"
|
| 6 |
-
license = "Apache-2.0"
|
| 7 |
-
|
| 8 |
-
[lib]
|
| 9 |
-
name = "underthesea_core_extend"
|
| 10 |
-
crate-type = ["cdylib"]
|
| 11 |
-
|
| 12 |
-
[dependencies]
|
| 13 |
-
pyo3 = { version = "0.22", features = ["extension-module"] }
|
| 14 |
-
serde = { version = "1.0", features = ["derive"] }
|
| 15 |
-
serde_json = "1.0"
|
| 16 |
-
rayon = "1.10"
|
| 17 |
-
hashbrown = { version = "0.15", features = ["serde"] }
|
| 18 |
-
unicode-normalization = "0.1"
|
| 19 |
-
|
| 20 |
-
[profile.release]
|
| 21 |
-
lto = true
|
| 22 |
-
codegen-units = 1
|
| 23 |
-
opt-level = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/pyproject.toml
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["maturin>=1.0,<2.0"]
|
| 3 |
-
build-backend = "maturin"
|
| 4 |
-
|
| 5 |
-
[project]
|
| 6 |
-
name = "underthesea_core_extend"
|
| 7 |
-
version = "0.1.0"
|
| 8 |
-
description = "Rust extensions for underthesea - Text Classification"
|
| 9 |
-
requires-python = ">=3.10"
|
| 10 |
-
license = { text = "Apache-2.0" }
|
| 11 |
-
authors = [{ name = "UnderTheSea NLP", email = "anhv.ict91@gmail.com" }]
|
| 12 |
-
classifiers = [
|
| 13 |
-
"Programming Language :: Rust",
|
| 14 |
-
"Programming Language :: Python :: Implementation :: CPython",
|
| 15 |
-
"Programming Language :: Python :: 3.10",
|
| 16 |
-
"Programming Language :: Python :: 3.11",
|
| 17 |
-
"Programming Language :: Python :: 3.12",
|
| 18 |
-
]
|
| 19 |
-
|
| 20 |
-
[tool.maturin]
|
| 21 |
-
features = ["pyo3/extension-module"]
|
| 22 |
-
module-name = "underthesea_core_extend"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/src/lib.rs
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
//! underthesea_core_extend - Rust extensions for Vietnamese Text Classification
|
| 2 |
-
//!
|
| 3 |
-
//! Provides fast TF-IDF vectorization and Linear SVM classification.
|
| 4 |
-
|
| 5 |
-
use pyo3::prelude::*;
|
| 6 |
-
|
| 7 |
-
mod tfidf;
|
| 8 |
-
mod svm;
|
| 9 |
-
|
| 10 |
-
pub use tfidf::TfIdfVectorizer;
|
| 11 |
-
pub use svm::{LinearSVM, SVMTrainer, FastSVMTrainer};
|
| 12 |
-
|
| 13 |
-
/// Python module
|
| 14 |
-
#[pymodule]
|
| 15 |
-
fn underthesea_core_extend(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 16 |
-
m.add_class::<TfIdfVectorizer>()?;
|
| 17 |
-
m.add_class::<LinearSVM>()?;
|
| 18 |
-
m.add_class::<SVMTrainer>()?;
|
| 19 |
-
m.add_class::<FastSVMTrainer>()?;
|
| 20 |
-
Ok(())
|
| 21 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/src/svm.rs
DELETED
|
@@ -1,512 +0,0 @@
|
|
| 1 |
-
//! Optimized Linear SVM - LIBLINEAR-style Dual Coordinate Descent
|
| 2 |
-
//!
|
| 3 |
-
//! Pure Rust implementation of L2-regularized L2-loss SVM (dual form)
|
| 4 |
-
//! Reference: "A Dual Coordinate Descent Method for Large-scale Linear SVM"
|
| 5 |
-
//! Hsieh et al., ICML 2008
|
| 6 |
-
|
| 7 |
-
use hashbrown::HashMap;
|
| 8 |
-
use pyo3::prelude::*;
|
| 9 |
-
use rayon::prelude::*;
|
| 10 |
-
use serde::{Deserialize, Serialize};
|
| 11 |
-
use std::fs::File;
|
| 12 |
-
use std::io::{BufReader, BufWriter};
|
| 13 |
-
|
| 14 |
-
/// Sparse feature vector
|
| 15 |
-
pub type SparseVec = Vec<(u32, f32)>; // Use u32/f32 for memory efficiency
|
| 16 |
-
|
| 17 |
-
/// Linear SVM Model
|
| 18 |
-
#[pyclass]
|
| 19 |
-
#[derive(Clone, Serialize, Deserialize)]
|
| 20 |
-
pub struct LinearSVM {
|
| 21 |
-
weights: Vec<Vec<f32>>,
|
| 22 |
-
biases: Vec<f32>,
|
| 23 |
-
classes: Vec<String>,
|
| 24 |
-
n_features: usize,
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
#[pymethods]
|
| 28 |
-
impl LinearSVM {
|
| 29 |
-
#[new]
|
| 30 |
-
pub fn new() -> Self {
|
| 31 |
-
Self {
|
| 32 |
-
weights: Vec::new(),
|
| 33 |
-
biases: Vec::new(),
|
| 34 |
-
classes: Vec::new(),
|
| 35 |
-
n_features: 0,
|
| 36 |
-
}
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
pub fn predict(&self, features: Vec<f64>) -> String {
|
| 40 |
-
let idx = self.predict_idx(&features);
|
| 41 |
-
self.classes[idx].clone()
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
pub fn predict_with_score(&self, features: Vec<f64>) -> (String, f64) {
|
| 45 |
-
let scores = self.decision_scores(&features);
|
| 46 |
-
let (idx, &max_score) = scores
|
| 47 |
-
.iter()
|
| 48 |
-
.enumerate()
|
| 49 |
-
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
| 50 |
-
.unwrap();
|
| 51 |
-
let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
|
| 52 |
-
(self.classes[idx].clone(), confidence)
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
pub fn predict_batch(&self, features_batch: Vec<Vec<f64>>) -> Vec<String> {
|
| 56 |
-
features_batch
|
| 57 |
-
.par_iter()
|
| 58 |
-
.map(|f| {
|
| 59 |
-
let idx = self.predict_idx(f);
|
| 60 |
-
self.classes[idx].clone()
|
| 61 |
-
})
|
| 62 |
-
.collect()
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
pub fn predict_batch_sparse(&self, features_batch: Vec<Vec<(usize, f64)>>) -> Vec<String> {
|
| 66 |
-
features_batch
|
| 67 |
-
.par_iter()
|
| 68 |
-
.map(|f| {
|
| 69 |
-
let idx = self.predict_idx_sparse(f);
|
| 70 |
-
self.classes[idx].clone()
|
| 71 |
-
})
|
| 72 |
-
.collect()
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
pub fn predict_sparse_with_score(&self, features: Vec<(usize, f64)>) -> (String, f64) {
|
| 76 |
-
let scores = self.decision_scores_sparse(&features);
|
| 77 |
-
let (idx, &max_score) = scores
|
| 78 |
-
.iter()
|
| 79 |
-
.enumerate()
|
| 80 |
-
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
| 81 |
-
.unwrap();
|
| 82 |
-
let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
|
| 83 |
-
(self.classes[idx].clone(), confidence)
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
pub fn decision_function(&self, features: Vec<f64>) -> Vec<f64> {
|
| 87 |
-
self.decision_scores(&features).into_iter().map(|x| x as f64).collect()
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
#[getter]
|
| 91 |
-
pub fn classes(&self) -> Vec<String> {
|
| 92 |
-
self.classes.clone()
|
| 93 |
-
}
|
| 94 |
-
|
| 95 |
-
#[getter]
|
| 96 |
-
pub fn n_classes(&self) -> usize {
|
| 97 |
-
self.classes.len()
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
#[getter]
|
| 101 |
-
pub fn n_features(&self) -> usize {
|
| 102 |
-
self.n_features
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
pub fn save(&self, path: &str) -> PyResult<()> {
|
| 106 |
-
let file = File::create(path)
|
| 107 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 108 |
-
let writer = BufWriter::new(file);
|
| 109 |
-
serde_json::to_writer(writer, self)
|
| 110 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 111 |
-
Ok(())
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
#[staticmethod]
|
| 115 |
-
pub fn load(path: &str) -> PyResult<Self> {
|
| 116 |
-
let file = File::open(path)
|
| 117 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 118 |
-
let reader = BufReader::new(file);
|
| 119 |
-
let model: Self = serde_json::from_reader(reader)
|
| 120 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 121 |
-
Ok(model)
|
| 122 |
-
}
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
impl LinearSVM {
|
| 126 |
-
#[inline]
|
| 127 |
-
fn predict_idx(&self, features: &[f64]) -> usize {
|
| 128 |
-
let mut best_idx = 0;
|
| 129 |
-
let mut best_score = f32::NEG_INFINITY;
|
| 130 |
-
|
| 131 |
-
for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
| 132 |
-
let score: f32 = w.iter()
|
| 133 |
-
.zip(features.iter())
|
| 134 |
-
.map(|(&wi, &fi)| wi * fi as f32)
|
| 135 |
-
.sum::<f32>() + b;
|
| 136 |
-
|
| 137 |
-
if score > best_score {
|
| 138 |
-
best_score = score;
|
| 139 |
-
best_idx = idx;
|
| 140 |
-
}
|
| 141 |
-
}
|
| 142 |
-
best_idx
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
#[inline]
|
| 146 |
-
fn predict_idx_sparse(&self, features: &[(usize, f64)]) -> usize {
|
| 147 |
-
let mut best_idx = 0;
|
| 148 |
-
let mut best_score = f32::NEG_INFINITY;
|
| 149 |
-
|
| 150 |
-
for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
| 151 |
-
let score: f32 = features.iter()
|
| 152 |
-
.map(|&(j, v)| w[j] * v as f32)
|
| 153 |
-
.sum::<f32>() + b;
|
| 154 |
-
|
| 155 |
-
if score > best_score {
|
| 156 |
-
best_score = score;
|
| 157 |
-
best_idx = idx;
|
| 158 |
-
}
|
| 159 |
-
}
|
| 160 |
-
best_idx
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
fn decision_scores(&self, features: &[f64]) -> Vec<f32> {
|
| 164 |
-
self.weights
|
| 165 |
-
.iter()
|
| 166 |
-
.zip(self.biases.iter())
|
| 167 |
-
.map(|(w, &b)| {
|
| 168 |
-
w.iter()
|
| 169 |
-
.zip(features.iter())
|
| 170 |
-
.map(|(&wi, &fi)| wi * fi as f32)
|
| 171 |
-
.sum::<f32>() + b
|
| 172 |
-
})
|
| 173 |
-
.collect()
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
fn decision_scores_sparse(&self, features: &[(usize, f64)]) -> Vec<f32> {
|
| 177 |
-
self.weights
|
| 178 |
-
.iter()
|
| 179 |
-
.zip(self.biases.iter())
|
| 180 |
-
.map(|(w, &b)| {
|
| 181 |
-
features.iter()
|
| 182 |
-
.map(|&(j, v)| w[j] * v as f32)
|
| 183 |
-
.sum::<f32>() + b
|
| 184 |
-
})
|
| 185 |
-
.collect()
|
| 186 |
-
}
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
/// LIBLINEAR-style SVM Trainer
|
| 190 |
-
#[pyclass]
|
| 191 |
-
pub struct SVMTrainer {
|
| 192 |
-
c: f64,
|
| 193 |
-
max_iter: usize,
|
| 194 |
-
tol: f64,
|
| 195 |
-
verbose: bool,
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
#[pymethods]
|
| 199 |
-
impl SVMTrainer {
|
| 200 |
-
#[new]
|
| 201 |
-
#[pyo3(signature = (c=1.0, max_iter=1000, tol=0.1, verbose=false))]
|
| 202 |
-
pub fn new(c: f64, max_iter: usize, tol: f64, verbose: bool) -> Self {
|
| 203 |
-
Self { c, max_iter, tol, verbose }
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
pub fn set_c(&mut self, c: f64) {
|
| 207 |
-
self.c = c;
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
pub fn set_max_iter(&mut self, max_iter: usize) {
|
| 211 |
-
self.max_iter = max_iter;
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
pub fn set_verbose(&mut self, verbose: bool) {
|
| 215 |
-
self.verbose = verbose;
|
| 216 |
-
}
|
| 217 |
-
|
| 218 |
-
pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
|
| 219 |
-
let n_samples = features.len();
|
| 220 |
-
let n_features = if n_samples > 0 { features[0].len() } else { 0 };
|
| 221 |
-
|
| 222 |
-
// Convert to compact sparse format (f32 for memory/cache efficiency)
|
| 223 |
-
let sparse_features: Vec<SparseVec> = features
|
| 224 |
-
.par_iter()
|
| 225 |
-
.map(|dense| {
|
| 226 |
-
dense
|
| 227 |
-
.iter()
|
| 228 |
-
.enumerate()
|
| 229 |
-
.filter(|&(_, &v)| v.abs() > 1e-10)
|
| 230 |
-
.map(|(i, &v)| (i as u32, v as f32))
|
| 231 |
-
.collect()
|
| 232 |
-
})
|
| 233 |
-
.collect();
|
| 234 |
-
|
| 235 |
-
// Precompute ||x_i||^2
|
| 236 |
-
let x_sq_norms: Vec<f32> = sparse_features
|
| 237 |
-
.par_iter()
|
| 238 |
-
.map(|x| x.iter().map(|&(_, v)| v * v).sum())
|
| 239 |
-
.collect();
|
| 240 |
-
|
| 241 |
-
// Get unique classes
|
| 242 |
-
let mut classes: Vec<String> = labels.iter().cloned().collect();
|
| 243 |
-
classes.sort();
|
| 244 |
-
classes.dedup();
|
| 245 |
-
let n_classes = classes.len();
|
| 246 |
-
|
| 247 |
-
let class_to_idx: HashMap<String, usize> = classes
|
| 248 |
-
.iter()
|
| 249 |
-
.enumerate()
|
| 250 |
-
.map(|(i, c)| (c.clone(), i))
|
| 251 |
-
.collect();
|
| 252 |
-
|
| 253 |
-
let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
|
| 254 |
-
|
| 255 |
-
// Train binary classifiers in parallel (one-vs-rest)
|
| 256 |
-
let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
|
| 257 |
-
.into_par_iter()
|
| 258 |
-
.map(|class_idx| {
|
| 259 |
-
let y_binary: Vec<i8> = y_idx
|
| 260 |
-
.iter()
|
| 261 |
-
.map(|&idx| if idx == class_idx { 1 } else { -1 })
|
| 262 |
-
.collect();
|
| 263 |
-
|
| 264 |
-
solve_l2r_l2_svc(
|
| 265 |
-
&sparse_features,
|
| 266 |
-
&y_binary,
|
| 267 |
-
&x_sq_norms,
|
| 268 |
-
n_features,
|
| 269 |
-
self.c as f32,
|
| 270 |
-
self.tol as f32,
|
| 271 |
-
self.max_iter,
|
| 272 |
-
)
|
| 273 |
-
})
|
| 274 |
-
.collect();
|
| 275 |
-
|
| 276 |
-
let weights = results.iter().map(|(w, _)| w.clone()).collect();
|
| 277 |
-
let biases = results.iter().map(|(_, b)| *b).collect();
|
| 278 |
-
|
| 279 |
-
LinearSVM {
|
| 280 |
-
weights,
|
| 281 |
-
biases,
|
| 282 |
-
classes,
|
| 283 |
-
n_features,
|
| 284 |
-
}
|
| 285 |
-
}
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
/// LIBLINEAR's solve_l2r_l2_svc - Dual Coordinate Descent for L2-loss SVM
|
| 289 |
-
///
|
| 290 |
-
/// Solves: min_α 0.5 * α^T * Q * α - e^T * α, s.t. α_i ≥ 0
|
| 291 |
-
/// where Q_ij = y_i * y_j * x_i^T * x_j + δ_ij / (2C)
|
| 292 |
-
///
|
| 293 |
-
/// Primal-dual relationship: w = Σ α_i * y_i * x_i
|
| 294 |
-
#[inline(never)]
|
| 295 |
-
fn solve_l2r_l2_svc(
|
| 296 |
-
x: &[SparseVec],
|
| 297 |
-
y: &[i8],
|
| 298 |
-
x_sq_norms: &[f32],
|
| 299 |
-
n_features: usize,
|
| 300 |
-
c: f32,
|
| 301 |
-
eps: f32,
|
| 302 |
-
max_iter: usize,
|
| 303 |
-
) -> (Vec<f32>, f32) {
|
| 304 |
-
let n = x.len();
|
| 305 |
-
|
| 306 |
-
// D_ii = 1/(2C) for L2-loss SVM
|
| 307 |
-
let diag = 0.5 / c;
|
| 308 |
-
|
| 309 |
-
// QD[i] = ||x_i||^2 + D_ii
|
| 310 |
-
let qd: Vec<f32> = x_sq_norms.iter().map(|&xn| xn + diag).collect();
|
| 311 |
-
|
| 312 |
-
// Initialize α = 0
|
| 313 |
-
let mut alpha = vec![0.0f32; n];
|
| 314 |
-
|
| 315 |
-
// w = Σ α_i * y_i * x_i (initially 0)
|
| 316 |
-
let mut w = vec![0.0f32; n_features];
|
| 317 |
-
|
| 318 |
-
// Index for permutation
|
| 319 |
-
let mut index: Vec<usize> = (0..n).collect();
|
| 320 |
-
|
| 321 |
-
// Main loop
|
| 322 |
-
for iter in 0..max_iter {
|
| 323 |
-
// Shuffle indices
|
| 324 |
-
for i in 0..n {
|
| 325 |
-
let j = i + (iter * 1103515245 + 12345) % (n - i).max(1);
|
| 326 |
-
index.swap(i, j);
|
| 327 |
-
}
|
| 328 |
-
|
| 329 |
-
let mut max_violation = 0.0f32;
|
| 330 |
-
|
| 331 |
-
for &i in &index {
|
| 332 |
-
let yi = y[i] as f32;
|
| 333 |
-
let xi = &x[i];
|
| 334 |
-
|
| 335 |
-
// G = y_i * (w · x_i) - 1 + D_ii * α_i
|
| 336 |
-
let wxi: f32 = xi.iter().map(|&(j, v)| w[j as usize] * v).sum();
|
| 337 |
-
let g = yi * wxi - 1.0 + diag * alpha[i];
|
| 338 |
-
|
| 339 |
-
// Projected gradient (α ≥ 0, no upper bound for L2-loss)
|
| 340 |
-
let pg = if alpha[i] == 0.0 { g.min(0.0) } else { g };
|
| 341 |
-
|
| 342 |
-
max_violation = max_violation.max(pg.abs());
|
| 343 |
-
|
| 344 |
-
if pg.abs() > 1e-12 {
|
| 345 |
-
let alpha_old = alpha[i];
|
| 346 |
-
|
| 347 |
-
// α_i = max(0, α_i - G/Q_ii)
|
| 348 |
-
alpha[i] = (alpha[i] - g / qd[i]).max(0.0);
|
| 349 |
-
|
| 350 |
-
// Update w: w += (α_new - α_old) * y_i * x_i
|
| 351 |
-
let d = (alpha[i] - alpha_old) * yi;
|
| 352 |
-
if d.abs() > 1e-12 {
|
| 353 |
-
for &(j, v) in xi.iter() {
|
| 354 |
-
w[j as usize] += d * v;
|
| 355 |
-
}
|
| 356 |
-
}
|
| 357 |
-
}
|
| 358 |
-
}
|
| 359 |
-
|
| 360 |
-
// Stopping criterion
|
| 361 |
-
if max_violation <= eps {
|
| 362 |
-
break;
|
| 363 |
-
}
|
| 364 |
-
}
|
| 365 |
-
|
| 366 |
-
// Compute bias from KKT conditions
|
| 367 |
-
// For α_i > 0: y_i * (w · x_i + b) = 1 - α_i / (2C)
|
| 368 |
-
let mut bias_sum = 0.0f32;
|
| 369 |
-
let mut n_sv = 0;
|
| 370 |
-
|
| 371 |
-
for i in 0..n {
|
| 372 |
-
if alpha[i] > 1e-8 {
|
| 373 |
-
let yi = y[i] as f32;
|
| 374 |
-
let wxi: f32 = x[i].iter().map(|&(j, v)| w[j as usize] * v).sum();
|
| 375 |
-
// b = y_i * (1 - α_i * diag) - w · x_i
|
| 376 |
-
bias_sum += yi * (1.0 - alpha[i] * diag) - wxi;
|
| 377 |
-
n_sv += 1;
|
| 378 |
-
}
|
| 379 |
-
}
|
| 380 |
-
|
| 381 |
-
let bias = if n_sv > 0 { bias_sum / n_sv as f32 } else { 0.0 };
|
| 382 |
-
|
| 383 |
-
(w, bias)
|
| 384 |
-
}
|
| 385 |
-
|
| 386 |
-
/// Fast SVM using Pegasos algorithm
|
| 387 |
-
#[pyclass]
|
| 388 |
-
pub struct FastSVMTrainer {
|
| 389 |
-
c: f64,
|
| 390 |
-
max_iter: usize,
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
#[pymethods]
|
| 394 |
-
impl FastSVMTrainer {
|
| 395 |
-
#[new]
|
| 396 |
-
#[pyo3(signature = (c=1.0, max_iter=100))]
|
| 397 |
-
pub fn new(c: f64, max_iter: usize) -> Self {
|
| 398 |
-
Self { c, max_iter }
|
| 399 |
-
}
|
| 400 |
-
|
| 401 |
-
pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
|
| 402 |
-
let n_samples = features.len();
|
| 403 |
-
let n_features = if n_samples > 0 { features[0].len() } else { 0 };
|
| 404 |
-
|
| 405 |
-
let sparse_features: Vec<SparseVec> = features
|
| 406 |
-
.par_iter()
|
| 407 |
-
.map(|dense| {
|
| 408 |
-
dense
|
| 409 |
-
.iter()
|
| 410 |
-
.enumerate()
|
| 411 |
-
.filter(|&(_, &v)| v.abs() > 1e-10)
|
| 412 |
-
.map(|(i, &v)| (i as u32, v as f32))
|
| 413 |
-
.collect()
|
| 414 |
-
})
|
| 415 |
-
.collect();
|
| 416 |
-
|
| 417 |
-
let mut classes: Vec<String> = labels.iter().cloned().collect();
|
| 418 |
-
classes.sort();
|
| 419 |
-
classes.dedup();
|
| 420 |
-
let n_classes = classes.len();
|
| 421 |
-
|
| 422 |
-
let class_to_idx: HashMap<String, usize> = classes
|
| 423 |
-
.iter()
|
| 424 |
-
.enumerate()
|
| 425 |
-
.map(|(i, c)| (c.clone(), i))
|
| 426 |
-
.collect();
|
| 427 |
-
|
| 428 |
-
let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
|
| 429 |
-
|
| 430 |
-
let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
|
| 431 |
-
.into_par_iter()
|
| 432 |
-
.map(|class_idx| {
|
| 433 |
-
let y_binary: Vec<i8> = y_idx
|
| 434 |
-
.iter()
|
| 435 |
-
.map(|&idx| if idx == class_idx { 1 } else { -1 })
|
| 436 |
-
.collect();
|
| 437 |
-
|
| 438 |
-
pegasos(&sparse_features, &y_binary, n_features, self.c as f32, self.max_iter)
|
| 439 |
-
})
|
| 440 |
-
.collect();
|
| 441 |
-
|
| 442 |
-
LinearSVM {
|
| 443 |
-
weights: results.iter().map(|(w, _)| w.clone()).collect(),
|
| 444 |
-
biases: results.iter().map(|(_, b)| *b).collect(),
|
| 445 |
-
classes,
|
| 446 |
-
n_features,
|
| 447 |
-
}
|
| 448 |
-
}
|
| 449 |
-
}
|
| 450 |
-
|
| 451 |
-
/// Pegasos algorithm with lazy scaling
|
| 452 |
-
#[inline(never)]
|
| 453 |
-
fn pegasos(
|
| 454 |
-
x: &[SparseVec],
|
| 455 |
-
y: &[i8],
|
| 456 |
-
n_features: usize,
|
| 457 |
-
c: f32,
|
| 458 |
-
max_iter: usize,
|
| 459 |
-
) -> (Vec<f32>, f32) {
|
| 460 |
-
let n = x.len();
|
| 461 |
-
let lambda = 1.0 / c;
|
| 462 |
-
|
| 463 |
-
let mut w = vec![0.0f32; n_features];
|
| 464 |
-
let mut scale = 1.0f32;
|
| 465 |
-
let mut b = 0.0f32;
|
| 466 |
-
|
| 467 |
-
let eta0 = 0.5;
|
| 468 |
-
let t0 = 1.0 / (eta0 * lambda);
|
| 469 |
-
|
| 470 |
-
let mut indices: Vec<usize> = (0..n).collect();
|
| 471 |
-
|
| 472 |
-
for epoch in 0..max_iter {
|
| 473 |
-
// Shuffle
|
| 474 |
-
for i in 0..n {
|
| 475 |
-
let j = (i + epoch * 1103515245 + 12345) % n;
|
| 476 |
-
indices.swap(i, j);
|
| 477 |
-
}
|
| 478 |
-
|
| 479 |
-
for (t_inner, &i) in indices.iter().enumerate() {
|
| 480 |
-
let t = (epoch * n + t_inner) as f32;
|
| 481 |
-
let eta = 1.0 / (lambda * (t + t0));
|
| 482 |
-
|
| 483 |
-
let yi = y[i] as f32;
|
| 484 |
-
let xi = &x[i];
|
| 485 |
-
|
| 486 |
-
let margin: f32 = scale * xi.iter().map(|&(j, v)| w[j as usize] * v).sum::<f32>() + b;
|
| 487 |
-
|
| 488 |
-
scale *= 1.0 - eta * lambda;
|
| 489 |
-
|
| 490 |
-
if scale < 1e-9 {
|
| 491 |
-
for wj in w.iter_mut() {
|
| 492 |
-
*wj *= scale;
|
| 493 |
-
}
|
| 494 |
-
scale = 1.0;
|
| 495 |
-
}
|
| 496 |
-
|
| 497 |
-
if yi * margin < 1.0 {
|
| 498 |
-
let update = eta / scale;
|
| 499 |
-
for &(j, v) in xi.iter() {
|
| 500 |
-
w[j as usize] += update * yi * v;
|
| 501 |
-
}
|
| 502 |
-
b += eta * yi * 0.1;
|
| 503 |
-
}
|
| 504 |
-
}
|
| 505 |
-
}
|
| 506 |
-
|
| 507 |
-
for wj in w.iter_mut() {
|
| 508 |
-
*wj *= scale;
|
| 509 |
-
}
|
| 510 |
-
|
| 511 |
-
(w, b)
|
| 512 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/src/tfidf.rs
DELETED
|
@@ -1,235 +0,0 @@
|
|
| 1 |
-
//! TF-IDF Vectorizer implementation
|
| 2 |
-
|
| 3 |
-
use hashbrown::HashMap;
|
| 4 |
-
use pyo3::prelude::*;
|
| 5 |
-
use rayon::prelude::*;
|
| 6 |
-
use serde::{Deserialize, Serialize};
|
| 7 |
-
use std::fs::File;
|
| 8 |
-
use std::io::{BufReader, BufWriter};
|
| 9 |
-
|
| 10 |
-
/// TF-IDF Vectorizer
|
| 11 |
-
///
|
| 12 |
-
/// Converts text documents into TF-IDF feature vectors.
|
| 13 |
-
#[pyclass]
|
| 14 |
-
#[derive(Clone, Serialize, Deserialize)]
|
| 15 |
-
pub struct TfIdfVectorizer {
|
| 16 |
-
/// Vocabulary: word -> index
|
| 17 |
-
vocab: HashMap<String, usize>,
|
| 18 |
-
/// Inverse vocabulary: index -> word
|
| 19 |
-
inv_vocab: Vec<String>,
|
| 20 |
-
/// IDF values for each term
|
| 21 |
-
idf: Vec<f64>,
|
| 22 |
-
/// Number of documents used for fitting
|
| 23 |
-
n_docs: usize,
|
| 24 |
-
/// Maximum number of features
|
| 25 |
-
max_features: usize,
|
| 26 |
-
/// N-gram range (min, max)
|
| 27 |
-
ngram_range: (usize, usize),
|
| 28 |
-
/// Minimum document frequency
|
| 29 |
-
min_df: usize,
|
| 30 |
-
/// Maximum document frequency (as ratio)
|
| 31 |
-
max_df: f64,
|
| 32 |
-
/// Whether the vectorizer is fitted
|
| 33 |
-
is_fitted: bool,
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
#[pymethods]
|
| 37 |
-
impl TfIdfVectorizer {
|
| 38 |
-
/// Create a new TfIdfVectorizer
|
| 39 |
-
#[new]
|
| 40 |
-
#[pyo3(signature = (max_features=20000, ngram_range=(1, 2), min_df=1, max_df=1.0))]
|
| 41 |
-
pub fn new(
|
| 42 |
-
max_features: usize,
|
| 43 |
-
ngram_range: (usize, usize),
|
| 44 |
-
min_df: usize,
|
| 45 |
-
max_df: f64,
|
| 46 |
-
) -> Self {
|
| 47 |
-
Self {
|
| 48 |
-
vocab: HashMap::new(),
|
| 49 |
-
inv_vocab: Vec::new(),
|
| 50 |
-
idf: Vec::new(),
|
| 51 |
-
n_docs: 0,
|
| 52 |
-
max_features,
|
| 53 |
-
ngram_range,
|
| 54 |
-
min_df,
|
| 55 |
-
max_df,
|
| 56 |
-
is_fitted: false,
|
| 57 |
-
}
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
/// Fit the vectorizer on a list of documents
|
| 61 |
-
pub fn fit(&mut self, documents: Vec<String>) {
|
| 62 |
-
let n_docs = documents.len();
|
| 63 |
-
self.n_docs = n_docs;
|
| 64 |
-
|
| 65 |
-
// Count document frequency for each term
|
| 66 |
-
let mut df: HashMap<String, usize> = HashMap::new();
|
| 67 |
-
|
| 68 |
-
for doc in &documents {
|
| 69 |
-
let tokens = self.tokenize(doc);
|
| 70 |
-
let unique_tokens: std::collections::HashSet<_> = tokens.into_iter().collect();
|
| 71 |
-
for token in unique_tokens {
|
| 72 |
-
*df.entry(token).or_insert(0) += 1;
|
| 73 |
-
}
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
// Filter by min_df and max_df
|
| 77 |
-
let max_df_count = (self.max_df * n_docs as f64) as usize;
|
| 78 |
-
let mut filtered: Vec<(String, usize)> = df
|
| 79 |
-
.into_iter()
|
| 80 |
-
.filter(|(_, count)| *count >= self.min_df && *count <= max_df_count)
|
| 81 |
-
.collect();
|
| 82 |
-
|
| 83 |
-
// Sort by frequency (descending) and take top max_features
|
| 84 |
-
filtered.sort_by(|a, b| b.1.cmp(&a.1));
|
| 85 |
-
filtered.truncate(self.max_features);
|
| 86 |
-
|
| 87 |
-
// Build vocabulary
|
| 88 |
-
self.vocab.clear();
|
| 89 |
-
self.inv_vocab.clear();
|
| 90 |
-
self.idf.clear();
|
| 91 |
-
|
| 92 |
-
for (idx, (term, doc_freq)) in filtered.into_iter().enumerate() {
|
| 93 |
-
self.vocab.insert(term.clone(), idx);
|
| 94 |
-
self.inv_vocab.push(term);
|
| 95 |
-
// IDF with smoothing: log((n_docs + 1) / (df + 1)) + 1
|
| 96 |
-
let idf_value = ((n_docs as f64 + 1.0) / (doc_freq as f64 + 1.0)).ln() + 1.0;
|
| 97 |
-
self.idf.push(idf_value);
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
self.is_fitted = true;
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
/// Transform a single document to TF-IDF vector (sparse format)
|
| 104 |
-
///
|
| 105 |
-
/// Returns list of (index, value) tuples
|
| 106 |
-
pub fn transform(&self, document: &str) -> Vec<(usize, f64)> {
|
| 107 |
-
if !self.is_fitted {
|
| 108 |
-
return Vec::new();
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
let tokens = self.tokenize(document);
|
| 112 |
-
let mut tf: HashMap<usize, usize> = HashMap::new();
|
| 113 |
-
|
| 114 |
-
for token in &tokens {
|
| 115 |
-
if let Some(&idx) = self.vocab.get(token) {
|
| 116 |
-
*tf.entry(idx).or_insert(0) += 1;
|
| 117 |
-
}
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
let n_tokens = tokens.len() as f64;
|
| 121 |
-
if n_tokens == 0.0 {
|
| 122 |
-
return Vec::new();
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
let mut result: Vec<(usize, f64)> = tf
|
| 126 |
-
.into_iter()
|
| 127 |
-
.map(|(idx, count)| {
|
| 128 |
-
let tf_value = count as f64 / n_tokens;
|
| 129 |
-
let tfidf = tf_value * self.idf[idx];
|
| 130 |
-
(idx, tfidf)
|
| 131 |
-
})
|
| 132 |
-
.collect();
|
| 133 |
-
|
| 134 |
-
// L2 normalize
|
| 135 |
-
let norm: f64 = result.iter().map(|(_, v)| v * v).sum::<f64>().sqrt();
|
| 136 |
-
if norm > 0.0 {
|
| 137 |
-
for (_, v) in &mut result {
|
| 138 |
-
*v /= norm;
|
| 139 |
-
}
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
result.sort_by_key(|(idx, _)| *idx);
|
| 143 |
-
result
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
/// Transform a single document to dense TF-IDF vector
|
| 147 |
-
pub fn transform_dense(&self, document: &str) -> Vec<f64> {
|
| 148 |
-
let sparse = self.transform(document);
|
| 149 |
-
let mut dense = vec![0.0; self.vocab.len()];
|
| 150 |
-
for (idx, val) in sparse {
|
| 151 |
-
dense[idx] = val;
|
| 152 |
-
}
|
| 153 |
-
dense
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
/// Transform multiple documents to dense TF-IDF vectors (parallel)
|
| 157 |
-
pub fn transform_batch(&self, documents: Vec<String>) -> Vec<Vec<f64>> {
|
| 158 |
-
documents
|
| 159 |
-
.par_iter()
|
| 160 |
-
.map(|doc| self.transform_dense(doc))
|
| 161 |
-
.collect()
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
/// Transform multiple documents to sparse TF-IDF vectors (parallel)
|
| 165 |
-
pub fn transform_batch_sparse(&self, documents: Vec<String>) -> Vec<Vec<(usize, f64)>> {
|
| 166 |
-
documents
|
| 167 |
-
.par_iter()
|
| 168 |
-
.map(|doc| self.transform(doc))
|
| 169 |
-
.collect()
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
/// Fit and transform in one step
|
| 173 |
-
pub fn fit_transform(&mut self, documents: Vec<String>) -> Vec<Vec<f64>> {
|
| 174 |
-
self.fit(documents.clone());
|
| 175 |
-
self.transform_batch(documents)
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
/// Get vocabulary size
|
| 179 |
-
#[getter]
|
| 180 |
-
pub fn vocab_size(&self) -> usize {
|
| 181 |
-
self.vocab.len()
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
/// Get feature names (vocabulary terms)
|
| 185 |
-
pub fn get_feature_names(&self) -> Vec<String> {
|
| 186 |
-
self.inv_vocab.clone()
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
/// Check if vectorizer is fitted
|
| 190 |
-
#[getter]
|
| 191 |
-
pub fn is_fitted(&self) -> bool {
|
| 192 |
-
self.is_fitted
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
/// Save vectorizer to file
|
| 196 |
-
pub fn save(&self, path: &str) -> PyResult<()> {
|
| 197 |
-
let file = File::create(path)
|
| 198 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 199 |
-
let writer = BufWriter::new(file);
|
| 200 |
-
serde_json::to_writer(writer, self)
|
| 201 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 202 |
-
Ok(())
|
| 203 |
-
}
|
| 204 |
-
|
| 205 |
-
/// Load vectorizer from file
|
| 206 |
-
#[staticmethod]
|
| 207 |
-
pub fn load(path: &str) -> PyResult<Self> {
|
| 208 |
-
let file = File::open(path)
|
| 209 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 210 |
-
let reader = BufReader::new(file);
|
| 211 |
-
let vectorizer: Self = serde_json::from_reader(reader)
|
| 212 |
-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 213 |
-
Ok(vectorizer)
|
| 214 |
-
}
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
impl TfIdfVectorizer {
|
| 218 |
-
/// Tokenize document into n-grams
|
| 219 |
-
fn tokenize(&self, document: &str) -> Vec<String> {
|
| 220 |
-
let words: Vec<&str> = document.split_whitespace().collect();
|
| 221 |
-
let mut tokens = Vec::new();
|
| 222 |
-
|
| 223 |
-
for n in self.ngram_range.0..=self.ngram_range.1 {
|
| 224 |
-
if n > words.len() {
|
| 225 |
-
continue;
|
| 226 |
-
}
|
| 227 |
-
for i in 0..=(words.len() - n) {
|
| 228 |
-
let ngram = words[i..i + n].join(" ");
|
| 229 |
-
tokens.push(ngram);
|
| 230 |
-
}
|
| 231 |
-
}
|
| 232 |
-
|
| 233 |
-
tokens
|
| 234 |
-
}
|
| 235 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extensions/underthesea_core_extend/uv.lock
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
version = 1
|
| 2 |
-
revision = 3
|
| 3 |
-
requires-python = ">=3.10"
|
| 4 |
-
|
| 5 |
-
[[package]]
|
| 6 |
-
name = "underthesea-core-extend"
|
| 7 |
-
version = "0.1.0"
|
| 8 |
-
source = { editable = "." }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
[project]
|
| 2 |
name = "sen"
|
| 3 |
version = "1.1.0"
|
| 4 |
-
description = "Vietnamese Text Classification
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.10"
|
| 7 |
license = "Apache-2.0"
|
|
@@ -10,14 +10,16 @@ authors = [
|
|
| 10 |
]
|
| 11 |
keywords = ["vietnamese", "nlp", "text-classification", "rust", "svm"]
|
| 12 |
dependencies = [
|
| 13 |
-
"
|
|
|
|
| 14 |
]
|
| 15 |
|
| 16 |
[project.optional-dependencies]
|
| 17 |
dev = [
|
| 18 |
"pytest>=7.0.0",
|
| 19 |
"huggingface-hub>=0.20.0",
|
| 20 |
-
"
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
[project.urls]
|
|
@@ -27,6 +29,3 @@ Repository = "https://github.com/undertheseanlp/sen"
|
|
| 27 |
[build-system]
|
| 28 |
requires = ["hatchling"]
|
| 29 |
build-backend = "hatchling.build"
|
| 30 |
-
|
| 31 |
-
[tool.hatch.build.targets.wheel]
|
| 32 |
-
packages = ["src/sen"]
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "sen"
|
| 3 |
version = "1.1.0"
|
| 4 |
+
description = "Vietnamese Text Classification - Training scripts for underthesea_core"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.10"
|
| 7 |
license = "Apache-2.0"
|
|
|
|
| 10 |
]
|
| 11 |
keywords = ["vietnamese", "nlp", "text-classification", "rust", "svm"]
|
| 12 |
dependencies = [
|
| 13 |
+
"underthesea>=6.0.0",
|
| 14 |
+
"click>=8.0.0",
|
| 15 |
]
|
| 16 |
|
| 17 |
[project.optional-dependencies]
|
| 18 |
dev = [
|
| 19 |
"pytest>=7.0.0",
|
| 20 |
"huggingface-hub>=0.20.0",
|
| 21 |
+
"scikit-learn>=1.0.0",
|
| 22 |
+
"datasets>=2.0.0",
|
| 23 |
]
|
| 24 |
|
| 25 |
[project.urls]
|
|
|
|
| 29 |
[build-system]
|
| 30 |
requires = ["hatchling"]
|
| 31 |
build-backend = "hatchling.build"
|
|
|
|
|
|
|
|
|
src/bench.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark CLI for Vietnamese Text Classification.
|
| 3 |
+
|
| 4 |
+
Compares Rust TextClassifier vs sklearn.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python bench.py vntc
|
| 8 |
+
python bench.py bank
|
| 9 |
+
python bench.py synthetic
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import time
|
| 14 |
+
import random
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import click
|
| 18 |
+
from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVectorizer
|
| 19 |
+
from sklearn.svm import LinearSVC as SklearnLinearSVC
|
| 20 |
+
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
| 21 |
+
|
| 22 |
+
from underthesea_core import TextClassifier
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_file(filepath):
|
| 26 |
+
"""Read text file with multiple encoding attempts."""
|
| 27 |
+
for enc in ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']:
|
| 28 |
+
try:
|
| 29 |
+
with open(filepath, 'r', encoding=enc) as f:
|
| 30 |
+
text = ' '.join(f.read().split())
|
| 31 |
+
if len(text) > 10:
|
| 32 |
+
return text
|
| 33 |
+
except (UnicodeDecodeError, UnicodeError):
|
| 34 |
+
continue
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=20000):
|
| 39 |
+
"""Benchmark scikit-learn TF-IDF + LinearSVC."""
|
| 40 |
+
click.echo("\n" + "=" * 70)
|
| 41 |
+
click.echo("scikit-learn: TfidfVectorizer + LinearSVC")
|
| 42 |
+
click.echo("=" * 70)
|
| 43 |
+
|
| 44 |
+
# Vectorize
|
| 45 |
+
click.echo(" Vectorizing...")
|
| 46 |
+
t0 = time.perf_counter()
|
| 47 |
+
vectorizer = SklearnTfidfVectorizer(max_features=max_features, ngram_range=(1, 2), min_df=2)
|
| 48 |
+
X_train = vectorizer.fit_transform(train_texts)
|
| 49 |
+
X_test = vectorizer.transform(test_texts)
|
| 50 |
+
vec_time = time.perf_counter() - t0
|
| 51 |
+
click.echo(f" Vectorization time: {vec_time:.2f}s")
|
| 52 |
+
click.echo(f" Vocabulary size: {len(vectorizer.vocabulary_)}")
|
| 53 |
+
|
| 54 |
+
# Train
|
| 55 |
+
click.echo(" Training LinearSVC...")
|
| 56 |
+
t0 = time.perf_counter()
|
| 57 |
+
clf = SklearnLinearSVC(C=1.0, max_iter=2000)
|
| 58 |
+
clf.fit(X_train, train_labels)
|
| 59 |
+
train_time = time.perf_counter() - t0
|
| 60 |
+
click.echo(f" Training time: {train_time:.2f}s")
|
| 61 |
+
|
| 62 |
+
# End-to-end inference
|
| 63 |
+
click.echo(" End-to-end inference...")
|
| 64 |
+
t0 = time.perf_counter()
|
| 65 |
+
X_test_e2e = vectorizer.transform(test_texts)
|
| 66 |
+
preds = clf.predict(X_test_e2e)
|
| 67 |
+
e2e_time = time.perf_counter() - t0
|
| 68 |
+
e2e_throughput = len(test_texts) / e2e_time
|
| 69 |
+
click.echo(f" E2E time: {e2e_time:.2f}s ({e2e_throughput:.0f} samples/sec)")
|
| 70 |
+
|
| 71 |
+
# Metrics
|
| 72 |
+
acc = accuracy_score(test_labels, preds)
|
| 73 |
+
f1_w = f1_score(test_labels, preds, average='weighted')
|
| 74 |
+
click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")
|
| 75 |
+
|
| 76 |
+
return {
|
| 77 |
+
"total_train": vec_time + train_time,
|
| 78 |
+
"e2e_throughput": e2e_throughput,
|
| 79 |
+
"accuracy": acc,
|
| 80 |
+
"f1_weighted": f1_w,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=20000):
|
| 85 |
+
"""Benchmark Rust TextClassifier."""
|
| 86 |
+
click.echo("\n" + "=" * 70)
|
| 87 |
+
click.echo("Rust: TextClassifier (underthesea_core)")
|
| 88 |
+
click.echo("=" * 70)
|
| 89 |
+
|
| 90 |
+
clf = TextClassifier(
|
| 91 |
+
max_features=max_features,
|
| 92 |
+
ngram_range=(1, 2),
|
| 93 |
+
min_df=2,
|
| 94 |
+
c=1.0,
|
| 95 |
+
max_iter=1000,
|
| 96 |
+
tol=0.1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Train
|
| 100 |
+
click.echo(" Training...")
|
| 101 |
+
t0 = time.perf_counter()
|
| 102 |
+
clf.fit(list(train_texts), list(train_labels))
|
| 103 |
+
train_time = time.perf_counter() - t0
|
| 104 |
+
click.echo(f" Training time: {train_time:.2f}s")
|
| 105 |
+
click.echo(f" Vocabulary size: {clf.n_features}")
|
| 106 |
+
|
| 107 |
+
# Inference
|
| 108 |
+
click.echo(" Inference...")
|
| 109 |
+
t0 = time.perf_counter()
|
| 110 |
+
preds = clf.predict_batch(list(test_texts))
|
| 111 |
+
infer_time = time.perf_counter() - t0
|
| 112 |
+
throughput = len(test_texts) / infer_time
|
| 113 |
+
click.echo(f" Inference time: {infer_time:.2f}s ({throughput:.0f} samples/sec)")
|
| 114 |
+
|
| 115 |
+
# Metrics
|
| 116 |
+
acc = accuracy_score(test_labels, preds)
|
| 117 |
+
f1_w = f1_score(test_labels, preds, average='weighted')
|
| 118 |
+
click.echo(f" Results: Accuracy={acc:.4f}, F1={f1_w:.4f}")
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"total_train": train_time,
|
| 122 |
+
"throughput": throughput,
|
| 123 |
+
"accuracy": acc,
|
| 124 |
+
"f1_weighted": f1_w,
|
| 125 |
+
"clf": clf,
|
| 126 |
+
"preds": preds,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def print_comparison(sklearn_results, rust_results):
|
| 131 |
+
"""Print comparison summary."""
|
| 132 |
+
click.echo("\n" + "=" * 70)
|
| 133 |
+
click.echo("COMPARISON SUMMARY")
|
| 134 |
+
click.echo("=" * 70)
|
| 135 |
+
click.echo(f"{'Metric':<30} {'sklearn':<20} {'Rust':<20}")
|
| 136 |
+
click.echo("-" * 70)
|
| 137 |
+
|
| 138 |
+
click.echo(f"{'Training time (s)':<30} {sklearn_results['total_train']:<20.2f} {rust_results['total_train']:<20.2f}")
|
| 139 |
+
click.echo(f"{'Inference (samples/sec)':<30} {sklearn_results['e2e_throughput']:<20.0f} {rust_results['throughput']:<20.0f}")
|
| 140 |
+
click.echo(f"{'Accuracy':<30} {sklearn_results['accuracy']:<20.4f} {rust_results['accuracy']:<20.4f}")
|
| 141 |
+
click.echo(f"{'F1 (weighted)':<30} {sklearn_results['f1_weighted']:<20.4f} {rust_results['f1_weighted']:<20.4f}")
|
| 142 |
+
|
| 143 |
+
click.echo("-" * 70)
|
| 144 |
+
train_speedup = sklearn_results['total_train'] / rust_results['total_train'] if rust_results['total_train'] > 0 else 0
|
| 145 |
+
infer_speedup = rust_results['throughput'] / sklearn_results['e2e_throughput'] if sklearn_results['e2e_throughput'] > 0 else 0
|
| 146 |
+
click.echo(f"Speedup: Training {train_speedup:.2f}x, Inference {infer_speedup:.2f}x")
|
| 147 |
+
click.echo("=" * 70)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@click.group()
|
| 151 |
+
def cli():
|
| 152 |
+
"""Benchmark Vietnamese text classification models."""
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@cli.command()
|
| 157 |
+
@click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
|
| 158 |
+
help='Path to VNTC dataset')
|
| 159 |
+
@click.option('--save-model', is_flag=True, help='Save the trained Rust model')
|
| 160 |
+
@click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
|
| 161 |
+
def vntc(data_dir, save_model, output):
|
| 162 |
+
"""Benchmark on VNTC dataset (10 topics, ~84k documents)."""
|
| 163 |
+
click.echo("=" * 70)
|
| 164 |
+
click.echo("VNTC Full Dataset Benchmark")
|
| 165 |
+
click.echo("Vietnamese News Text Classification (10 Topics)")
|
| 166 |
+
click.echo("=" * 70)
|
| 167 |
+
|
| 168 |
+
train_dir = os.path.join(data_dir, "Train_Full")
|
| 169 |
+
test_dir = os.path.join(data_dir, "Test_Full")
|
| 170 |
+
|
| 171 |
+
# Load data
|
| 172 |
+
click.echo("\nLoading training data...")
|
| 173 |
+
t0 = time.perf_counter()
|
| 174 |
+
train_texts, train_labels = [], []
|
| 175 |
+
for folder in sorted(os.listdir(train_dir)):
|
| 176 |
+
folder_path = os.path.join(train_dir, folder)
|
| 177 |
+
if not os.path.isdir(folder_path):
|
| 178 |
+
continue
|
| 179 |
+
for fname in os.listdir(folder_path):
|
| 180 |
+
if fname.endswith('.txt'):
|
| 181 |
+
text = read_file(os.path.join(folder_path, fname))
|
| 182 |
+
if text:
|
| 183 |
+
train_texts.append(text)
|
| 184 |
+
train_labels.append(folder)
|
| 185 |
+
click.echo(f" Loaded {len(train_texts)} training samples in {time.perf_counter()-t0:.1f}s")
|
| 186 |
+
|
| 187 |
+
click.echo("Loading test data...")
|
| 188 |
+
t0 = time.perf_counter()
|
| 189 |
+
test_texts, test_labels = [], []
|
| 190 |
+
for folder in sorted(os.listdir(test_dir)):
|
| 191 |
+
folder_path = os.path.join(test_dir, folder)
|
| 192 |
+
if not os.path.isdir(folder_path):
|
| 193 |
+
continue
|
| 194 |
+
for fname in os.listdir(folder_path):
|
| 195 |
+
if fname.endswith('.txt'):
|
| 196 |
+
text = read_file(os.path.join(folder_path, fname))
|
| 197 |
+
if text:
|
| 198 |
+
test_texts.append(text)
|
| 199 |
+
test_labels.append(folder)
|
| 200 |
+
click.echo(f" Loaded {len(test_texts)} test samples in {time.perf_counter()-t0:.1f}s")
|
| 201 |
+
|
| 202 |
+
# Run benchmarks
|
| 203 |
+
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels)
|
| 204 |
+
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels)
|
| 205 |
+
|
| 206 |
+
print_comparison(sklearn_results, rust_results)
|
| 207 |
+
|
| 208 |
+
if save_model:
|
| 209 |
+
model_path = Path(output)
|
| 210 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 211 |
+
rust_results['clf'].save(str(model_path))
|
| 212 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 213 |
+
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@cli.command()
|
| 217 |
+
@click.option('--save-model', is_flag=True, help='Save the trained Rust model')
|
| 218 |
+
@click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
|
| 219 |
+
def bank(save_model, output):
|
| 220 |
+
"""Benchmark on UTS2017_Bank dataset (14 categories, banking domain)."""
|
| 221 |
+
from datasets import load_dataset
|
| 222 |
+
|
| 223 |
+
click.echo("=" * 70)
|
| 224 |
+
click.echo("UTS2017_Bank Dataset Benchmark")
|
| 225 |
+
click.echo("Vietnamese Banking Domain Text Classification (14 Categories)")
|
| 226 |
+
click.echo("=" * 70)
|
| 227 |
+
|
| 228 |
+
# Load data
|
| 229 |
+
click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
|
| 230 |
+
dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
|
| 231 |
+
|
| 232 |
+
train_texts = list(dataset["train"]["text"])
|
| 233 |
+
train_labels = list(dataset["train"]["label"])
|
| 234 |
+
test_texts = list(dataset["test"]["text"])
|
| 235 |
+
test_labels = list(dataset["test"]["label"])
|
| 236 |
+
|
| 237 |
+
click.echo(f" Train samples: {len(train_texts)}")
|
| 238 |
+
click.echo(f" Test samples: {len(test_texts)}")
|
| 239 |
+
click.echo(f" Categories: {len(set(train_labels))}")
|
| 240 |
+
|
| 241 |
+
# Run benchmarks (smaller max_features for smaller dataset)
|
| 242 |
+
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
|
| 243 |
+
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)
|
| 244 |
+
|
| 245 |
+
print_comparison(sklearn_results, rust_results)
|
| 246 |
+
|
| 247 |
+
click.echo("\nClassification Report (Rust):")
|
| 248 |
+
click.echo(classification_report(test_labels, rust_results['preds']))
|
| 249 |
+
|
| 250 |
+
if save_model:
|
| 251 |
+
model_path = Path(output)
|
| 252 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
rust_results['clf'].save(str(model_path))
|
| 254 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 255 |
+
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@cli.command()
|
| 259 |
+
@click.option('--train-per-cat', default=340, help='Training samples per category')
|
| 260 |
+
@click.option('--test-per-cat', default=500, help='Test samples per category')
|
| 261 |
+
@click.option('--seed', default=42, help='Random seed')
|
| 262 |
+
def synthetic(train_per_cat, test_per_cat, seed):
|
| 263 |
+
"""Benchmark on synthetic VNTC-like data."""
|
| 264 |
+
# Vietnamese text templates by category
|
| 265 |
+
TEMPLATES = {
|
| 266 |
+
"the_thao": ["Đội tuyển {} thắng {} với tỷ số {}", "Cầu thủ {} ghi bàn đẹp mắt"],
|
| 267 |
+
"kinh_doanh": ["Chứng khoán {} điểm trong phiên giao dịch", "Ngân hàng {} công bố lãi suất {}"],
|
| 268 |
+
"cong_nghe": ["Apple ra mắt {} với nhiều tính năng", "Trí tuệ nhân tạo đang thay đổi {}"],
|
| 269 |
+
"chinh_tri": ["Quốc hội thông qua nghị quyết về {}", "Chủ tịch {} tiếp đón phái đoàn"],
|
| 270 |
+
"van_hoa": ["Nghệ sĩ {} ra mắt album mới", "Liên hoan phim {} trao giải"],
|
| 271 |
+
"khoa_hoc": ["Nhà khoa học phát hiện {} mới", "Nghiên cứu cho thấy {} có tác dụng"],
|
| 272 |
+
"suc_khoe": ["Bộ Y tế cảnh báo về {} trong mùa", "Vaccine {} đạt hiệu quả cao"],
|
| 273 |
+
"giao_duc": ["Trường {} công bố điểm chuẩn", "Học sinh đoạt huy chương tại Olympic"],
|
| 274 |
+
"phap_luat": ["Tòa án xét xử vụ án {} với bị cáo", "Công an triệt phá đường dây"],
|
| 275 |
+
"doi_song": ["Giá {} tăng trong tháng", "Người dân đổ xô đi mua {}"],
|
| 276 |
+
}
|
| 277 |
+
FILLS = {
|
| 278 |
+
"the_thao": ["Việt Nam", "Thái Lan", "3-0", "bóng đá", "AFF Cup"],
|
| 279 |
+
"kinh_doanh": ["tăng", "giảm", "VN-Index", "Vietcombank", "8%"],
|
| 280 |
+
"cong_nghe": ["iPhone 16", "ChatGPT", "công việc", "VinAI", "5G"],
|
| 281 |
+
"chinh_tri": ["kinh tế", "nước", "Trung Quốc", "Hà Nội", "phát triển"],
|
| 282 |
+
"van_hoa": ["Mỹ Tâm", "Cannes", "nghệ thuật", "Hà Nội", "Bố Già"],
|
| 283 |
+
"khoa_hoc": ["loài sinh vật", "trà xanh", "VNREDSat-1", "Nobel", "robot"],
|
| 284 |
+
"suc_khoe": ["dịch cúm", "COVID-19", "Bạch Mai", "dinh dưỡng", "tiểu đường"],
|
| 285 |
+
"giao_duc": ["Bách Khoa", "Việt Nam", "Toán", "THPT", "STEM"],
|
| 286 |
+
"phap_luat": ["tham nhũng", "TP.HCM", "ma túy", "Hình sự", "gian lận"],
|
| 287 |
+
"doi_song": ["xăng", "vàng", "nắng nóng", "Trung thu", "bún chả"],
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def generate_sample(category):
|
| 291 |
+
template = random.choice(TEMPLATES[category])
|
| 292 |
+
fills = FILLS[category]
|
| 293 |
+
n = template.count("{}")
|
| 294 |
+
return template.format(*random.choices(fills, k=n))
|
| 295 |
+
|
| 296 |
+
def generate_dataset(n_per_cat, categories):
|
| 297 |
+
texts, labels = [], []
|
| 298 |
+
for cat in categories:
|
| 299 |
+
for _ in range(n_per_cat):
|
| 300 |
+
texts.append(generate_sample(cat))
|
| 301 |
+
labels.append(cat)
|
| 302 |
+
combined = list(zip(texts, labels))
|
| 303 |
+
random.shuffle(combined)
|
| 304 |
+
return [t for t, _ in combined], [l for _, l in combined]
|
| 305 |
+
|
| 306 |
+
click.echo("=" * 70)
|
| 307 |
+
click.echo("Synthetic VNTC-like Benchmark")
|
| 308 |
+
click.echo("=" * 70)
|
| 309 |
+
|
| 310 |
+
random.seed(seed)
|
| 311 |
+
categories = list(TEMPLATES.keys())
|
| 312 |
+
|
| 313 |
+
click.echo(f"\nConfiguration:")
|
| 314 |
+
click.echo(f" Categories: {len(categories)}")
|
| 315 |
+
click.echo(f" Train samples: {train_per_cat * len(categories)}")
|
| 316 |
+
click.echo(f" Test samples: {test_per_cat * len(categories)}")
|
| 317 |
+
|
| 318 |
+
train_texts, train_labels = generate_dataset(train_per_cat, categories)
|
| 319 |
+
test_texts, test_labels = generate_dataset(test_per_cat, categories)
|
| 320 |
+
|
| 321 |
+
sklearn_results = benchmark_sklearn(train_texts, train_labels, test_texts, test_labels, max_features=10000)
|
| 322 |
+
rust_results = benchmark_rust(train_texts, train_labels, test_texts, test_labels, max_features=10000)
|
| 323 |
+
|
| 324 |
+
print_comparison(sklearn_results, rust_results)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
cli()
|
src/scripts/train.py
DELETED
|
@@ -1,221 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Train Sen Text Classifier on Vietnamese news data.
|
| 3 |
-
|
| 4 |
-
Categories based on VNTC dataset:
|
| 5 |
-
- Chinh tri Xa hoi (Politics/Society)
|
| 6 |
-
- Doi song (Lifestyle)
|
| 7 |
-
- Khoa hoc (Science)
|
| 8 |
-
- Kinh doanh (Business)
|
| 9 |
-
- Phap luat (Law)
|
| 10 |
-
- Suc khoe (Health)
|
| 11 |
-
- The gioi (World)
|
| 12 |
-
- The thao (Sports)
|
| 13 |
-
- Van hoa (Culture)
|
| 14 |
-
- Vi tinh (Technology)
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
import json
|
| 18 |
-
import os
|
| 19 |
-
import sys
|
| 20 |
-
|
| 21 |
-
sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
|
| 22 |
-
|
| 23 |
-
from sen import SenTextClassifier
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# Sample Vietnamese news data for each category
|
| 27 |
-
SAMPLE_DATA = {
|
| 28 |
-
"chinh_tri_xa_hoi": [
|
| 29 |
-
"Quốc hội thông qua nghị quyết về phát triển kinh tế xã hội",
|
| 30 |
-
"Thủ tướng chủ trì họp Chính phủ thường kỳ tháng này",
|
| 31 |
-
"Đại hội Đảng toàn quốc lần thứ XIII thành công tốt đẹp",
|
| 32 |
-
"Chủ tịch nước tiếp đoàn đại biểu quốc tế",
|
| 33 |
-
"Bộ Nội vụ triển khai cải cách hành chính",
|
| 34 |
-
"Ủy ban Mặt trận Tổ quốc tổ chức hội nghị toàn quốc",
|
| 35 |
-
"Đoàn thanh niên phát động phong trào tình nguyện",
|
| 36 |
-
"Hội Liên hiệp Phụ nữ tổ chức đại hội đại biểu",
|
| 37 |
-
],
|
| 38 |
-
"doi_song": [
|
| 39 |
-
"Mẹo hay giúp tiết kiệm chi phí sinh hoạt hàng ngày",
|
| 40 |
-
"Xu hướng thời trang mới nhất mùa thu đông năm nay",
|
| 41 |
-
"Cách trang trí nhà cửa đón Tết đẹp và tiết kiệm",
|
| 42 |
-
"Bí quyết nấu ăn ngon cho cả gia đình",
|
| 43 |
-
"Kinh nghiệm du lịch Đà Nẵng tiết kiệm chi phí",
|
| 44 |
-
"Cách chăm sóc cây cảnh trong nhà hiệu quả",
|
| 45 |
-
"Chia sẻ cách dạy con học tập hiệu quả",
|
| 46 |
-
"Mẹo vặt hay cho cuộc sống hàng ngày",
|
| 47 |
-
],
|
| 48 |
-
"khoa_hoc": [
|
| 49 |
-
"Các nhà khoa học phát hiện hành tinh mới ngoài hệ mặt trời",
|
| 50 |
-
"Nghiên cứu mới về biến đổi khí hậu toàn cầu",
|
| 51 |
-
"Vệ tinh nhân tạo được phóng thành công lên quỹ đạo",
|
| 52 |
-
"Khám phá mới về nguồn gốc vũ trụ",
|
| 53 |
-
"Công nghệ nano ứng dụng trong y học",
|
| 54 |
-
"Phát hiện loài động vật mới ở rừng Amazon",
|
| 55 |
-
"Nghiên cứu về trí tuệ nhân tạo và học máy",
|
| 56 |
-
"Thí nghiệm vật lý hạt nhân tại CERN",
|
| 57 |
-
],
|
| 58 |
-
"kinh_doanh": [
|
| 59 |
-
"Chứng khoán Việt Nam tăng điểm mạnh phiên đầu tuần",
|
| 60 |
-
"Ngân hàng Nhà nước điều chỉnh lãi suất điều hành",
|
| 61 |
-
"Doanh nghiệp xuất khẩu gặp khó khăn do tỷ giá",
|
| 62 |
-
"Thị trường bất động sản có dấu hiệu phục hồi",
|
| 63 |
-
"VN-Index vượt mốc 1200 điểm trong phiên giao dịch",
|
| 64 |
-
"FDI vào Việt Nam tăng trưởng ấn tượng",
|
| 65 |
-
"Giá vàng thế giới biến động mạnh trong tuần",
|
| 66 |
-
"Startup công nghệ Việt gọi vốn thành công Series A",
|
| 67 |
-
],
|
| 68 |
-
"phap_luat": [
|
| 69 |
-
"Tòa án xét xử vụ án tham nhũng lớn",
|
| 70 |
-
"Công an triệt phá đường dây buôn lậu xuyên quốc gia",
|
| 71 |
-
"Luật mới về bảo vệ môi trường có hiệu lực",
|
| 72 |
-
"Khởi tố vụ án lừa đảo chiếm đoạt tài sản",
|
| 73 |
-
"Bộ Công an cảnh báo thủ đoạn lừa đảo qua mạng",
|
| 74 |
-
"Tòa án tuyên án vụ vi phạm an toàn giao thông",
|
| 75 |
-
"Viện Kiểm sát truy tố các bị cáo trong vụ án kinh tế",
|
| 76 |
-
"Cảnh sát giao thông xử lý vi phạm nồng độ cồn",
|
| 77 |
-
],
|
| 78 |
-
"suc_khoe": [
|
| 79 |
-
"Bộ Y tế khuyến cáo phòng chống dịch bệnh mùa đông",
|
| 80 |
-
"Phát hiện phương pháp điều trị ung thư mới",
|
| 81 |
-
"Cách phòng tránh các bệnh về đường hô hấp",
|
| 82 |
-
"Chế độ ăn uống lành mạnh cho người cao tuổi",
|
| 83 |
-
"Vaccine mới được phê duyệt sử dụng tại Việt Nam",
|
| 84 |
-
"Bệnh viện Bạch Mai ứng dụng kỹ thuật mổ nội soi",
|
| 85 |
-
"Cách chăm sóc sức khỏe tinh thần hiệu quả",
|
| 86 |
-
"Tập thể dục đúng cách để có sức khỏe tốt",
|
| 87 |
-
],
|
| 88 |
-
"the_gioi": [
|
| 89 |
-
"Tổng thống Mỹ công bố chính sách đối ngoại mới",
|
| 90 |
-
"Hội nghị thượng đỉnh G20 thảo luận về biến đổi khí hậu",
|
| 91 |
-
"Xung đột vũ trang leo thang tại Trung Đông",
|
| 92 |
-
"Liên Hợp Quốc họp khẩn về tình hình nhân đạo",
|
| 93 |
-
"Châu Âu đối mặt với khủng hoảng năng lượng",
|
| 94 |
-
"Trung Quốc công bố số liệu tăng trưởng kinh tế",
|
| 95 |
-
"Nhật Bản bầu cử thủ tướng mới",
|
| 96 |
-
"Nga và Ukraine tiếp tục đàm phán hòa bình",
|
| 97 |
-
],
|
| 98 |
-
"the_thao": [
|
| 99 |
-
"Đội tuyển Việt Nam thắng đậm trong trận giao hữu",
|
| 100 |
-
"Cầu thủ Nguyễn Quang Hải ghi bàn đẹp mắt",
|
| 101 |
-
"V-League 2024 khởi tranh vào tháng tới",
|
| 102 |
-
"HLV Park Hang-seo chia tay bóng đá Việt Nam",
|
| 103 |
-
"U23 Việt Nam vô địch giải Đông Nam Á",
|
| 104 |
-
"Hoàng Xuân Vinh giành huy chương vàng Olympic",
|
| 105 |
-
"Ánh Viên phá kỷ lục quốc gia môn bơi lội",
|
| 106 |
-
"SEA Games 31 tổ chức thành công tại Việt Nam",
|
| 107 |
-
],
|
| 108 |
-
"van_hoa": [
|
| 109 |
-
"Liên hoan phim Việt Nam lần thứ 23 khai mạc",
|
| 110 |
-
"Nghệ sĩ nhân dân được phong tặng danh hiệu cao quý",
|
| 111 |
-
"Triển lãm tranh của họa sĩ nổi tiếng tại Hà Nội",
|
| 112 |
-
"Ca sĩ Việt Nam giành giải thưởng âm nhạc châu Á",
|
| 113 |
-
"Lễ hội truyền thống thu hút đông đảo du khách",
|
| 114 |
-
"Phim Việt Nam được đề cử tại liên hoan phim quốc tế",
|
| 115 |
-
"Nhạc sĩ sáng tác ca khúc mới về quê hương",
|
| 116 |
-
"Bảo tàng lịch sử khai trương triển lãm mới",
|
| 117 |
-
],
|
| 118 |
-
"vi_tinh": [
|
| 119 |
-
"Apple ra mắt iPhone mới với nhiều tính năng",
|
| 120 |
-
"Trí tuệ nhân tạo đang thay đổi cuộc sống",
|
| 121 |
-
"Samsung công bố điện thoại gập thế hệ mới",
|
| 122 |
-
"Microsoft phát hành bản cập nhật Windows",
|
| 123 |
-
"ChatGPT và cuộc cách mạng trí tuệ nhân tạo",
|
| 124 |
-
"5G được triển khai rộng rãi tại các thành phố lớn",
|
| 125 |
-
"Ứng dụng di động phổ biến nhất trong năm",
|
| 126 |
-
"An ninh mạng trước nguy cơ tấn công hacker",
|
| 127 |
-
],
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def prepare_data():
|
| 132 |
-
"""Prepare training and validation data."""
|
| 133 |
-
train_texts = []
|
| 134 |
-
train_labels = []
|
| 135 |
-
val_texts = []
|
| 136 |
-
val_labels = []
|
| 137 |
-
|
| 138 |
-
for label, texts in SAMPLE_DATA.items():
|
| 139 |
-
# Use first 6 for training, last 2 for validation
|
| 140 |
-
for text in texts[:6]:
|
| 141 |
-
train_texts.append(text)
|
| 142 |
-
train_labels.append(label)
|
| 143 |
-
for text in texts[6:]:
|
| 144 |
-
val_texts.append(text)
|
| 145 |
-
val_labels.append(label)
|
| 146 |
-
|
| 147 |
-
return train_texts, train_labels, val_texts, val_labels
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def main():
|
| 151 |
-
print("=" * 60)
|
| 152 |
-
print("Training Sen Text Classifier")
|
| 153 |
-
print("Based on VNTC Vietnamese News Classification")
|
| 154 |
-
print("=" * 60)
|
| 155 |
-
|
| 156 |
-
# Prepare data
|
| 157 |
-
train_texts, train_labels, val_texts, val_labels = prepare_data()
|
| 158 |
-
print(f"\nDataset:")
|
| 159 |
-
print(f" - Training samples: {len(train_texts)}")
|
| 160 |
-
print(f" - Validation samples: {len(val_texts)}")
|
| 161 |
-
print(f" - Categories: {len(SAMPLE_DATA)}")
|
| 162 |
-
|
| 163 |
-
# Initialize classifier
|
| 164 |
-
classifier = SenTextClassifier(
|
| 165 |
-
max_features=5000,
|
| 166 |
-
ngram_range=(1, 2),
|
| 167 |
-
min_df=1,
|
| 168 |
-
max_df=0.95,
|
| 169 |
-
sublinear_tf=True,
|
| 170 |
-
C=1.0,
|
| 171 |
-
max_iter=1000,
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
# Train
|
| 175 |
-
print("\n" + "=" * 60)
|
| 176 |
-
print("Training...")
|
| 177 |
-
print("=" * 60)
|
| 178 |
-
results = classifier.train(
|
| 179 |
-
train_texts=train_texts,
|
| 180 |
-
train_labels=train_labels,
|
| 181 |
-
val_texts=val_texts,
|
| 182 |
-
val_labels=val_labels,
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
# Evaluate
|
| 186 |
-
print("\n" + "=" * 60)
|
| 187 |
-
print("Evaluation on validation set:")
|
| 188 |
-
print("=" * 60)
|
| 189 |
-
classifier.evaluate(val_texts, val_labels)
|
| 190 |
-
|
| 191 |
-
# Test predictions
|
| 192 |
-
print("\n" + "=" * 60)
|
| 193 |
-
print("Sample Predictions:")
|
| 194 |
-
print("=" * 60)
|
| 195 |
-
test_texts = [
|
| 196 |
-
"Đội tuyển bóng đá Việt Nam chiến thắng",
|
| 197 |
-
"Giá vàng tăng mạnh trong phiên giao dịch",
|
| 198 |
-
"Apple công bố sản phẩm mới tại sự kiện",
|
| 199 |
-
"Bộ Y tế cảnh báo dịch cúm mùa",
|
| 200 |
-
"Quốc hội họp phiên bất thường",
|
| 201 |
-
]
|
| 202 |
-
|
| 203 |
-
for text in test_texts:
|
| 204 |
-
from sen import Sentence
|
| 205 |
-
sentence = Sentence(text)
|
| 206 |
-
classifier.predict(sentence)
|
| 207 |
-
print(f" '{text}' -> {sentence.labels[0]}")
|
| 208 |
-
|
| 209 |
-
# Save model
|
| 210 |
-
save_path = "/home/anhvu2/projects/workspace_underthesea/sen/trained_model"
|
| 211 |
-
print(f"\n" + "=" * 60)
|
| 212 |
-
print(f"Saving model to: {save_path}")
|
| 213 |
-
print("=" * 60)
|
| 214 |
-
classifier.save(save_path)
|
| 215 |
-
|
| 216 |
-
print("\nTraining completed!")
|
| 217 |
-
return classifier
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
if __name__ == "__main__":
|
| 221 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/scripts/train_sonar.py
DELETED
|
@@ -1,234 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Reproduce sonar_core_1 training configuration.
|
| 3 |
-
Uses CountVectorizer + TfidfTransformer + SVC pipeline.
|
| 4 |
-
Target: 92.80% accuracy on VNTC (matching sonar_core_1)
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import os
|
| 8 |
-
import sys
|
| 9 |
-
import time
|
| 10 |
-
import json
|
| 11 |
-
from datetime import datetime
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
import joblib
|
| 15 |
-
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
| 16 |
-
from sklearn.svm import SVC, LinearSVC
|
| 17 |
-
from sklearn.linear_model import LogisticRegression
|
| 18 |
-
from sklearn.pipeline import Pipeline
|
| 19 |
-
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
| 20 |
-
|
| 21 |
-
sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
|
| 22 |
-
|
| 23 |
-
# VNTC data paths
|
| 24 |
-
VNTC_BASE = "/home/anhvu2/projects/workspace_underthesea/VNTC_github/Data/10Topics/Ver1.1"
|
| 25 |
-
TRAIN_DIR = os.path.join(VNTC_BASE, "Train_Full")
|
| 26 |
-
TEST_DIR = os.path.join(VNTC_BASE, "Test_Full")
|
| 27 |
-
|
| 28 |
-
# Category mapping
|
| 29 |
-
CATEGORY_MAP = {
|
| 30 |
-
"Chinh tri Xa hoi": "Chinh tri Xa hoi",
|
| 31 |
-
"Doi song": "Doi song",
|
| 32 |
-
"Khoa hoc": "Khoa hoc",
|
| 33 |
-
"Kinh doanh": "Kinh doanh",
|
| 34 |
-
"Phap luat": "Phap luat",
|
| 35 |
-
"Suc khoe": "Suc khoe",
|
| 36 |
-
"The gioi": "The gioi",
|
| 37 |
-
"The thao": "The thao",
|
| 38 |
-
"Van hoa": "Van hoa",
|
| 39 |
-
"Vi tinh": "Vi tinh",
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def read_file(filepath):
|
| 44 |
-
"""Read text file with multiple encoding attempts."""
|
| 45 |
-
encodings = ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']
|
| 46 |
-
for encoding in encodings:
|
| 47 |
-
try:
|
| 48 |
-
with open(filepath, 'r', encoding=encoding) as f:
|
| 49 |
-
text = f.read()
|
| 50 |
-
text = ' '.join(text.split())
|
| 51 |
-
if len(text) > 10:
|
| 52 |
-
return text
|
| 53 |
-
except (UnicodeDecodeError, UnicodeError):
|
| 54 |
-
continue
|
| 55 |
-
return None
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def load_vntc_data(data_dir):
|
| 59 |
-
"""Load VNTC data from directory."""
|
| 60 |
-
texts = []
|
| 61 |
-
labels = []
|
| 62 |
-
|
| 63 |
-
for folder_name, label in CATEGORY_MAP.items():
|
| 64 |
-
folder_path = os.path.join(data_dir, folder_name)
|
| 65 |
-
if not os.path.exists(folder_path):
|
| 66 |
-
print(f" Warning: {folder_path} not found")
|
| 67 |
-
continue
|
| 68 |
-
|
| 69 |
-
files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
|
| 70 |
-
for filename in files:
|
| 71 |
-
filepath = os.path.join(folder_path, filename)
|
| 72 |
-
text = read_file(filepath)
|
| 73 |
-
if text:
|
| 74 |
-
texts.append(text)
|
| 75 |
-
labels.append(label)
|
| 76 |
-
|
| 77 |
-
return np.array(texts), np.array(labels)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def train_sonar_config():
|
| 81 |
-
"""Train with sonar_core_1 configuration."""
|
| 82 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 83 |
-
|
| 84 |
-
print("=" * 70)
|
| 85 |
-
print("Reproducing sonar_core_1 Configuration")
|
| 86 |
-
print("=" * 70)
|
| 87 |
-
|
| 88 |
-
# Load data
|
| 89 |
-
print("\n[1/4] Loading VNTC data...")
|
| 90 |
-
start = time.time()
|
| 91 |
-
X_train, y_train = load_vntc_data(TRAIN_DIR)
|
| 92 |
-
X_test, y_test = load_vntc_data(TEST_DIR)
|
| 93 |
-
print(f" Train: {len(X_train)} samples")
|
| 94 |
-
print(f" Test: {len(X_test)} samples")
|
| 95 |
-
print(f" Classes: {len(set(y_train))}")
|
| 96 |
-
print(f" Load time: {time.time()-start:.1f}s")
|
| 97 |
-
|
| 98 |
-
# sonar_core_1 configuration
|
| 99 |
-
configs = [
|
| 100 |
-
{
|
| 101 |
-
"name": "SVC (sonar_core_1 config)",
|
| 102 |
-
"max_features": 20000,
|
| 103 |
-
"ngram_range": (1, 2),
|
| 104 |
-
"classifier": SVC(kernel='linear', probability=True, random_state=42),
|
| 105 |
-
},
|
| 106 |
-
{
|
| 107 |
-
"name": "LinearSVC (faster)",
|
| 108 |
-
"max_features": 20000,
|
| 109 |
-
"ngram_range": (1, 2),
|
| 110 |
-
"classifier": LinearSVC(C=1.0, max_iter=2000, random_state=42),
|
| 111 |
-
},
|
| 112 |
-
{
|
| 113 |
-
"name": "LogisticRegression",
|
| 114 |
-
"max_features": 20000,
|
| 115 |
-
"ngram_range": (1, 2),
|
| 116 |
-
"classifier": LogisticRegression(max_iter=1000, random_state=42),
|
| 117 |
-
},
|
| 118 |
-
]
|
| 119 |
-
|
| 120 |
-
results = []
|
| 121 |
-
best_model = None
|
| 122 |
-
best_accuracy = 0
|
| 123 |
-
|
| 124 |
-
for i, config in enumerate(configs):
|
| 125 |
-
print(f"\n[{i+2}/4] Training: {config['name']}")
|
| 126 |
-
print("-" * 50)
|
| 127 |
-
|
| 128 |
-
# Create pipeline (sonar_core_1 style)
|
| 129 |
-
pipeline = Pipeline([
|
| 130 |
-
('vect', CountVectorizer(
|
| 131 |
-
max_features=config['max_features'],
|
| 132 |
-
ngram_range=config['ngram_range']
|
| 133 |
-
)),
|
| 134 |
-
('tfidf', TfidfTransformer(use_idf=True)),
|
| 135 |
-
('clf', config['classifier']),
|
| 136 |
-
])
|
| 137 |
-
|
| 138 |
-
# Train
|
| 139 |
-
start = time.time()
|
| 140 |
-
pipeline.fit(X_train, y_train)
|
| 141 |
-
train_time = time.time() - start
|
| 142 |
-
|
| 143 |
-
# Evaluate
|
| 144 |
-
train_pred = pipeline.predict(X_train)
|
| 145 |
-
test_pred = pipeline.predict(X_test)
|
| 146 |
-
|
| 147 |
-
train_acc = accuracy_score(y_train, train_pred)
|
| 148 |
-
test_acc = accuracy_score(y_test, test_pred)
|
| 149 |
-
test_f1 = f1_score(y_test, test_pred, average='weighted')
|
| 150 |
-
|
| 151 |
-
print(f" Train accuracy: {train_acc:.4f}")
|
| 152 |
-
print(f" Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
|
| 153 |
-
print(f" Test F1: {test_f1:.4f}")
|
| 154 |
-
print(f" Train time: {train_time:.1f}s")
|
| 155 |
-
|
| 156 |
-
results.append({
|
| 157 |
-
"name": config['name'],
|
| 158 |
-
"train_acc": train_acc,
|
| 159 |
-
"test_acc": test_acc,
|
| 160 |
-
"test_f1": test_f1,
|
| 161 |
-
"train_time": train_time,
|
| 162 |
-
})
|
| 163 |
-
|
| 164 |
-
if test_acc > best_accuracy:
|
| 165 |
-
best_accuracy = test_acc
|
| 166 |
-
best_model = pipeline
|
| 167 |
-
best_config = config
|
| 168 |
-
|
| 169 |
-
# Print comparison
|
| 170 |
-
print("\n" + "=" * 70)
|
| 171 |
-
print("Results Comparison")
|
| 172 |
-
print("=" * 70)
|
| 173 |
-
print(f"{'Model':<30} {'Test Acc':>10} {'Test F1':>10} {'Time':>10}")
|
| 174 |
-
print("-" * 70)
|
| 175 |
-
for r in results:
|
| 176 |
-
print(f"{r['name']:<30} {r['test_acc']*100:>9.2f}% {r['test_f1']:>10.4f} {r['train_time']:>9.1f}s")
|
| 177 |
-
print("-" * 70)
|
| 178 |
-
print(f"sonar_core_1 reference: {92.80:>9.2f}%")
|
| 179 |
-
print("=" * 70)
|
| 180 |
-
|
| 181 |
-
# Save best model
|
| 182 |
-
save_dir = "/home/anhvu2/projects/workspace_underthesea/sen/sen-general-1.0.0-20260202"
|
| 183 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 184 |
-
|
| 185 |
-
# Save pipeline
|
| 186 |
-
joblib.dump(best_model, os.path.join(save_dir, "pipeline.joblib"))
|
| 187 |
-
|
| 188 |
-
# Save label encoder (for compatibility)
|
| 189 |
-
from sklearn.preprocessing import LabelEncoder
|
| 190 |
-
le = LabelEncoder()
|
| 191 |
-
le.fit(y_train)
|
| 192 |
-
joblib.dump(le, os.path.join(save_dir, "label_encoder.joblib"))
|
| 193 |
-
|
| 194 |
-
# Save metadata
|
| 195 |
-
metadata = {
|
| 196 |
-
"model_type": "sonar_core_1_reproduction",
|
| 197 |
-
"architecture": "CountVectorizer + TfidfTransformer + LinearSVC",
|
| 198 |
-
"max_features": best_config['max_features'],
|
| 199 |
-
"ngram_range": list(best_config['ngram_range']),
|
| 200 |
-
"train_samples": len(X_train),
|
| 201 |
-
"test_samples": len(X_test),
|
| 202 |
-
"train_accuracy": float(results[1]['train_acc']), # LinearSVC
|
| 203 |
-
"test_accuracy": float(results[1]['test_acc']),
|
| 204 |
-
"test_f1_weighted": float(results[1]['test_f1']),
|
| 205 |
-
"labels": sorted(list(set(y_train))),
|
| 206 |
-
"timestamp": timestamp,
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
with open(os.path.join(save_dir, "metadata.json"), 'w') as f:
|
| 210 |
-
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
| 211 |
-
|
| 212 |
-
print(f"\nBest model saved to: {save_dir}")
|
| 213 |
-
|
| 214 |
-
# Print detailed classification report for best model
|
| 215 |
-
print("\n" + "=" * 70)
|
| 216 |
-
print("Classification Report (LinearSVC)")
|
| 217 |
-
print("=" * 70)
|
| 218 |
-
|
| 219 |
-
# Retrain LinearSVC for report
|
| 220 |
-
pipeline = Pipeline([
|
| 221 |
-
('vect', CountVectorizer(max_features=20000, ngram_range=(1, 2))),
|
| 222 |
-
('tfidf', TfidfTransformer(use_idf=True)),
|
| 223 |
-
('clf', LinearSVC(C=1.0, max_iter=2000, random_state=42)),
|
| 224 |
-
])
|
| 225 |
-
pipeline.fit(X_train, y_train)
|
| 226 |
-
test_pred = pipeline.predict(X_test)
|
| 227 |
-
|
| 228 |
-
print(classification_report(y_test, test_pred))
|
| 229 |
-
|
| 230 |
-
return results
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
if __name__ == "__main__":
|
| 234 |
-
train_sonar_config()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/scripts/train_vntc.py
DELETED
|
@@ -1,181 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Train Sen Text Classifier on full VNTC dataset.
|
| 3 |
-
|
| 4 |
-
VNTC: Vietnamese News Text Classification Corpus
|
| 5 |
-
- 10 Topics, ~33,759 train / ~50,373 test documents
|
| 6 |
-
- Reference: Vu et al. (2007) RIVF
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
import sys
|
| 11 |
-
import time
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
|
| 14 |
-
sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
|
| 15 |
-
|
| 16 |
-
from sen import SenTextClassifier
|
| 17 |
-
|
| 18 |
-
# VNTC data paths
|
| 19 |
-
VNTC_BASE = "/home/anhvu2/projects/workspace_underthesea/VNTC_github/Data/10Topics/Ver1.1"
|
| 20 |
-
TRAIN_DIR = os.path.join(VNTC_BASE, "Train_Full")
|
| 21 |
-
TEST_DIR = os.path.join(VNTC_BASE, "Test_Full")
|
| 22 |
-
|
| 23 |
-
# Category mapping (folder name -> normalized label)
|
| 24 |
-
CATEGORY_MAP = {
|
| 25 |
-
"Chinh tri Xa hoi": "chinh_tri_xa_hoi",
|
| 26 |
-
"Doi song": "doi_song",
|
| 27 |
-
"Khoa hoc": "khoa_hoc",
|
| 28 |
-
"Kinh doanh": "kinh_doanh",
|
| 29 |
-
"Phap luat": "phap_luat",
|
| 30 |
-
"Suc khoe": "suc_khoe",
|
| 31 |
-
"The gioi": "the_gioi",
|
| 32 |
-
"The thao": "the_thao",
|
| 33 |
-
"Van hoa": "van_hoa",
|
| 34 |
-
"Vi tinh": "vi_tinh",
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def read_file(filepath):
|
| 39 |
-
"""Read text file with multiple encoding attempts."""
|
| 40 |
-
encodings = ['utf-16', 'utf-16-le', 'utf-8', 'latin-1']
|
| 41 |
-
|
| 42 |
-
for encoding in encodings:
|
| 43 |
-
try:
|
| 44 |
-
with open(filepath, 'r', encoding=encoding) as f:
|
| 45 |
-
text = f.read()
|
| 46 |
-
# Clean up text (remove extra whitespace)
|
| 47 |
-
text = ' '.join(text.split())
|
| 48 |
-
if len(text) > 10: # Valid text
|
| 49 |
-
return text
|
| 50 |
-
except (UnicodeDecodeError, UnicodeError):
|
| 51 |
-
continue
|
| 52 |
-
|
| 53 |
-
return None
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def load_vntc_data(data_dir, max_per_category=None):
|
| 57 |
-
"""Load VNTC data from directory."""
|
| 58 |
-
texts = []
|
| 59 |
-
labels = []
|
| 60 |
-
stats = {}
|
| 61 |
-
|
| 62 |
-
for folder_name, label in CATEGORY_MAP.items():
|
| 63 |
-
folder_path = os.path.join(data_dir, folder_name)
|
| 64 |
-
if not os.path.exists(folder_path):
|
| 65 |
-
print(f" Warning: {folder_path} not found")
|
| 66 |
-
continue
|
| 67 |
-
|
| 68 |
-
files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]
|
| 69 |
-
if max_per_category:
|
| 70 |
-
files = files[:max_per_category]
|
| 71 |
-
|
| 72 |
-
count = 0
|
| 73 |
-
for filename in files:
|
| 74 |
-
filepath = os.path.join(folder_path, filename)
|
| 75 |
-
text = read_file(filepath)
|
| 76 |
-
if text:
|
| 77 |
-
texts.append(text)
|
| 78 |
-
labels.append(label)
|
| 79 |
-
count += 1
|
| 80 |
-
|
| 81 |
-
stats[label] = count
|
| 82 |
-
|
| 83 |
-
return texts, labels, stats
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def main():
|
| 87 |
-
print("=" * 70)
|
| 88 |
-
print("Training Sen Text Classifier on VNTC Dataset")
|
| 89 |
-
print("Vietnamese News Text Classification Corpus")
|
| 90 |
-
print("=" * 70)
|
| 91 |
-
|
| 92 |
-
# Load training data
|
| 93 |
-
print("\n[1/5] Loading training data...")
|
| 94 |
-
start = time.time()
|
| 95 |
-
train_texts, train_labels, train_stats = load_vntc_data(TRAIN_DIR)
|
| 96 |
-
print(f" Loaded {len(train_texts)} training samples in {time.time()-start:.1f}s")
|
| 97 |
-
print(" Per category:")
|
| 98 |
-
for label, count in sorted(train_stats.items()):
|
| 99 |
-
print(f" - {label}: {count}")
|
| 100 |
-
|
| 101 |
-
# Load test data
|
| 102 |
-
print("\n[2/5] Loading test data...")
|
| 103 |
-
start = time.time()
|
| 104 |
-
test_texts, test_labels, test_stats = load_vntc_data(TEST_DIR)
|
| 105 |
-
print(f" Loaded {len(test_texts)} test samples in {time.time()-start:.1f}s")
|
| 106 |
-
print(" Per category:")
|
| 107 |
-
for label, count in sorted(test_stats.items()):
|
| 108 |
-
print(f" - {label}: {count}")
|
| 109 |
-
|
| 110 |
-
# Initialize classifier
|
| 111 |
-
print("\n[3/5] Initializing classifier...")
|
| 112 |
-
classifier = SenTextClassifier(
|
| 113 |
-
max_features=10000, # Increased for larger dataset
|
| 114 |
-
ngram_range=(1, 2),
|
| 115 |
-
min_df=2, # Require term in at least 2 docs
|
| 116 |
-
max_df=0.95,
|
| 117 |
-
sublinear_tf=True,
|
| 118 |
-
C=1.0,
|
| 119 |
-
max_iter=2000,
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
# Train
|
| 123 |
-
print("\n[4/5] Training...")
|
| 124 |
-
start = time.time()
|
| 125 |
-
results = classifier.train(
|
| 126 |
-
train_texts=train_texts,
|
| 127 |
-
train_labels=train_labels,
|
| 128 |
-
val_texts=test_texts[:5000], # Use subset for validation during training
|
| 129 |
-
val_labels=test_labels[:5000],
|
| 130 |
-
)
|
| 131 |
-
train_time = time.time() - start
|
| 132 |
-
print(f" Training completed in {train_time:.1f}s")
|
| 133 |
-
|
| 134 |
-
# Evaluate on full test set
|
| 135 |
-
print("\n[5/5] Evaluating on full test set...")
|
| 136 |
-
start = time.time()
|
| 137 |
-
eval_results = classifier.evaluate(test_texts, test_labels)
|
| 138 |
-
eval_time = time.time() - start
|
| 139 |
-
|
| 140 |
-
print("\n" + "=" * 70)
|
| 141 |
-
print("VNTC Benchmark Results (10 Topics)")
|
| 142 |
-
print("=" * 70)
|
| 143 |
-
print(f" Test samples: {len(test_texts)}")
|
| 144 |
-
print(f" Accuracy: {eval_results['accuracy']:.4f} ({eval_results['accuracy']*100:.2f}%)")
|
| 145 |
-
print(f" F1 (weighted):{eval_results['f1_weighted']:.4f}")
|
| 146 |
-
print(f" Train time: {train_time:.1f}s")
|
| 147 |
-
print(f" Eval time: {eval_time:.1f}s")
|
| 148 |
-
print("=" * 70)
|
| 149 |
-
|
| 150 |
-
# Save model
|
| 151 |
-
save_path = "/home/anhvu2/projects/workspace_underthesea/sen/sen-1.0.0-vntc"
|
| 152 |
-
print(f"\nSaving model to: {save_path}")
|
| 153 |
-
classifier.save(save_path)
|
| 154 |
-
|
| 155 |
-
# Sample predictions
|
| 156 |
-
print("\nSample Predictions:")
|
| 157 |
-
print("-" * 70)
|
| 158 |
-
test_samples = [
|
| 159 |
-
"Đội tuyển Việt Nam thắng đậm 3-0 trước Indonesia",
|
| 160 |
-
"Giá vàng tăng mạnh trong phiên giao dịch hôm nay",
|
| 161 |
-
"Apple ra mắt iPhone mới với nhiều tính năng hấp dẫn",
|
| 162 |
-
"Bộ Y tế cảnh báo về dịch cúm mùa đông",
|
| 163 |
-
"Quốc hội thông qua nghị quyết phát triển kinh tế",
|
| 164 |
-
]
|
| 165 |
-
|
| 166 |
-
from sen import Sentence
|
| 167 |
-
for text in test_samples:
|
| 168 |
-
sentence = Sentence(text)
|
| 169 |
-
classifier.predict(sentence)
|
| 170 |
-
label = sentence.labels[0]
|
| 171 |
-
print(f" '{text[:50]}...' -> {label.value} ({label.score:.2f})")
|
| 172 |
-
|
| 173 |
-
print("\n" + "=" * 70)
|
| 174 |
-
print("Training completed successfully!")
|
| 175 |
-
print("=" * 70)
|
| 176 |
-
|
| 177 |
-
return eval_results
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
if __name__ == "__main__":
|
| 181 |
-
results = main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/sen/__init__.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Sen-1: Vietnamese Text Classification by UnderTheSea NLP.
|
| 3 |
-
|
| 4 |
-
Based on: "A Comparative Study on Vietnamese Text Classification Methods"
|
| 5 |
-
Vu et al., RIVF 2007
|
| 6 |
-
|
| 7 |
-
Methods:
|
| 8 |
-
- TF-IDF vectorization (sklearn)
|
| 9 |
-
- SVM (Support Vector Machine) classifier
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from .text_classifier import (
|
| 13 |
-
Label,
|
| 14 |
-
Sentence,
|
| 15 |
-
SenTextClassifier,
|
| 16 |
-
classify,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
__version__ = "1.0.0"
|
| 20 |
-
|
| 21 |
-
__all__ = [
|
| 22 |
-
"Label",
|
| 23 |
-
"Sentence",
|
| 24 |
-
"SenTextClassifier",
|
| 25 |
-
"classify",
|
| 26 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/sen/text_classifier.py
DELETED
|
@@ -1,374 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Sen Text Classifier - Rust-based classifier using underthesea_core_extend.
|
| 3 |
-
|
| 4 |
-
Based on: "A Comparative Study on Vietnamese Text Classification Methods"
|
| 5 |
-
Vu et al., RIVF 2007
|
| 6 |
-
https://ieeexplore.ieee.org/document/4223084/
|
| 7 |
-
|
| 8 |
-
Methods:
|
| 9 |
-
- TF-IDF vectorization (Rust: underthesea_core_extend.TfIdfVectorizer)
|
| 10 |
-
- Linear SVM classifier (Rust: underthesea_core_extend.LinearSVM)
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
import json
|
| 14 |
-
import os
|
| 15 |
-
from typing import List, Optional, Union
|
| 16 |
-
|
| 17 |
-
from underthesea_core_extend import TfIdfVectorizer, LinearSVM, SVMTrainer
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class Label:
|
| 21 |
-
"""Label class compatible with underthesea."""
|
| 22 |
-
|
| 23 |
-
def __init__(self, value: str, score: float = 1.0):
|
| 24 |
-
self.value = value
|
| 25 |
-
self.score = min(max(score, 0.0), 1.0)
|
| 26 |
-
|
| 27 |
-
def __str__(self):
|
| 28 |
-
return f"{self.value} ({self.score:.4f})"
|
| 29 |
-
|
| 30 |
-
def __repr__(self):
|
| 31 |
-
return f"{self.value} ({self.score:.4f})"
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class Sentence:
|
| 35 |
-
"""Sentence class compatible with underthesea."""
|
| 36 |
-
|
| 37 |
-
def __init__(self, text: str = None, labels: List[Label] = None):
|
| 38 |
-
self.text = text
|
| 39 |
-
self.labels = labels or []
|
| 40 |
-
|
| 41 |
-
def __str__(self):
|
| 42 |
-
return f'Sentence: "{self.text}" - Labels: {self.labels}'
|
| 43 |
-
|
| 44 |
-
def __repr__(self):
|
| 45 |
-
return f'Sentence: "{self.text}" - Labels: {self.labels}'
|
| 46 |
-
|
| 47 |
-
def add_labels(self, labels: List[Union[Label, str]]):
|
| 48 |
-
for label in labels:
|
| 49 |
-
if isinstance(label, str):
|
| 50 |
-
label = Label(label)
|
| 51 |
-
self.labels.append(label)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class SenTextClassifier:
|
| 55 |
-
"""
|
| 56 |
-
Rust-based text classifier using TF-IDF + Linear SVM.
|
| 57 |
-
|
| 58 |
-
Uses underthesea_core_extend for fast training and inference.
|
| 59 |
-
Compatible with underthesea API.
|
| 60 |
-
|
| 61 |
-
Reference:
|
| 62 |
-
Vu et al. "A Comparative Study on Vietnamese Text Classification Methods"
|
| 63 |
-
RIVF 2007
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
def __init__(
|
| 67 |
-
self,
|
| 68 |
-
# TF-IDF parameters
|
| 69 |
-
max_features: int = 20000,
|
| 70 |
-
ngram_range: tuple = (1, 2),
|
| 71 |
-
min_df: int = 1,
|
| 72 |
-
max_df: float = 1.0,
|
| 73 |
-
# SVM parameters
|
| 74 |
-
c: float = 1.0,
|
| 75 |
-
max_iter: int = 1000,
|
| 76 |
-
tol: float = 0.1,
|
| 77 |
-
verbose: bool = True,
|
| 78 |
-
):
|
| 79 |
-
self.max_features = max_features
|
| 80 |
-
self.ngram_range = ngram_range
|
| 81 |
-
self.min_df = min_df
|
| 82 |
-
self.max_df = max_df
|
| 83 |
-
self.c = c
|
| 84 |
-
self.max_iter = max_iter
|
| 85 |
-
self.tol = tol
|
| 86 |
-
self.verbose = verbose
|
| 87 |
-
|
| 88 |
-
self.vectorizer: Optional[TfIdfVectorizer] = None
|
| 89 |
-
self.classifier: Optional[LinearSVM] = None
|
| 90 |
-
self.labels_: Optional[List[str]] = None
|
| 91 |
-
|
| 92 |
-
def train(
|
| 93 |
-
self,
|
| 94 |
-
train_texts: List[str],
|
| 95 |
-
train_labels: List[str],
|
| 96 |
-
val_texts: List[str] = None,
|
| 97 |
-
val_labels: List[str] = None,
|
| 98 |
-
) -> dict:
|
| 99 |
-
"""
|
| 100 |
-
Train the classifier.
|
| 101 |
-
|
| 102 |
-
Args:
|
| 103 |
-
train_texts: List of training texts
|
| 104 |
-
train_labels: List of training labels
|
| 105 |
-
val_texts: Optional validation texts
|
| 106 |
-
val_labels: Optional validation labels
|
| 107 |
-
|
| 108 |
-
Returns:
|
| 109 |
-
Dictionary with training metrics
|
| 110 |
-
"""
|
| 111 |
-
# Get unique labels
|
| 112 |
-
self.labels_ = sorted(list(set(train_labels)))
|
| 113 |
-
|
| 114 |
-
# Build and fit vectorizer
|
| 115 |
-
self.vectorizer = TfIdfVectorizer(
|
| 116 |
-
max_features=self.max_features,
|
| 117 |
-
ngram_range=self.ngram_range,
|
| 118 |
-
min_df=self.min_df,
|
| 119 |
-
max_df=self.max_df,
|
| 120 |
-
)
|
| 121 |
-
self.vectorizer.fit(train_texts)
|
| 122 |
-
|
| 123 |
-
# Transform to features
|
| 124 |
-
train_features = self.vectorizer.transform_batch(train_texts)
|
| 125 |
-
|
| 126 |
-
# Build and train SVM model
|
| 127 |
-
trainer = SVMTrainer(
|
| 128 |
-
c=self.c,
|
| 129 |
-
max_iter=self.max_iter,
|
| 130 |
-
tol=self.tol,
|
| 131 |
-
verbose=self.verbose,
|
| 132 |
-
)
|
| 133 |
-
self.classifier = trainer.train(train_features, train_labels)
|
| 134 |
-
|
| 135 |
-
# Calculate training metrics
|
| 136 |
-
train_preds = self.classifier.predict_batch(train_features)
|
| 137 |
-
train_acc = sum(1 for p, t in zip(train_preds, train_labels) if p == t) / len(train_labels)
|
| 138 |
-
|
| 139 |
-
# Calculate F1 score
|
| 140 |
-
train_f1 = self._calculate_f1(train_labels, train_preds)
|
| 141 |
-
|
| 142 |
-
results = {
|
| 143 |
-
"train_accuracy": train_acc,
|
| 144 |
-
"train_f1": train_f1,
|
| 145 |
-
"num_classes": len(self.labels_),
|
| 146 |
-
"num_samples": len(train_texts),
|
| 147 |
-
"vocab_size": self.vectorizer.vocab_size,
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
print(f"Training completed:")
|
| 151 |
-
print(f" - Samples: {len(train_texts)}")
|
| 152 |
-
print(f" - Classes: {len(self.labels_)}")
|
| 153 |
-
print(f" - Vocab size: {self.vectorizer.vocab_size}")
|
| 154 |
-
print(f" - Train accuracy: {train_acc:.4f}")
|
| 155 |
-
print(f" - Train F1: {train_f1:.4f}")
|
| 156 |
-
|
| 157 |
-
# Validation metrics
|
| 158 |
-
if val_texts and val_labels:
|
| 159 |
-
val_features = self.vectorizer.transform_batch(val_texts)
|
| 160 |
-
val_preds = self.classifier.predict_batch(val_features)
|
| 161 |
-
val_acc = sum(1 for p, t in zip(val_preds, val_labels) if p == t) / len(val_labels)
|
| 162 |
-
val_f1 = self._calculate_f1(val_labels, val_preds)
|
| 163 |
-
|
| 164 |
-
results["val_accuracy"] = val_acc
|
| 165 |
-
results["val_f1"] = val_f1
|
| 166 |
-
|
| 167 |
-
print(f" - Val accuracy: {val_acc:.4f}")
|
| 168 |
-
print(f" - Val F1: {val_f1:.4f}")
|
| 169 |
-
|
| 170 |
-
return results
|
| 171 |
-
|
| 172 |
-
def _calculate_f1(self, y_true: List[str], y_pred: List[str]) -> float:
|
| 173 |
-
"""Calculate weighted F1 score."""
|
| 174 |
-
from collections import Counter
|
| 175 |
-
|
| 176 |
-
label_counts = Counter(y_true)
|
| 177 |
-
total = len(y_true)
|
| 178 |
-
|
| 179 |
-
f1_sum = 0.0
|
| 180 |
-
for label in self.labels_:
|
| 181 |
-
tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
|
| 182 |
-
fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
|
| 183 |
-
fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
|
| 184 |
-
|
| 185 |
-
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 186 |
-
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 187 |
-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
| 188 |
-
|
| 189 |
-
weight = label_counts[label] / total
|
| 190 |
-
f1_sum += f1 * weight
|
| 191 |
-
|
| 192 |
-
return f1_sum
|
| 193 |
-
|
| 194 |
-
def predict(self, sentence: Sentence) -> None:
|
| 195 |
-
"""
|
| 196 |
-
Predict label for a sentence (underthesea-compatible API).
|
| 197 |
-
|
| 198 |
-
Args:
|
| 199 |
-
sentence: Sentence object with text attribute
|
| 200 |
-
"""
|
| 201 |
-
if self.classifier is None or self.vectorizer is None:
|
| 202 |
-
raise ValueError("Model not trained. Call train() first or load a model.")
|
| 203 |
-
|
| 204 |
-
features = self.vectorizer.transform_dense(sentence.text)
|
| 205 |
-
label_value, score = self.classifier.predict_with_score(features)
|
| 206 |
-
|
| 207 |
-
sentence.labels = []
|
| 208 |
-
sentence.add_labels([Label(label_value, score)])
|
| 209 |
-
|
| 210 |
-
def predict_batch(self, texts: List[str]) -> List[Label]:
|
| 211 |
-
"""
|
| 212 |
-
Predict labels for multiple texts.
|
| 213 |
-
|
| 214 |
-
Args:
|
| 215 |
-
texts: List of texts to classify
|
| 216 |
-
|
| 217 |
-
Returns:
|
| 218 |
-
List of Label objects
|
| 219 |
-
"""
|
| 220 |
-
if self.classifier is None or self.vectorizer is None:
|
| 221 |
-
raise ValueError("Model not trained. Call train() first or load a model.")
|
| 222 |
-
|
| 223 |
-
# Use dense transform (faster Python-Rust interface)
|
| 224 |
-
features = self.vectorizer.transform_batch(texts)
|
| 225 |
-
results = []
|
| 226 |
-
for feat in features:
|
| 227 |
-
label_value, score = self.classifier.predict_with_score(feat)
|
| 228 |
-
results.append(Label(label_value, float(score)))
|
| 229 |
-
|
| 230 |
-
return results
|
| 231 |
-
|
| 232 |
-
def evaluate(self, texts: List[str], labels: List[str]) -> dict:
|
| 233 |
-
"""
|
| 234 |
-
Evaluate model on test data.
|
| 235 |
-
|
| 236 |
-
Args:
|
| 237 |
-
texts: List of texts
|
| 238 |
-
labels: List of true labels
|
| 239 |
-
|
| 240 |
-
Returns:
|
| 241 |
-
Dictionary with evaluation metrics
|
| 242 |
-
"""
|
| 243 |
-
# Use dense transform (faster Python-Rust interface)
|
| 244 |
-
features = self.vectorizer.transform_batch(texts)
|
| 245 |
-
y_pred = self.classifier.predict_batch(features)
|
| 246 |
-
|
| 247 |
-
acc = sum(1 for p, t in zip(y_pred, labels) if p == t) / len(labels)
|
| 248 |
-
f1 = self._calculate_f1(labels, y_pred)
|
| 249 |
-
|
| 250 |
-
print(f"Evaluation:")
|
| 251 |
-
print(f" - Accuracy: {acc:.4f}")
|
| 252 |
-
print(f" - F1 (weighted): {f1:.4f}")
|
| 253 |
-
|
| 254 |
-
# Print classification report
|
| 255 |
-
self._print_classification_report(labels, y_pred)
|
| 256 |
-
|
| 257 |
-
return {"accuracy": acc, "f1": f1}
|
| 258 |
-
|
| 259 |
-
def _print_classification_report(self, y_true: List[str], y_pred: List[str]):
|
| 260 |
-
"""Print classification report."""
|
| 261 |
-
from collections import Counter
|
| 262 |
-
|
| 263 |
-
print("\nClassification Report:")
|
| 264 |
-
print(f"{'':>20} {'precision':>10} {'recall':>10} {'f1-score':>10} {'support':>10}")
|
| 265 |
-
print()
|
| 266 |
-
|
| 267 |
-
label_counts = Counter(y_true)
|
| 268 |
-
|
| 269 |
-
for label in self.labels_:
|
| 270 |
-
tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
|
| 271 |
-
fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
|
| 272 |
-
fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
|
| 273 |
-
|
| 274 |
-
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 275 |
-
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 276 |
-
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
| 277 |
-
support = label_counts[label]
|
| 278 |
-
|
| 279 |
-
print(f"{label:>20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f} {support:>10}")
|
| 280 |
-
|
| 281 |
-
print()
|
| 282 |
-
|
| 283 |
-
def save(self, path: str) -> None:
|
| 284 |
-
"""
|
| 285 |
-
Save model to disk.
|
| 286 |
-
|
| 287 |
-
Args:
|
| 288 |
-
path: Directory path to save model
|
| 289 |
-
"""
|
| 290 |
-
os.makedirs(path, exist_ok=True)
|
| 291 |
-
|
| 292 |
-
# Save vectorizer
|
| 293 |
-
self.vectorizer.save(os.path.join(path, "vectorizer.json"))
|
| 294 |
-
|
| 295 |
-
# Save classifier
|
| 296 |
-
self.classifier.save(os.path.join(path, "classifier.json"))
|
| 297 |
-
|
| 298 |
-
# Save metadata
|
| 299 |
-
metadata = {
|
| 300 |
-
"estimator": "RUST_SVM",
|
| 301 |
-
"max_features": self.max_features,
|
| 302 |
-
"ngram_range": self.ngram_range,
|
| 303 |
-
"min_df": self.min_df,
|
| 304 |
-
"max_df": self.max_df,
|
| 305 |
-
"c": self.c,
|
| 306 |
-
"max_iter": self.max_iter,
|
| 307 |
-
"tol": self.tol,
|
| 308 |
-
"labels": self.labels_,
|
| 309 |
-
"vocab_size": self.vectorizer.vocab_size,
|
| 310 |
-
"n_classes": self.classifier.n_classes,
|
| 311 |
-
}
|
| 312 |
-
with open(os.path.join(path, "metadata.json"), "w", encoding="utf-8") as f:
|
| 313 |
-
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
| 314 |
-
|
| 315 |
-
print(f"Model saved to: {path}")
|
| 316 |
-
|
| 317 |
-
@classmethod
|
| 318 |
-
def load(cls, path: str) -> "SenTextClassifier":
|
| 319 |
-
"""
|
| 320 |
-
Load model from disk.
|
| 321 |
-
|
| 322 |
-
Args:
|
| 323 |
-
path: Directory path containing saved model
|
| 324 |
-
|
| 325 |
-
Returns:
|
| 326 |
-
Loaded SenTextClassifier instance
|
| 327 |
-
"""
|
| 328 |
-
# Load metadata
|
| 329 |
-
with open(os.path.join(path, "metadata.json"), "r", encoding="utf-8") as f:
|
| 330 |
-
metadata = json.load(f)
|
| 331 |
-
|
| 332 |
-
# Create instance with saved parameters
|
| 333 |
-
classifier = cls(
|
| 334 |
-
max_features=metadata.get("max_features", 20000),
|
| 335 |
-
ngram_range=tuple(metadata.get("ngram_range", (1, 2))),
|
| 336 |
-
min_df=metadata.get("min_df", 1),
|
| 337 |
-
max_df=metadata.get("max_df", 1.0),
|
| 338 |
-
c=metadata.get("c", 1.0),
|
| 339 |
-
max_iter=metadata.get("max_iter", 1000),
|
| 340 |
-
tol=metadata.get("tol", 0.1),
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
# Load vectorizer
|
| 344 |
-
classifier.vectorizer = TfIdfVectorizer.load(os.path.join(path, "vectorizer.json"))
|
| 345 |
-
|
| 346 |
-
# Load SVM model
|
| 347 |
-
classifier.classifier = LinearSVM.load(os.path.join(path, "classifier.json"))
|
| 348 |
-
classifier.labels_ = metadata.get("labels", [])
|
| 349 |
-
|
| 350 |
-
print(f"Model loaded from: {path}")
|
| 351 |
-
return classifier
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def classify(text: str, model_path: str = None) -> List[str]:
|
| 355 |
-
"""
|
| 356 |
-
Classify text using Sen model.
|
| 357 |
-
|
| 358 |
-
Args:
|
| 359 |
-
text: Input text to classify
|
| 360 |
-
model_path: Path to trained model
|
| 361 |
-
|
| 362 |
-
Returns:
|
| 363 |
-
List of predicted labels
|
| 364 |
-
"""
|
| 365 |
-
if not hasattr(classify, "_classifier") or classify._model_path != model_path:
|
| 366 |
-
if model_path:
|
| 367 |
-
classify._classifier = SenTextClassifier.load(model_path)
|
| 368 |
-
classify._model_path = model_path
|
| 369 |
-
else:
|
| 370 |
-
raise ValueError("model_path is required")
|
| 371 |
-
|
| 372 |
-
sentence = Sentence(text)
|
| 373 |
-
classify._classifier.predict(sentence)
|
| 374 |
-
return [label.value for label in sentence.labels]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/train.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training CLI for Vietnamese Text Classification.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python train.py vntc --output models/sen-vntc.bin
|
| 6 |
+
python train.py bank --output models/sen-bank.bin
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import click
|
| 14 |
+
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
| 15 |
+
|
| 16 |
+
from underthesea_core import TextClassifier
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def read_file(filepath):
|
| 20 |
+
"""Read text file with multiple encoding attempts."""
|
| 21 |
+
for enc in ['utf-16', 'utf-8', 'latin-1']:
|
| 22 |
+
try:
|
| 23 |
+
with open(filepath, 'r', encoding=enc) as f:
|
| 24 |
+
text = ' '.join(f.read().split())
|
| 25 |
+
if len(text) > 10:
|
| 26 |
+
return text
|
| 27 |
+
except (UnicodeDecodeError, UnicodeError):
|
| 28 |
+
continue
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_vntc_data(data_dir):
|
| 33 |
+
"""Load VNTC data from directory."""
|
| 34 |
+
texts, labels = [], []
|
| 35 |
+
|
| 36 |
+
for folder in sorted(os.listdir(data_dir)):
|
| 37 |
+
folder_path = os.path.join(data_dir, folder)
|
| 38 |
+
if not os.path.isdir(folder_path):
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
for fname in os.listdir(folder_path):
|
| 42 |
+
if fname.endswith('.txt'):
|
| 43 |
+
text = read_file(os.path.join(folder_path, fname))
|
| 44 |
+
if text:
|
| 45 |
+
texts.append(text)
|
| 46 |
+
labels.append(folder)
|
| 47 |
+
|
| 48 |
+
return texts, labels
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@click.group()
|
| 52 |
+
def cli():
|
| 53 |
+
"""Train Vietnamese text classification models."""
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@cli.command()
|
| 58 |
+
@click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
|
| 59 |
+
help='Path to VNTC dataset')
|
| 60 |
+
@click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
|
| 61 |
+
@click.option('--max-features', default=20000, help='Maximum vocabulary size')
|
| 62 |
+
@click.option('--ngram-min', default=1, help='Minimum n-gram')
|
| 63 |
+
@click.option('--ngram-max', default=2, help='Maximum n-gram')
|
| 64 |
+
@click.option('--min-df', default=2, help='Minimum document frequency')
|
| 65 |
+
@click.option('--c', default=1.0, help='SVM regularization parameter')
|
| 66 |
+
@click.option('--max-iter', default=1000, help='Maximum iterations')
|
| 67 |
+
@click.option('--tol', default=0.1, help='Convergence tolerance')
|
| 68 |
+
def vntc(data_dir, output, max_features, ngram_min, ngram_max, min_df, c, max_iter, tol):
|
| 69 |
+
"""Train on VNTC dataset (10 topics, ~84k documents)."""
|
| 70 |
+
click.echo("=" * 70)
|
| 71 |
+
click.echo("VNTC Dataset Training (10 Topics)")
|
| 72 |
+
click.echo("=" * 70)
|
| 73 |
+
|
| 74 |
+
train_dir = os.path.join(data_dir, "Train_Full")
|
| 75 |
+
test_dir = os.path.join(data_dir, "Test_Full")
|
| 76 |
+
|
| 77 |
+
# Load data
|
| 78 |
+
click.echo("\nLoading data...")
|
| 79 |
+
t0 = time.perf_counter()
|
| 80 |
+
train_texts, train_labels = load_vntc_data(train_dir)
|
| 81 |
+
test_texts, test_labels = load_vntc_data(test_dir)
|
| 82 |
+
load_time = time.perf_counter() - t0
|
| 83 |
+
|
| 84 |
+
click.echo(f" Train samples: {len(train_texts)}")
|
| 85 |
+
click.echo(f" Test samples: {len(test_texts)}")
|
| 86 |
+
click.echo(f" Categories: {len(set(train_labels))}")
|
| 87 |
+
click.echo(f" Load time: {load_time:.2f}s")
|
| 88 |
+
|
| 89 |
+
# Train
|
| 90 |
+
click.echo("\nTraining Rust TextClassifier...")
|
| 91 |
+
clf = TextClassifier(
|
| 92 |
+
max_features=max_features,
|
| 93 |
+
ngram_range=(ngram_min, ngram_max),
|
| 94 |
+
min_df=min_df,
|
| 95 |
+
c=c,
|
| 96 |
+
max_iter=max_iter,
|
| 97 |
+
tol=tol,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
t0 = time.perf_counter()
|
| 101 |
+
clf.fit(train_texts, train_labels)
|
| 102 |
+
train_time = time.perf_counter() - t0
|
| 103 |
+
click.echo(f" Training time: {train_time:.2f}s")
|
| 104 |
+
click.echo(f" Vocabulary size: {clf.n_features}")
|
| 105 |
+
|
| 106 |
+
# Evaluate
|
| 107 |
+
click.echo("\nEvaluating...")
|
| 108 |
+
t0 = time.perf_counter()
|
| 109 |
+
preds = clf.predict_batch(test_texts)
|
| 110 |
+
infer_time = time.perf_counter() - t0
|
| 111 |
+
throughput = len(test_texts) / infer_time
|
| 112 |
+
|
| 113 |
+
acc = accuracy_score(test_labels, preds)
|
| 114 |
+
f1_w = f1_score(test_labels, preds, average='weighted')
|
| 115 |
+
f1_m = f1_score(test_labels, preds, average='macro')
|
| 116 |
+
|
| 117 |
+
click.echo(f" Inference: {infer_time:.3f}s ({throughput:.0f} samples/sec)")
|
| 118 |
+
|
| 119 |
+
click.echo("\n" + "=" * 70)
|
| 120 |
+
click.echo("RESULTS")
|
| 121 |
+
click.echo("=" * 70)
|
| 122 |
+
click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
|
| 123 |
+
click.echo(f" F1 (weighted): {f1_w:.4f}")
|
| 124 |
+
click.echo(f" F1 (macro): {f1_m:.4f}")
|
| 125 |
+
|
| 126 |
+
click.echo("\nClassification Report:")
|
| 127 |
+
click.echo(classification_report(test_labels, preds))
|
| 128 |
+
|
| 129 |
+
# Save model
|
| 130 |
+
model_path = Path(output)
|
| 131 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
clf.save(str(model_path))
|
| 133 |
+
|
| 134 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 135 |
+
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@cli.command()
|
| 139 |
+
@click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
|
| 140 |
+
@click.option('--max-features', default=10000, help='Maximum vocabulary size')
|
| 141 |
+
@click.option('--ngram-min', default=1, help='Minimum n-gram')
|
| 142 |
+
@click.option('--ngram-max', default=2, help='Maximum n-gram')
|
| 143 |
+
@click.option('--min-df', default=1, help='Minimum document frequency')
|
| 144 |
+
@click.option('--c', default=1.0, help='SVM regularization parameter')
|
| 145 |
+
@click.option('--max-iter', default=1000, help='Maximum iterations')
|
| 146 |
+
@click.option('--tol', default=0.1, help='Convergence tolerance')
|
| 147 |
+
def bank(output, max_features, ngram_min, ngram_max, min_df, c, max_iter, tol):
|
| 148 |
+
"""Train on UTS2017_Bank dataset (14 categories, banking domain)."""
|
| 149 |
+
from datasets import load_dataset
|
| 150 |
+
|
| 151 |
+
click.echo("=" * 70)
|
| 152 |
+
click.echo("UTS2017_Bank Dataset Training (14 Categories)")
|
| 153 |
+
click.echo("=" * 70)
|
| 154 |
+
|
| 155 |
+
# Load data
|
| 156 |
+
click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
|
| 157 |
+
dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
|
| 158 |
+
|
| 159 |
+
train_texts = list(dataset["train"]["text"])
|
| 160 |
+
train_labels = list(dataset["train"]["label"])
|
| 161 |
+
test_texts = list(dataset["test"]["text"])
|
| 162 |
+
test_labels = list(dataset["test"]["label"])
|
| 163 |
+
|
| 164 |
+
click.echo(f" Train samples: {len(train_texts)}")
|
| 165 |
+
click.echo(f" Test samples: {len(test_texts)}")
|
| 166 |
+
click.echo(f" Categories: {len(set(train_labels))}")
|
| 167 |
+
|
| 168 |
+
# Train
|
| 169 |
+
click.echo("\nTraining Rust TextClassifier...")
|
| 170 |
+
clf = TextClassifier(
|
| 171 |
+
max_features=max_features,
|
| 172 |
+
ngram_range=(ngram_min, ngram_max),
|
| 173 |
+
min_df=min_df,
|
| 174 |
+
c=c,
|
| 175 |
+
max_iter=max_iter,
|
| 176 |
+
tol=tol,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
t0 = time.perf_counter()
|
| 180 |
+
clf.fit(train_texts, train_labels)
|
| 181 |
+
train_time = time.perf_counter() - t0
|
| 182 |
+
click.echo(f" Training time: {train_time:.3f}s")
|
| 183 |
+
click.echo(f" Vocabulary size: {clf.n_features}")
|
| 184 |
+
|
| 185 |
+
# Evaluate
|
| 186 |
+
click.echo("\nEvaluating...")
|
| 187 |
+
preds = clf.predict_batch(test_texts)
|
| 188 |
+
|
| 189 |
+
acc = accuracy_score(test_labels, preds)
|
| 190 |
+
f1_w = f1_score(test_labels, preds, average='weighted')
|
| 191 |
+
f1_m = f1_score(test_labels, preds, average='macro')
|
| 192 |
+
|
| 193 |
+
click.echo("\n" + "=" * 70)
|
| 194 |
+
click.echo("RESULTS")
|
| 195 |
+
click.echo("=" * 70)
|
| 196 |
+
click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
|
| 197 |
+
click.echo(f" F1 (weighted): {f1_w:.4f}")
|
| 198 |
+
click.echo(f" F1 (macro): {f1_m:.4f}")
|
| 199 |
+
|
| 200 |
+
click.echo("\nClassification Report:")
|
| 201 |
+
click.echo(classification_report(test_labels, preds))
|
| 202 |
+
|
| 203 |
+
# Save model
|
| 204 |
+
model_path = Path(output)
|
| 205 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
clf.save(str(model_path))
|
| 207 |
+
|
| 208 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 209 |
+
click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
cli()
|
tests/test_classifier.py
DELETED
|
@@ -1,165 +0,0 @@
|
|
| 1 |
-
"""Test Sen Text Classifier (sklearn-based)."""
|
| 2 |
-
|
| 3 |
-
import sys
|
| 4 |
-
sys.path.insert(0, "/home/anhvu2/projects/workspace_underthesea")
|
| 5 |
-
|
| 6 |
-
from sen import SenTextClassifier, Sentence, Label
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def test_training():
|
| 10 |
-
"""Test training on sample Vietnamese sentiment data."""
|
| 11 |
-
print("=" * 60)
|
| 12 |
-
print("Test: Training sklearn-based Text Classifier")
|
| 13 |
-
print("=" * 60)
|
| 14 |
-
|
| 15 |
-
# Sample Vietnamese sentiment data
|
| 16 |
-
train_texts = [
|
| 17 |
-
"Sản phẩm rất tốt, tôi hài lòng!",
|
| 18 |
-
"Chất lượng tuyệt vời, giao hàng nhanh",
|
| 19 |
-
"Hàng đẹp, đóng gói cẩn thận",
|
| 20 |
-
"Mình rất thích sản phẩm này",
|
| 21 |
-
"Shop phục vụ nhiệt tình, sẽ ủng hộ tiếp",
|
| 22 |
-
"Hàng chính hãng, giá tốt",
|
| 23 |
-
"Rất đáng tiền, recommend cho mọi người",
|
| 24 |
-
"Chất liệu tốt, may đẹp",
|
| 25 |
-
"Hàng tệ quá, không như mô tả",
|
| 26 |
-
"Chất lượng kém, không đáng tiền",
|
| 27 |
-
"Giao hàng chậm, đóng gói cẩu thả",
|
| 28 |
-
"Sản phẩm lỗi, shop không hỗ trợ",
|
| 29 |
-
"Thất vọng, không bao giờ mua lại",
|
| 30 |
-
"Hàng giả, không nên mua",
|
| 31 |
-
"Tệ lắm, phí tiền",
|
| 32 |
-
"Màu không đúng, size sai",
|
| 33 |
-
]
|
| 34 |
-
train_labels = [
|
| 35 |
-
"positive", "positive", "positive", "positive",
|
| 36 |
-
"positive", "positive", "positive", "positive",
|
| 37 |
-
"negative", "negative", "negative", "negative",
|
| 38 |
-
"negative", "negative", "negative", "negative",
|
| 39 |
-
]
|
| 40 |
-
|
| 41 |
-
val_texts = [
|
| 42 |
-
"Hàng ok, sẽ mua lại",
|
| 43 |
-
"Tệ lắm, không nên mua",
|
| 44 |
-
]
|
| 45 |
-
val_labels = ["positive", "negative"]
|
| 46 |
-
|
| 47 |
-
# Initialize and train
|
| 48 |
-
classifier = SenTextClassifier(
|
| 49 |
-
max_features=1000,
|
| 50 |
-
ngram_range=(1, 2),
|
| 51 |
-
min_df=1,
|
| 52 |
-
C=1.0,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
print("\nTraining...")
|
| 56 |
-
history = classifier.train(
|
| 57 |
-
train_texts=train_texts,
|
| 58 |
-
train_labels=train_labels,
|
| 59 |
-
val_texts=val_texts,
|
| 60 |
-
val_labels=val_labels,
|
| 61 |
-
)
|
| 62 |
-
print()
|
| 63 |
-
|
| 64 |
-
# Test predictions
|
| 65 |
-
print("Testing predictions:")
|
| 66 |
-
test_texts = [
|
| 67 |
-
"Sản phẩm tuyệt vời!",
|
| 68 |
-
"Hàng rất tệ",
|
| 69 |
-
"Giao hàng nhanh, hàng đẹp",
|
| 70 |
-
"Thất vọng với chất lượng",
|
| 71 |
-
"Chất lượng tốt, giá hợp lý",
|
| 72 |
-
"Không đáng tiền, hàng kém",
|
| 73 |
-
]
|
| 74 |
-
|
| 75 |
-
for text in test_texts:
|
| 76 |
-
sentence = Sentence(text)
|
| 77 |
-
classifier.predict(sentence)
|
| 78 |
-
print(f" '{text}' -> {sentence.labels[0]}")
|
| 79 |
-
print()
|
| 80 |
-
|
| 81 |
-
# Test batch prediction
|
| 82 |
-
print("Batch prediction:")
|
| 83 |
-
labels = classifier.predict_batch(test_texts)
|
| 84 |
-
for text, label in zip(test_texts, labels):
|
| 85 |
-
print(f" '{text}' -> {label}")
|
| 86 |
-
print()
|
| 87 |
-
|
| 88 |
-
# Save model
|
| 89 |
-
save_path = "/tmp/sen-classifier-sklearn"
|
| 90 |
-
classifier.save(save_path)
|
| 91 |
-
|
| 92 |
-
# Load and test
|
| 93 |
-
print("\nLoading saved model...")
|
| 94 |
-
loaded_classifier = SenTextClassifier.load(save_path)
|
| 95 |
-
|
| 96 |
-
sentence = Sentence("Rất hài lòng với sản phẩm")
|
| 97 |
-
loaded_classifier.predict(sentence)
|
| 98 |
-
print(f"Loaded model prediction: '{sentence.text}' -> {sentence.labels[0]}")
|
| 99 |
-
print()
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def test_multiclass():
|
| 103 |
-
"""Test multi-class classification (news categories)."""
|
| 104 |
-
print("=" * 60)
|
| 105 |
-
print("Test: Multi-class Classification (News Categories)")
|
| 106 |
-
print("=" * 60)
|
| 107 |
-
|
| 108 |
-
# Sample Vietnamese news data (simulating VNTC categories)
|
| 109 |
-
train_texts = [
|
| 110 |
-
# Thể thao (Sports)
|
| 111 |
-
"Đội tuyển Việt Nam thắng 3-0 trước Indonesia",
|
| 112 |
-
"Cầu thủ Nguyễn Quang Hải ghi bàn đẹp mắt",
|
| 113 |
-
"V-League 2024 khởi tranh vào tháng tới",
|
| 114 |
-
"HLV Park Hang-seo chia tay bóng đá Việt Nam",
|
| 115 |
-
# Kinh doanh (Business)
|
| 116 |
-
"Chứng khoán tăng điểm mạnh phiên đầu tuần",
|
| 117 |
-
"Ngân hàng Nhà nước điều chỉnh lãi suất",
|
| 118 |
-
"Doanh nghiệp xuất khẩu gặp khó khăn",
|
| 119 |
-
"Thị trường bất động sản phục hồi",
|
| 120 |
-
# Công nghệ (Technology)
|
| 121 |
-
"Apple ra mắt iPhone 16 với nhiều tính năng mới",
|
| 122 |
-
"Trí tuệ nhân tạo đang thay đổi cuộc sống",
|
| 123 |
-
"Startup công nghệ Việt Nam gọi vốn thành công",
|
| 124 |
-
"5G được triển khai rộng rãi tại Việt Nam",
|
| 125 |
-
]
|
| 126 |
-
train_labels = [
|
| 127 |
-
"the_thao", "the_thao", "the_thao", "the_thao",
|
| 128 |
-
"kinh_doanh", "kinh_doanh", "kinh_doanh", "kinh_doanh",
|
| 129 |
-
"cong_nghe", "cong_nghe", "cong_nghe", "cong_nghe",
|
| 130 |
-
]
|
| 131 |
-
|
| 132 |
-
# Initialize and train
|
| 133 |
-
classifier = SenTextClassifier(
|
| 134 |
-
max_features=500,
|
| 135 |
-
ngram_range=(1, 2),
|
| 136 |
-
min_df=1,
|
| 137 |
-
)
|
| 138 |
-
|
| 139 |
-
print("\nTraining...")
|
| 140 |
-
classifier.train(train_texts, train_labels)
|
| 141 |
-
print()
|
| 142 |
-
|
| 143 |
-
# Test predictions
|
| 144 |
-
print("Testing predictions:")
|
| 145 |
-
test_texts = [
|
| 146 |
-
"Ronaldo ghi hat-trick trong trận đấu",
|
| 147 |
-
"VN-Index tăng 10 điểm hôm nay",
|
| 148 |
-
"Samsung ra mắt điện thoại mới",
|
| 149 |
-
"Đội bóng đá Việt Nam vô đ���ch AFF Cup",
|
| 150 |
-
"Lãi suất ngân hàng giảm mạnh",
|
| 151 |
-
]
|
| 152 |
-
|
| 153 |
-
for text in test_texts:
|
| 154 |
-
sentence = Sentence(text)
|
| 155 |
-
classifier.predict(sentence)
|
| 156 |
-
print(f" '{text}' -> {sentence.labels[0]}")
|
| 157 |
-
print()
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
if __name__ == "__main__":
|
| 161 |
-
test_training()
|
| 162 |
-
test_multiclass()
|
| 163 |
-
print("=" * 60)
|
| 164 |
-
print("All tests completed!")
|
| 165 |
-
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|