Add Rust extensions for TF-IDF and Linear SVM
Browse files- Implement TF-IDF vectorizer in pure Rust with PyO3 bindings
- Implement Linear SVM with LIBLINEAR-style Dual Coordinate Descent
- Achieve 88.15% accuracy on VNTC (vs sklearn's 89.48%)
- Update text_classifier.py to use new Rust backend
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- extensions/underthesea_core_extend/.gitignore +5 -0
- extensions/underthesea_core_extend/Cargo.lock +351 -0
- extensions/underthesea_core_extend/Cargo.toml +23 -0
- extensions/underthesea_core_extend/pyproject.toml +22 -0
- extensions/underthesea_core_extend/src/lib.rs +21 -0
- extensions/underthesea_core_extend/src/svm.rs +512 -0
- extensions/underthesea_core_extend/src/tfidf.rs +235 -0
- extensions/underthesea_core_extend/uv.lock +8 -0
- pyproject.toml +5 -6
- src/sen/text_classifier.py +136 -105
extensions/underthesea_core_extend/.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target/
|
| 2 |
+
.venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.so
|
| 5 |
+
*.pyc
|
extensions/underthesea_core_extend/Cargo.lock
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is automatically @generated by Cargo.
|
| 2 |
+
# It is not intended for manual editing.
|
| 3 |
+
version = 4
|
| 4 |
+
|
| 5 |
+
[[package]]
|
| 6 |
+
name = "allocator-api2"
|
| 7 |
+
version = "0.2.21"
|
| 8 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 9 |
+
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
| 10 |
+
|
| 11 |
+
[[package]]
|
| 12 |
+
name = "autocfg"
|
| 13 |
+
version = "1.5.0"
|
| 14 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 15 |
+
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
| 16 |
+
|
| 17 |
+
[[package]]
|
| 18 |
+
name = "cfg-if"
|
| 19 |
+
version = "1.0.4"
|
| 20 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 21 |
+
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
| 22 |
+
|
| 23 |
+
[[package]]
|
| 24 |
+
name = "crossbeam-deque"
|
| 25 |
+
version = "0.8.6"
|
| 26 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 27 |
+
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
| 28 |
+
dependencies = [
|
| 29 |
+
"crossbeam-epoch",
|
| 30 |
+
"crossbeam-utils",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[[package]]
|
| 34 |
+
name = "crossbeam-epoch"
|
| 35 |
+
version = "0.9.18"
|
| 36 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 37 |
+
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
| 38 |
+
dependencies = [
|
| 39 |
+
"crossbeam-utils",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
[[package]]
|
| 43 |
+
name = "crossbeam-utils"
|
| 44 |
+
version = "0.8.21"
|
| 45 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 46 |
+
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
| 47 |
+
|
| 48 |
+
[[package]]
|
| 49 |
+
name = "either"
|
| 50 |
+
version = "1.15.0"
|
| 51 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 52 |
+
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
| 53 |
+
|
| 54 |
+
[[package]]
|
| 55 |
+
name = "equivalent"
|
| 56 |
+
version = "1.0.2"
|
| 57 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 58 |
+
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
| 59 |
+
|
| 60 |
+
[[package]]
|
| 61 |
+
name = "foldhash"
|
| 62 |
+
version = "0.1.5"
|
| 63 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 64 |
+
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
| 65 |
+
|
| 66 |
+
[[package]]
|
| 67 |
+
name = "hashbrown"
|
| 68 |
+
version = "0.15.5"
|
| 69 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 70 |
+
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
| 71 |
+
dependencies = [
|
| 72 |
+
"allocator-api2",
|
| 73 |
+
"equivalent",
|
| 74 |
+
"foldhash",
|
| 75 |
+
"serde",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
[[package]]
|
| 79 |
+
name = "heck"
|
| 80 |
+
version = "0.5.0"
|
| 81 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 82 |
+
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
| 83 |
+
|
| 84 |
+
[[package]]
|
| 85 |
+
name = "indoc"
|
| 86 |
+
version = "2.0.7"
|
| 87 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 88 |
+
checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
|
| 89 |
+
dependencies = [
|
| 90 |
+
"rustversion",
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
[[package]]
|
| 94 |
+
name = "itoa"
|
| 95 |
+
version = "1.0.17"
|
| 96 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 97 |
+
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
|
| 98 |
+
|
| 99 |
+
[[package]]
|
| 100 |
+
name = "libc"
|
| 101 |
+
version = "0.2.180"
|
| 102 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 103 |
+
checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
|
| 104 |
+
|
| 105 |
+
[[package]]
|
| 106 |
+
name = "memchr"
|
| 107 |
+
version = "2.7.6"
|
| 108 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 109 |
+
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
|
| 110 |
+
|
| 111 |
+
[[package]]
|
| 112 |
+
name = "memoffset"
|
| 113 |
+
version = "0.9.1"
|
| 114 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 115 |
+
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
| 116 |
+
dependencies = [
|
| 117 |
+
"autocfg",
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
[[package]]
|
| 121 |
+
name = "once_cell"
|
| 122 |
+
version = "1.21.3"
|
| 123 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 124 |
+
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
| 125 |
+
|
| 126 |
+
[[package]]
|
| 127 |
+
name = "portable-atomic"
|
| 128 |
+
version = "1.13.1"
|
| 129 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 130 |
+
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
| 131 |
+
|
| 132 |
+
[[package]]
|
| 133 |
+
name = "proc-macro2"
|
| 134 |
+
version = "1.0.106"
|
| 135 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 136 |
+
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
| 137 |
+
dependencies = [
|
| 138 |
+
"unicode-ident",
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
[[package]]
|
| 142 |
+
name = "pyo3"
|
| 143 |
+
version = "0.22.6"
|
| 144 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 145 |
+
checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
|
| 146 |
+
dependencies = [
|
| 147 |
+
"cfg-if",
|
| 148 |
+
"indoc",
|
| 149 |
+
"libc",
|
| 150 |
+
"memoffset",
|
| 151 |
+
"once_cell",
|
| 152 |
+
"portable-atomic",
|
| 153 |
+
"pyo3-build-config",
|
| 154 |
+
"pyo3-ffi",
|
| 155 |
+
"pyo3-macros",
|
| 156 |
+
"unindent",
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
[[package]]
|
| 160 |
+
name = "pyo3-build-config"
|
| 161 |
+
version = "0.22.6"
|
| 162 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 163 |
+
checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
|
| 164 |
+
dependencies = [
|
| 165 |
+
"once_cell",
|
| 166 |
+
"target-lexicon",
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
[[package]]
|
| 170 |
+
name = "pyo3-ffi"
|
| 171 |
+
version = "0.22.6"
|
| 172 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 173 |
+
checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
|
| 174 |
+
dependencies = [
|
| 175 |
+
"libc",
|
| 176 |
+
"pyo3-build-config",
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
[[package]]
|
| 180 |
+
name = "pyo3-macros"
|
| 181 |
+
version = "0.22.6"
|
| 182 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 183 |
+
checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
|
| 184 |
+
dependencies = [
|
| 185 |
+
"proc-macro2",
|
| 186 |
+
"pyo3-macros-backend",
|
| 187 |
+
"quote",
|
| 188 |
+
"syn",
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
[[package]]
|
| 192 |
+
name = "pyo3-macros-backend"
|
| 193 |
+
version = "0.22.6"
|
| 194 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 195 |
+
checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
|
| 196 |
+
dependencies = [
|
| 197 |
+
"heck",
|
| 198 |
+
"proc-macro2",
|
| 199 |
+
"pyo3-build-config",
|
| 200 |
+
"quote",
|
| 201 |
+
"syn",
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
[[package]]
|
| 205 |
+
name = "quote"
|
| 206 |
+
version = "1.0.44"
|
| 207 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 208 |
+
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
|
| 209 |
+
dependencies = [
|
| 210 |
+
"proc-macro2",
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
[[package]]
|
| 214 |
+
name = "rayon"
|
| 215 |
+
version = "1.11.0"
|
| 216 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 217 |
+
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
|
| 218 |
+
dependencies = [
|
| 219 |
+
"either",
|
| 220 |
+
"rayon-core",
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
[[package]]
|
| 224 |
+
name = "rayon-core"
|
| 225 |
+
version = "1.13.0"
|
| 226 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 227 |
+
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
|
| 228 |
+
dependencies = [
|
| 229 |
+
"crossbeam-deque",
|
| 230 |
+
"crossbeam-utils",
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
[[package]]
|
| 234 |
+
name = "rustversion"
|
| 235 |
+
version = "1.0.22"
|
| 236 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 237 |
+
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
| 238 |
+
|
| 239 |
+
[[package]]
|
| 240 |
+
name = "serde"
|
| 241 |
+
version = "1.0.228"
|
| 242 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 243 |
+
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
|
| 244 |
+
dependencies = [
|
| 245 |
+
"serde_core",
|
| 246 |
+
"serde_derive",
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
[[package]]
|
| 250 |
+
name = "serde_core"
|
| 251 |
+
version = "1.0.228"
|
| 252 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 253 |
+
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
|
| 254 |
+
dependencies = [
|
| 255 |
+
"serde_derive",
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
[[package]]
|
| 259 |
+
name = "serde_derive"
|
| 260 |
+
version = "1.0.228"
|
| 261 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 262 |
+
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
| 263 |
+
dependencies = [
|
| 264 |
+
"proc-macro2",
|
| 265 |
+
"quote",
|
| 266 |
+
"syn",
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
[[package]]
|
| 270 |
+
name = "serde_json"
|
| 271 |
+
version = "1.0.149"
|
| 272 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 273 |
+
checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
|
| 274 |
+
dependencies = [
|
| 275 |
+
"itoa",
|
| 276 |
+
"memchr",
|
| 277 |
+
"serde",
|
| 278 |
+
"serde_core",
|
| 279 |
+
"zmij",
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
[[package]]
|
| 283 |
+
name = "syn"
|
| 284 |
+
version = "2.0.114"
|
| 285 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 286 |
+
checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
|
| 287 |
+
dependencies = [
|
| 288 |
+
"proc-macro2",
|
| 289 |
+
"quote",
|
| 290 |
+
"unicode-ident",
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
[[package]]
|
| 294 |
+
name = "target-lexicon"
|
| 295 |
+
version = "0.12.16"
|
| 296 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 297 |
+
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
| 298 |
+
|
| 299 |
+
[[package]]
|
| 300 |
+
name = "tinyvec"
|
| 301 |
+
version = "1.10.0"
|
| 302 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 303 |
+
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
| 304 |
+
dependencies = [
|
| 305 |
+
"tinyvec_macros",
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
[[package]]
|
| 309 |
+
name = "tinyvec_macros"
|
| 310 |
+
version = "0.1.1"
|
| 311 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 312 |
+
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
| 313 |
+
|
| 314 |
+
[[package]]
|
| 315 |
+
name = "underthesea_core_extend"
|
| 316 |
+
version = "0.1.0"
|
| 317 |
+
dependencies = [
|
| 318 |
+
"hashbrown",
|
| 319 |
+
"pyo3",
|
| 320 |
+
"rayon",
|
| 321 |
+
"serde",
|
| 322 |
+
"serde_json",
|
| 323 |
+
"unicode-normalization",
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
[[package]]
|
| 327 |
+
name = "unicode-ident"
|
| 328 |
+
version = "1.0.22"
|
| 329 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 330 |
+
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
| 331 |
+
|
| 332 |
+
[[package]]
|
| 333 |
+
name = "unicode-normalization"
|
| 334 |
+
version = "0.1.25"
|
| 335 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 336 |
+
checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
|
| 337 |
+
dependencies = [
|
| 338 |
+
"tinyvec",
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
[[package]]
|
| 342 |
+
name = "unindent"
|
| 343 |
+
version = "0.2.4"
|
| 344 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 345 |
+
checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
|
| 346 |
+
|
| 347 |
+
[[package]]
|
| 348 |
+
name = "zmij"
|
| 349 |
+
version = "1.0.19"
|
| 350 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 351 |
+
checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445"
|
extensions/underthesea_core_extend/Cargo.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[package]
|
| 2 |
+
name = "underthesea_core_extend"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
edition = "2021"
|
| 5 |
+
description = "Rust extensions for underthesea - Text Classification"
|
| 6 |
+
license = "Apache-2.0"
|
| 7 |
+
|
| 8 |
+
[lib]
|
| 9 |
+
name = "underthesea_core_extend"
|
| 10 |
+
crate-type = ["cdylib"]
|
| 11 |
+
|
| 12 |
+
[dependencies]
|
| 13 |
+
pyo3 = { version = "0.22", features = ["extension-module"] }
|
| 14 |
+
serde = { version = "1.0", features = ["derive"] }
|
| 15 |
+
serde_json = "1.0"
|
| 16 |
+
rayon = "1.10"
|
| 17 |
+
hashbrown = { version = "0.15", features = ["serde"] }
|
| 18 |
+
unicode-normalization = "0.1"
|
| 19 |
+
|
| 20 |
+
[profile.release]
|
| 21 |
+
lto = true
|
| 22 |
+
codegen-units = 1
|
| 23 |
+
opt-level = 3
|
extensions/underthesea_core_extend/pyproject.toml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["maturin>=1.0,<2.0"]
|
| 3 |
+
build-backend = "maturin"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "underthesea_core_extend"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Rust extensions for underthesea - Text Classification"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
license = { text = "Apache-2.0" }
|
| 11 |
+
authors = [{ name = "UnderTheSea NLP", email = "anhv.ict91@gmail.com" }]
|
| 12 |
+
classifiers = [
|
| 13 |
+
"Programming Language :: Rust",
|
| 14 |
+
"Programming Language :: Python :: Implementation :: CPython",
|
| 15 |
+
"Programming Language :: Python :: 3.10",
|
| 16 |
+
"Programming Language :: Python :: 3.11",
|
| 17 |
+
"Programming Language :: Python :: 3.12",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[tool.maturin]
|
| 21 |
+
features = ["pyo3/extension-module"]
|
| 22 |
+
module-name = "underthesea_core_extend"
|
extensions/underthesea_core_extend/src/lib.rs
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! underthesea_core_extend - Rust extensions for Vietnamese Text Classification
|
| 2 |
+
//!
|
| 3 |
+
//! Provides fast TF-IDF vectorization and Linear SVM classification.
|
| 4 |
+
|
| 5 |
+
use pyo3::prelude::*;
|
| 6 |
+
|
| 7 |
+
mod tfidf;
|
| 8 |
+
mod svm;
|
| 9 |
+
|
| 10 |
+
pub use tfidf::TfIdfVectorizer;
|
| 11 |
+
pub use svm::{LinearSVM, SVMTrainer, FastSVMTrainer};
|
| 12 |
+
|
| 13 |
+
/// Python module
|
| 14 |
+
#[pymodule]
|
| 15 |
+
fn underthesea_core_extend(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 16 |
+
m.add_class::<TfIdfVectorizer>()?;
|
| 17 |
+
m.add_class::<LinearSVM>()?;
|
| 18 |
+
m.add_class::<SVMTrainer>()?;
|
| 19 |
+
m.add_class::<FastSVMTrainer>()?;
|
| 20 |
+
Ok(())
|
| 21 |
+
}
|
extensions/underthesea_core_extend/src/svm.rs
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Optimized Linear SVM - LIBLINEAR-style Dual Coordinate Descent
|
| 2 |
+
//!
|
| 3 |
+
//! Pure Rust implementation of L2-regularized L2-loss SVM (dual form)
|
| 4 |
+
//! Reference: "A Dual Coordinate Descent Method for Large-scale Linear SVM"
|
| 5 |
+
//! Hsieh et al., ICML 2008
|
| 6 |
+
|
| 7 |
+
use hashbrown::HashMap;
|
| 8 |
+
use pyo3::prelude::*;
|
| 9 |
+
use rayon::prelude::*;
|
| 10 |
+
use serde::{Deserialize, Serialize};
|
| 11 |
+
use std::fs::File;
|
| 12 |
+
use std::io::{BufReader, BufWriter};
|
| 13 |
+
|
| 14 |
+
/// Sparse feature vector
|
| 15 |
+
pub type SparseVec = Vec<(u32, f32)>; // Use u32/f32 for memory efficiency
|
| 16 |
+
|
| 17 |
+
/// Linear SVM Model
|
| 18 |
+
#[pyclass]
|
| 19 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 20 |
+
pub struct LinearSVM {
|
| 21 |
+
weights: Vec<Vec<f32>>,
|
| 22 |
+
biases: Vec<f32>,
|
| 23 |
+
classes: Vec<String>,
|
| 24 |
+
n_features: usize,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
#[pymethods]
|
| 28 |
+
impl LinearSVM {
|
| 29 |
+
#[new]
|
| 30 |
+
pub fn new() -> Self {
|
| 31 |
+
Self {
|
| 32 |
+
weights: Vec::new(),
|
| 33 |
+
biases: Vec::new(),
|
| 34 |
+
classes: Vec::new(),
|
| 35 |
+
n_features: 0,
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
pub fn predict(&self, features: Vec<f64>) -> String {
|
| 40 |
+
let idx = self.predict_idx(&features);
|
| 41 |
+
self.classes[idx].clone()
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
pub fn predict_with_score(&self, features: Vec<f64>) -> (String, f64) {
|
| 45 |
+
let scores = self.decision_scores(&features);
|
| 46 |
+
let (idx, &max_score) = scores
|
| 47 |
+
.iter()
|
| 48 |
+
.enumerate()
|
| 49 |
+
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
| 50 |
+
.unwrap();
|
| 51 |
+
let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
|
| 52 |
+
(self.classes[idx].clone(), confidence)
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
pub fn predict_batch(&self, features_batch: Vec<Vec<f64>>) -> Vec<String> {
|
| 56 |
+
features_batch
|
| 57 |
+
.par_iter()
|
| 58 |
+
.map(|f| {
|
| 59 |
+
let idx = self.predict_idx(f);
|
| 60 |
+
self.classes[idx].clone()
|
| 61 |
+
})
|
| 62 |
+
.collect()
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
pub fn predict_batch_sparse(&self, features_batch: Vec<Vec<(usize, f64)>>) -> Vec<String> {
|
| 66 |
+
features_batch
|
| 67 |
+
.par_iter()
|
| 68 |
+
.map(|f| {
|
| 69 |
+
let idx = self.predict_idx_sparse(f);
|
| 70 |
+
self.classes[idx].clone()
|
| 71 |
+
})
|
| 72 |
+
.collect()
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
pub fn predict_sparse_with_score(&self, features: Vec<(usize, f64)>) -> (String, f64) {
|
| 76 |
+
let scores = self.decision_scores_sparse(&features);
|
| 77 |
+
let (idx, &max_score) = scores
|
| 78 |
+
.iter()
|
| 79 |
+
.enumerate()
|
| 80 |
+
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
| 81 |
+
.unwrap();
|
| 82 |
+
let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
|
| 83 |
+
(self.classes[idx].clone(), confidence)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
pub fn decision_function(&self, features: Vec<f64>) -> Vec<f64> {
|
| 87 |
+
self.decision_scores(&features).into_iter().map(|x| x as f64).collect()
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
#[getter]
|
| 91 |
+
pub fn classes(&self) -> Vec<String> {
|
| 92 |
+
self.classes.clone()
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
#[getter]
|
| 96 |
+
pub fn n_classes(&self) -> usize {
|
| 97 |
+
self.classes.len()
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
#[getter]
|
| 101 |
+
pub fn n_features(&self) -> usize {
|
| 102 |
+
self.n_features
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
pub fn save(&self, path: &str) -> PyResult<()> {
|
| 106 |
+
let file = File::create(path)
|
| 107 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 108 |
+
let writer = BufWriter::new(file);
|
| 109 |
+
serde_json::to_writer(writer, self)
|
| 110 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 111 |
+
Ok(())
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#[staticmethod]
|
| 115 |
+
pub fn load(path: &str) -> PyResult<Self> {
|
| 116 |
+
let file = File::open(path)
|
| 117 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 118 |
+
let reader = BufReader::new(file);
|
| 119 |
+
let model: Self = serde_json::from_reader(reader)
|
| 120 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 121 |
+
Ok(model)
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
impl LinearSVM {
|
| 126 |
+
#[inline]
|
| 127 |
+
fn predict_idx(&self, features: &[f64]) -> usize {
|
| 128 |
+
let mut best_idx = 0;
|
| 129 |
+
let mut best_score = f32::NEG_INFINITY;
|
| 130 |
+
|
| 131 |
+
for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
| 132 |
+
let score: f32 = w.iter()
|
| 133 |
+
.zip(features.iter())
|
| 134 |
+
.map(|(&wi, &fi)| wi * fi as f32)
|
| 135 |
+
.sum::<f32>() + b;
|
| 136 |
+
|
| 137 |
+
if score > best_score {
|
| 138 |
+
best_score = score;
|
| 139 |
+
best_idx = idx;
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
best_idx
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
#[inline]
|
| 146 |
+
fn predict_idx_sparse(&self, features: &[(usize, f64)]) -> usize {
|
| 147 |
+
let mut best_idx = 0;
|
| 148 |
+
let mut best_score = f32::NEG_INFINITY;
|
| 149 |
+
|
| 150 |
+
for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
| 151 |
+
let score: f32 = features.iter()
|
| 152 |
+
.map(|&(j, v)| w[j] * v as f32)
|
| 153 |
+
.sum::<f32>() + b;
|
| 154 |
+
|
| 155 |
+
if score > best_score {
|
| 156 |
+
best_score = score;
|
| 157 |
+
best_idx = idx;
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
best_idx
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
fn decision_scores(&self, features: &[f64]) -> Vec<f32> {
|
| 164 |
+
self.weights
|
| 165 |
+
.iter()
|
| 166 |
+
.zip(self.biases.iter())
|
| 167 |
+
.map(|(w, &b)| {
|
| 168 |
+
w.iter()
|
| 169 |
+
.zip(features.iter())
|
| 170 |
+
.map(|(&wi, &fi)| wi * fi as f32)
|
| 171 |
+
.sum::<f32>() + b
|
| 172 |
+
})
|
| 173 |
+
.collect()
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
fn decision_scores_sparse(&self, features: &[(usize, f64)]) -> Vec<f32> {
|
| 177 |
+
self.weights
|
| 178 |
+
.iter()
|
| 179 |
+
.zip(self.biases.iter())
|
| 180 |
+
.map(|(w, &b)| {
|
| 181 |
+
features.iter()
|
| 182 |
+
.map(|&(j, v)| w[j] * v as f32)
|
| 183 |
+
.sum::<f32>() + b
|
| 184 |
+
})
|
| 185 |
+
.collect()
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
/// LIBLINEAR-style SVM Trainer
|
| 190 |
+
#[pyclass]
|
| 191 |
+
pub struct SVMTrainer {
|
| 192 |
+
c: f64,
|
| 193 |
+
max_iter: usize,
|
| 194 |
+
tol: f64,
|
| 195 |
+
verbose: bool,
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
#[pymethods]
|
| 199 |
+
impl SVMTrainer {
|
| 200 |
+
#[new]
|
| 201 |
+
#[pyo3(signature = (c=1.0, max_iter=1000, tol=0.1, verbose=false))]
|
| 202 |
+
pub fn new(c: f64, max_iter: usize, tol: f64, verbose: bool) -> Self {
|
| 203 |
+
Self { c, max_iter, tol, verbose }
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
pub fn set_c(&mut self, c: f64) {
|
| 207 |
+
self.c = c;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
pub fn set_max_iter(&mut self, max_iter: usize) {
|
| 211 |
+
self.max_iter = max_iter;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
pub fn set_verbose(&mut self, verbose: bool) {
|
| 215 |
+
self.verbose = verbose;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
|
| 219 |
+
let n_samples = features.len();
|
| 220 |
+
let n_features = if n_samples > 0 { features[0].len() } else { 0 };
|
| 221 |
+
|
| 222 |
+
// Convert to compact sparse format (f32 for memory/cache efficiency)
|
| 223 |
+
let sparse_features: Vec<SparseVec> = features
|
| 224 |
+
.par_iter()
|
| 225 |
+
.map(|dense| {
|
| 226 |
+
dense
|
| 227 |
+
.iter()
|
| 228 |
+
.enumerate()
|
| 229 |
+
.filter(|&(_, &v)| v.abs() > 1e-10)
|
| 230 |
+
.map(|(i, &v)| (i as u32, v as f32))
|
| 231 |
+
.collect()
|
| 232 |
+
})
|
| 233 |
+
.collect();
|
| 234 |
+
|
| 235 |
+
// Precompute ||x_i||^2
|
| 236 |
+
let x_sq_norms: Vec<f32> = sparse_features
|
| 237 |
+
.par_iter()
|
| 238 |
+
.map(|x| x.iter().map(|&(_, v)| v * v).sum())
|
| 239 |
+
.collect();
|
| 240 |
+
|
| 241 |
+
// Get unique classes
|
| 242 |
+
let mut classes: Vec<String> = labels.iter().cloned().collect();
|
| 243 |
+
classes.sort();
|
| 244 |
+
classes.dedup();
|
| 245 |
+
let n_classes = classes.len();
|
| 246 |
+
|
| 247 |
+
let class_to_idx: HashMap<String, usize> = classes
|
| 248 |
+
.iter()
|
| 249 |
+
.enumerate()
|
| 250 |
+
.map(|(i, c)| (c.clone(), i))
|
| 251 |
+
.collect();
|
| 252 |
+
|
| 253 |
+
let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
|
| 254 |
+
|
| 255 |
+
// Train binary classifiers in parallel (one-vs-rest)
|
| 256 |
+
let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
|
| 257 |
+
.into_par_iter()
|
| 258 |
+
.map(|class_idx| {
|
| 259 |
+
let y_binary: Vec<i8> = y_idx
|
| 260 |
+
.iter()
|
| 261 |
+
.map(|&idx| if idx == class_idx { 1 } else { -1 })
|
| 262 |
+
.collect();
|
| 263 |
+
|
| 264 |
+
solve_l2r_l2_svc(
|
| 265 |
+
&sparse_features,
|
| 266 |
+
&y_binary,
|
| 267 |
+
&x_sq_norms,
|
| 268 |
+
n_features,
|
| 269 |
+
self.c as f32,
|
| 270 |
+
self.tol as f32,
|
| 271 |
+
self.max_iter,
|
| 272 |
+
)
|
| 273 |
+
})
|
| 274 |
+
.collect();
|
| 275 |
+
|
| 276 |
+
let weights = results.iter().map(|(w, _)| w.clone()).collect();
|
| 277 |
+
let biases = results.iter().map(|(_, b)| *b).collect();
|
| 278 |
+
|
| 279 |
+
LinearSVM {
|
| 280 |
+
weights,
|
| 281 |
+
biases,
|
| 282 |
+
classes,
|
| 283 |
+
n_features,
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
/// LIBLINEAR's solve_l2r_l2_svc - Dual Coordinate Descent for L2-loss SVM
|
| 289 |
+
///
|
| 290 |
+
/// Solves: min_α 0.5 * α^T * Q * α - e^T * α, s.t. α_i ≥ 0
|
| 291 |
+
/// where Q_ij = y_i * y_j * x_i^T * x_j + δ_ij / (2C)
|
| 292 |
+
///
|
| 293 |
+
/// Primal-dual relationship: w = Σ α_i * y_i * x_i
|
| 294 |
+
#[inline(never)]
|
| 295 |
+
fn solve_l2r_l2_svc(
|
| 296 |
+
x: &[SparseVec],
|
| 297 |
+
y: &[i8],
|
| 298 |
+
x_sq_norms: &[f32],
|
| 299 |
+
n_features: usize,
|
| 300 |
+
c: f32,
|
| 301 |
+
eps: f32,
|
| 302 |
+
max_iter: usize,
|
| 303 |
+
) -> (Vec<f32>, f32) {
|
| 304 |
+
let n = x.len();
|
| 305 |
+
|
| 306 |
+
// D_ii = 1/(2C) for L2-loss SVM
|
| 307 |
+
let diag = 0.5 / c;
|
| 308 |
+
|
| 309 |
+
// QD[i] = ||x_i||^2 + D_ii
|
| 310 |
+
let qd: Vec<f32> = x_sq_norms.iter().map(|&xn| xn + diag).collect();
|
| 311 |
+
|
| 312 |
+
// Initialize α = 0
|
| 313 |
+
let mut alpha = vec![0.0f32; n];
|
| 314 |
+
|
| 315 |
+
// w = Σ α_i * y_i * x_i (initially 0)
|
| 316 |
+
let mut w = vec![0.0f32; n_features];
|
| 317 |
+
|
| 318 |
+
// Index for permutation
|
| 319 |
+
let mut index: Vec<usize> = (0..n).collect();
|
| 320 |
+
|
| 321 |
+
// Main loop
|
| 322 |
+
for iter in 0..max_iter {
|
| 323 |
+
// Shuffle indices
|
| 324 |
+
for i in 0..n {
|
| 325 |
+
let j = i + (iter * 1103515245 + 12345) % (n - i).max(1);
|
| 326 |
+
index.swap(i, j);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
let mut max_violation = 0.0f32;
|
| 330 |
+
|
| 331 |
+
for &i in &index {
|
| 332 |
+
let yi = y[i] as f32;
|
| 333 |
+
let xi = &x[i];
|
| 334 |
+
|
| 335 |
+
// G = y_i * (w · x_i) - 1 + D_ii * α_i
|
| 336 |
+
let wxi: f32 = xi.iter().map(|&(j, v)| w[j as usize] * v).sum();
|
| 337 |
+
let g = yi * wxi - 1.0 + diag * alpha[i];
|
| 338 |
+
|
| 339 |
+
// Projected gradient (α ≥ 0, no upper bound for L2-loss)
|
| 340 |
+
let pg = if alpha[i] == 0.0 { g.min(0.0) } else { g };
|
| 341 |
+
|
| 342 |
+
max_violation = max_violation.max(pg.abs());
|
| 343 |
+
|
| 344 |
+
if pg.abs() > 1e-12 {
|
| 345 |
+
let alpha_old = alpha[i];
|
| 346 |
+
|
| 347 |
+
// α_i = max(0, α_i - G/Q_ii)
|
| 348 |
+
alpha[i] = (alpha[i] - g / qd[i]).max(0.0);
|
| 349 |
+
|
| 350 |
+
// Update w: w += (α_new - α_old) * y_i * x_i
|
| 351 |
+
let d = (alpha[i] - alpha_old) * yi;
|
| 352 |
+
if d.abs() > 1e-12 {
|
| 353 |
+
for &(j, v) in xi.iter() {
|
| 354 |
+
w[j as usize] += d * v;
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// Stopping criterion
|
| 361 |
+
if max_violation <= eps {
|
| 362 |
+
break;
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
// Compute bias from KKT conditions
|
| 367 |
+
// For α_i > 0: y_i * (w · x_i + b) = 1 - α_i / (2C)
|
| 368 |
+
let mut bias_sum = 0.0f32;
|
| 369 |
+
let mut n_sv = 0;
|
| 370 |
+
|
| 371 |
+
for i in 0..n {
|
| 372 |
+
if alpha[i] > 1e-8 {
|
| 373 |
+
let yi = y[i] as f32;
|
| 374 |
+
let wxi: f32 = x[i].iter().map(|&(j, v)| w[j as usize] * v).sum();
|
| 375 |
+
// b = y_i * (1 - α_i * diag) - w · x_i
|
| 376 |
+
bias_sum += yi * (1.0 - alpha[i] * diag) - wxi;
|
| 377 |
+
n_sv += 1;
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
let bias = if n_sv > 0 { bias_sum / n_sv as f32 } else { 0.0 };
|
| 382 |
+
|
| 383 |
+
(w, bias)
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
/// Fast SVM using Pegasos algorithm
|
| 387 |
+
#[pyclass]
|
| 388 |
+
pub struct FastSVMTrainer {
|
| 389 |
+
c: f64,
|
| 390 |
+
max_iter: usize,
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
#[pymethods]
|
| 394 |
+
impl FastSVMTrainer {
|
| 395 |
+
#[new]
|
| 396 |
+
#[pyo3(signature = (c=1.0, max_iter=100))]
|
| 397 |
+
pub fn new(c: f64, max_iter: usize) -> Self {
|
| 398 |
+
Self { c, max_iter }
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
|
| 402 |
+
let n_samples = features.len();
|
| 403 |
+
let n_features = if n_samples > 0 { features[0].len() } else { 0 };
|
| 404 |
+
|
| 405 |
+
let sparse_features: Vec<SparseVec> = features
|
| 406 |
+
.par_iter()
|
| 407 |
+
.map(|dense| {
|
| 408 |
+
dense
|
| 409 |
+
.iter()
|
| 410 |
+
.enumerate()
|
| 411 |
+
.filter(|&(_, &v)| v.abs() > 1e-10)
|
| 412 |
+
.map(|(i, &v)| (i as u32, v as f32))
|
| 413 |
+
.collect()
|
| 414 |
+
})
|
| 415 |
+
.collect();
|
| 416 |
+
|
| 417 |
+
let mut classes: Vec<String> = labels.iter().cloned().collect();
|
| 418 |
+
classes.sort();
|
| 419 |
+
classes.dedup();
|
| 420 |
+
let n_classes = classes.len();
|
| 421 |
+
|
| 422 |
+
let class_to_idx: HashMap<String, usize> = classes
|
| 423 |
+
.iter()
|
| 424 |
+
.enumerate()
|
| 425 |
+
.map(|(i, c)| (c.clone(), i))
|
| 426 |
+
.collect();
|
| 427 |
+
|
| 428 |
+
let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
|
| 429 |
+
|
| 430 |
+
let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
|
| 431 |
+
.into_par_iter()
|
| 432 |
+
.map(|class_idx| {
|
| 433 |
+
let y_binary: Vec<i8> = y_idx
|
| 434 |
+
.iter()
|
| 435 |
+
.map(|&idx| if idx == class_idx { 1 } else { -1 })
|
| 436 |
+
.collect();
|
| 437 |
+
|
| 438 |
+
pegasos(&sparse_features, &y_binary, n_features, self.c as f32, self.max_iter)
|
| 439 |
+
})
|
| 440 |
+
.collect();
|
| 441 |
+
|
| 442 |
+
LinearSVM {
|
| 443 |
+
weights: results.iter().map(|(w, _)| w.clone()).collect(),
|
| 444 |
+
biases: results.iter().map(|(_, b)| *b).collect(),
|
| 445 |
+
classes,
|
| 446 |
+
n_features,
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
/// Pegasos algorithm with lazy scaling
|
| 452 |
+
#[inline(never)]
|
| 453 |
+
fn pegasos(
|
| 454 |
+
x: &[SparseVec],
|
| 455 |
+
y: &[i8],
|
| 456 |
+
n_features: usize,
|
| 457 |
+
c: f32,
|
| 458 |
+
max_iter: usize,
|
| 459 |
+
) -> (Vec<f32>, f32) {
|
| 460 |
+
let n = x.len();
|
| 461 |
+
let lambda = 1.0 / c;
|
| 462 |
+
|
| 463 |
+
let mut w = vec![0.0f32; n_features];
|
| 464 |
+
let mut scale = 1.0f32;
|
| 465 |
+
let mut b = 0.0f32;
|
| 466 |
+
|
| 467 |
+
let eta0 = 0.5;
|
| 468 |
+
let t0 = 1.0 / (eta0 * lambda);
|
| 469 |
+
|
| 470 |
+
let mut indices: Vec<usize> = (0..n).collect();
|
| 471 |
+
|
| 472 |
+
for epoch in 0..max_iter {
|
| 473 |
+
// Shuffle
|
| 474 |
+
for i in 0..n {
|
| 475 |
+
let j = (i + epoch * 1103515245 + 12345) % n;
|
| 476 |
+
indices.swap(i, j);
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
for (t_inner, &i) in indices.iter().enumerate() {
|
| 480 |
+
let t = (epoch * n + t_inner) as f32;
|
| 481 |
+
let eta = 1.0 / (lambda * (t + t0));
|
| 482 |
+
|
| 483 |
+
let yi = y[i] as f32;
|
| 484 |
+
let xi = &x[i];
|
| 485 |
+
|
| 486 |
+
let margin: f32 = scale * xi.iter().map(|&(j, v)| w[j as usize] * v).sum::<f32>() + b;
|
| 487 |
+
|
| 488 |
+
scale *= 1.0 - eta * lambda;
|
| 489 |
+
|
| 490 |
+
if scale < 1e-9 {
|
| 491 |
+
for wj in w.iter_mut() {
|
| 492 |
+
*wj *= scale;
|
| 493 |
+
}
|
| 494 |
+
scale = 1.0;
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
if yi * margin < 1.0 {
|
| 498 |
+
let update = eta / scale;
|
| 499 |
+
for &(j, v) in xi.iter() {
|
| 500 |
+
w[j as usize] += update * yi * v;
|
| 501 |
+
}
|
| 502 |
+
b += eta * yi * 0.1;
|
| 503 |
+
}
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
for wj in w.iter_mut() {
|
| 508 |
+
*wj *= scale;
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
(w, b)
|
| 512 |
+
}
|
extensions/underthesea_core_extend/src/tfidf.rs
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! TF-IDF Vectorizer implementation
|
| 2 |
+
|
| 3 |
+
use hashbrown::HashMap;
|
| 4 |
+
use pyo3::prelude::*;
|
| 5 |
+
use rayon::prelude::*;
|
| 6 |
+
use serde::{Deserialize, Serialize};
|
| 7 |
+
use std::fs::File;
|
| 8 |
+
use std::io::{BufReader, BufWriter};
|
| 9 |
+
|
| 10 |
+
/// TF-IDF Vectorizer
|
| 11 |
+
///
|
| 12 |
+
/// Converts text documents into TF-IDF feature vectors.
|
| 13 |
+
#[pyclass]
|
| 14 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 15 |
+
pub struct TfIdfVectorizer {
|
| 16 |
+
/// Vocabulary: word -> index
|
| 17 |
+
vocab: HashMap<String, usize>,
|
| 18 |
+
/// Inverse vocabulary: index -> word
|
| 19 |
+
inv_vocab: Vec<String>,
|
| 20 |
+
/// IDF values for each term
|
| 21 |
+
idf: Vec<f64>,
|
| 22 |
+
/// Number of documents used for fitting
|
| 23 |
+
n_docs: usize,
|
| 24 |
+
/// Maximum number of features
|
| 25 |
+
max_features: usize,
|
| 26 |
+
/// N-gram range (min, max)
|
| 27 |
+
ngram_range: (usize, usize),
|
| 28 |
+
/// Minimum document frequency
|
| 29 |
+
min_df: usize,
|
| 30 |
+
/// Maximum document frequency (as ratio)
|
| 31 |
+
max_df: f64,
|
| 32 |
+
/// Whether the vectorizer is fitted
|
| 33 |
+
is_fitted: bool,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
#[pymethods]
|
| 37 |
+
impl TfIdfVectorizer {
|
| 38 |
+
/// Create a new TfIdfVectorizer
|
| 39 |
+
#[new]
|
| 40 |
+
#[pyo3(signature = (max_features=20000, ngram_range=(1, 2), min_df=1, max_df=1.0))]
|
| 41 |
+
pub fn new(
|
| 42 |
+
max_features: usize,
|
| 43 |
+
ngram_range: (usize, usize),
|
| 44 |
+
min_df: usize,
|
| 45 |
+
max_df: f64,
|
| 46 |
+
) -> Self {
|
| 47 |
+
Self {
|
| 48 |
+
vocab: HashMap::new(),
|
| 49 |
+
inv_vocab: Vec::new(),
|
| 50 |
+
idf: Vec::new(),
|
| 51 |
+
n_docs: 0,
|
| 52 |
+
max_features,
|
| 53 |
+
ngram_range,
|
| 54 |
+
min_df,
|
| 55 |
+
max_df,
|
| 56 |
+
is_fitted: false,
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/// Fit the vectorizer on a list of documents
|
| 61 |
+
pub fn fit(&mut self, documents: Vec<String>) {
|
| 62 |
+
let n_docs = documents.len();
|
| 63 |
+
self.n_docs = n_docs;
|
| 64 |
+
|
| 65 |
+
// Count document frequency for each term
|
| 66 |
+
let mut df: HashMap<String, usize> = HashMap::new();
|
| 67 |
+
|
| 68 |
+
for doc in &documents {
|
| 69 |
+
let tokens = self.tokenize(doc);
|
| 70 |
+
let unique_tokens: std::collections::HashSet<_> = tokens.into_iter().collect();
|
| 71 |
+
for token in unique_tokens {
|
| 72 |
+
*df.entry(token).or_insert(0) += 1;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Filter by min_df and max_df
|
| 77 |
+
let max_df_count = (self.max_df * n_docs as f64) as usize;
|
| 78 |
+
let mut filtered: Vec<(String, usize)> = df
|
| 79 |
+
.into_iter()
|
| 80 |
+
.filter(|(_, count)| *count >= self.min_df && *count <= max_df_count)
|
| 81 |
+
.collect();
|
| 82 |
+
|
| 83 |
+
// Sort by frequency (descending) and take top max_features
|
| 84 |
+
filtered.sort_by(|a, b| b.1.cmp(&a.1));
|
| 85 |
+
filtered.truncate(self.max_features);
|
| 86 |
+
|
| 87 |
+
// Build vocabulary
|
| 88 |
+
self.vocab.clear();
|
| 89 |
+
self.inv_vocab.clear();
|
| 90 |
+
self.idf.clear();
|
| 91 |
+
|
| 92 |
+
for (idx, (term, doc_freq)) in filtered.into_iter().enumerate() {
|
| 93 |
+
self.vocab.insert(term.clone(), idx);
|
| 94 |
+
self.inv_vocab.push(term);
|
| 95 |
+
// IDF with smoothing: log((n_docs + 1) / (df + 1)) + 1
|
| 96 |
+
let idf_value = ((n_docs as f64 + 1.0) / (doc_freq as f64 + 1.0)).ln() + 1.0;
|
| 97 |
+
self.idf.push(idf_value);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
self.is_fitted = true;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
/// Transform a single document to TF-IDF vector (sparse format)
|
| 104 |
+
///
|
| 105 |
+
/// Returns list of (index, value) tuples
|
| 106 |
+
pub fn transform(&self, document: &str) -> Vec<(usize, f64)> {
|
| 107 |
+
if !self.is_fitted {
|
| 108 |
+
return Vec::new();
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
let tokens = self.tokenize(document);
|
| 112 |
+
let mut tf: HashMap<usize, usize> = HashMap::new();
|
| 113 |
+
|
| 114 |
+
for token in &tokens {
|
| 115 |
+
if let Some(&idx) = self.vocab.get(token) {
|
| 116 |
+
*tf.entry(idx).or_insert(0) += 1;
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
let n_tokens = tokens.len() as f64;
|
| 121 |
+
if n_tokens == 0.0 {
|
| 122 |
+
return Vec::new();
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
let mut result: Vec<(usize, f64)> = tf
|
| 126 |
+
.into_iter()
|
| 127 |
+
.map(|(idx, count)| {
|
| 128 |
+
let tf_value = count as f64 / n_tokens;
|
| 129 |
+
let tfidf = tf_value * self.idf[idx];
|
| 130 |
+
(idx, tfidf)
|
| 131 |
+
})
|
| 132 |
+
.collect();
|
| 133 |
+
|
| 134 |
+
// L2 normalize
|
| 135 |
+
let norm: f64 = result.iter().map(|(_, v)| v * v).sum::<f64>().sqrt();
|
| 136 |
+
if norm > 0.0 {
|
| 137 |
+
for (_, v) in &mut result {
|
| 138 |
+
*v /= norm;
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
result.sort_by_key(|(idx, _)| *idx);
|
| 143 |
+
result
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/// Transform a single document to dense TF-IDF vector
|
| 147 |
+
pub fn transform_dense(&self, document: &str) -> Vec<f64> {
|
| 148 |
+
let sparse = self.transform(document);
|
| 149 |
+
let mut dense = vec![0.0; self.vocab.len()];
|
| 150 |
+
for (idx, val) in sparse {
|
| 151 |
+
dense[idx] = val;
|
| 152 |
+
}
|
| 153 |
+
dense
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
/// Transform multiple documents to dense TF-IDF vectors (parallel)
|
| 157 |
+
pub fn transform_batch(&self, documents: Vec<String>) -> Vec<Vec<f64>> {
|
| 158 |
+
documents
|
| 159 |
+
.par_iter()
|
| 160 |
+
.map(|doc| self.transform_dense(doc))
|
| 161 |
+
.collect()
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Transform multiple documents to sparse TF-IDF vectors (parallel)
|
| 165 |
+
pub fn transform_batch_sparse(&self, documents: Vec<String>) -> Vec<Vec<(usize, f64)>> {
|
| 166 |
+
documents
|
| 167 |
+
.par_iter()
|
| 168 |
+
.map(|doc| self.transform(doc))
|
| 169 |
+
.collect()
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/// Fit and transform in one step
|
| 173 |
+
pub fn fit_transform(&mut self, documents: Vec<String>) -> Vec<Vec<f64>> {
|
| 174 |
+
self.fit(documents.clone());
|
| 175 |
+
self.transform_batch(documents)
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// Get vocabulary size
|
| 179 |
+
#[getter]
|
| 180 |
+
pub fn vocab_size(&self) -> usize {
|
| 181 |
+
self.vocab.len()
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
/// Get feature names (vocabulary terms)
|
| 185 |
+
pub fn get_feature_names(&self) -> Vec<String> {
|
| 186 |
+
self.inv_vocab.clone()
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
/// Check if vectorizer is fitted
|
| 190 |
+
#[getter]
|
| 191 |
+
pub fn is_fitted(&self) -> bool {
|
| 192 |
+
self.is_fitted
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
/// Save vectorizer to file
|
| 196 |
+
pub fn save(&self, path: &str) -> PyResult<()> {
|
| 197 |
+
let file = File::create(path)
|
| 198 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 199 |
+
let writer = BufWriter::new(file);
|
| 200 |
+
serde_json::to_writer(writer, self)
|
| 201 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 202 |
+
Ok(())
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/// Load vectorizer from file
|
| 206 |
+
#[staticmethod]
|
| 207 |
+
pub fn load(path: &str) -> PyResult<Self> {
|
| 208 |
+
let file = File::open(path)
|
| 209 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 210 |
+
let reader = BufReader::new(file);
|
| 211 |
+
let vectorizer: Self = serde_json::from_reader(reader)
|
| 212 |
+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
|
| 213 |
+
Ok(vectorizer)
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
impl TfIdfVectorizer {
|
| 218 |
+
/// Tokenize document into n-grams
|
| 219 |
+
fn tokenize(&self, document: &str) -> Vec<String> {
|
| 220 |
+
let words: Vec<&str> = document.split_whitespace().collect();
|
| 221 |
+
let mut tokens = Vec::new();
|
| 222 |
+
|
| 223 |
+
for n in self.ngram_range.0..=self.ngram_range.1 {
|
| 224 |
+
if n > words.len() {
|
| 225 |
+
continue;
|
| 226 |
+
}
|
| 227 |
+
for i in 0..=(words.len() - n) {
|
| 228 |
+
let ngram = words[i..i + n].join(" ");
|
| 229 |
+
tokens.push(ngram);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
tokens
|
| 234 |
+
}
|
| 235 |
+
}
|
extensions/underthesea_core_extend/uv.lock
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version = 1
|
| 2 |
+
revision = 3
|
| 3 |
+
requires-python = ">=3.10"
|
| 4 |
+
|
| 5 |
+
[[package]]
|
| 6 |
+
name = "underthesea-core-extend"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
source = { editable = "." }
|
pyproject.toml
CHANGED
|
@@ -1,24 +1,23 @@
|
|
| 1 |
[project]
|
| 2 |
name = "sen"
|
| 3 |
-
version = "1.
|
| 4 |
-
description = "Vietnamese Text Classification Model"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.10"
|
| 7 |
license = "Apache-2.0"
|
| 8 |
authors = [
|
| 9 |
{name = "UnderTheSea NLP", email = "undertheseanlp@gmail.com"}
|
| 10 |
]
|
| 11 |
-
keywords = ["vietnamese", "nlp", "text-classification", "
|
| 12 |
dependencies = [
|
| 13 |
-
"
|
| 14 |
-
"joblib>=1.0.0",
|
| 15 |
-
"numpy>=1.20.0",
|
| 16 |
]
|
| 17 |
|
| 18 |
[project.optional-dependencies]
|
| 19 |
dev = [
|
| 20 |
"pytest>=7.0.0",
|
| 21 |
"huggingface-hub>=0.20.0",
|
|
|
|
| 22 |
]
|
| 23 |
|
| 24 |
[project.urls]
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "sen"
|
| 3 |
+
version = "1.1.0"
|
| 4 |
+
description = "Vietnamese Text Classification Model - Rust-powered"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.10"
|
| 7 |
license = "Apache-2.0"
|
| 8 |
authors = [
|
| 9 |
{name = "UnderTheSea NLP", email = "undertheseanlp@gmail.com"}
|
| 10 |
]
|
| 11 |
+
keywords = ["vietnamese", "nlp", "text-classification", "rust", "svm"]
|
| 12 |
dependencies = [
|
| 13 |
+
"underthesea_core_extend>=0.1.0",
|
|
|
|
|
|
|
| 14 |
]
|
| 15 |
|
| 16 |
[project.optional-dependencies]
|
| 17 |
dev = [
|
| 18 |
"pytest>=7.0.0",
|
| 19 |
"huggingface-hub>=0.20.0",
|
| 20 |
+
"maturin>=1.0.0",
|
| 21 |
]
|
| 22 |
|
| 23 |
[project.urls]
|
src/sen/text_classifier.py
CHANGED
|
@@ -1,26 +1,20 @@
|
|
| 1 |
"""
|
| 2 |
-
Sen Text Classifier -
|
| 3 |
|
| 4 |
Based on: "A Comparative Study on Vietnamese Text Classification Methods"
|
| 5 |
Vu et al., RIVF 2007
|
| 6 |
https://ieeexplore.ieee.org/document/4223084/
|
| 7 |
|
| 8 |
Methods:
|
| 9 |
-
- TF-IDF vectorization
|
| 10 |
-
- SVM (
|
| 11 |
"""
|
| 12 |
|
| 13 |
import json
|
| 14 |
import os
|
| 15 |
from typing import List, Optional, Union
|
| 16 |
|
| 17 |
-
import
|
| 18 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 19 |
-
from sklearn.svm import LinearSVC
|
| 20 |
-
from sklearn.pipeline import Pipeline
|
| 21 |
-
from sklearn.preprocessing import LabelEncoder
|
| 22 |
-
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
| 23 |
-
import numpy as np
|
| 24 |
|
| 25 |
|
| 26 |
class Label:
|
|
@@ -59,8 +53,9 @@ class Sentence:
|
|
| 59 |
|
| 60 |
class SenTextClassifier:
|
| 61 |
"""
|
| 62 |
-
|
| 63 |
|
|
|
|
| 64 |
Compatible with underthesea API.
|
| 65 |
|
| 66 |
Reference:
|
|
@@ -71,43 +66,28 @@ class SenTextClassifier:
|
|
| 71 |
def __init__(
|
| 72 |
self,
|
| 73 |
# TF-IDF parameters
|
| 74 |
-
max_features: int =
|
| 75 |
ngram_range: tuple = (1, 2),
|
| 76 |
-
min_df: int =
|
| 77 |
-
max_df: float =
|
| 78 |
-
sublinear_tf: bool = True,
|
| 79 |
# SVM parameters
|
| 80 |
-
|
| 81 |
max_iter: int = 1000,
|
|
|
|
|
|
|
| 82 |
):
|
| 83 |
self.max_features = max_features
|
| 84 |
self.ngram_range = ngram_range
|
| 85 |
self.min_df = min_df
|
| 86 |
self.max_df = max_df
|
| 87 |
-
self.
|
| 88 |
-
self.C = C
|
| 89 |
self.max_iter = max_iter
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
self.
|
| 92 |
-
self.
|
| 93 |
-
self.labels_ = None
|
| 94 |
-
|
| 95 |
-
def _build_pipeline(self) -> Pipeline:
|
| 96 |
-
"""Build sklearn pipeline with TF-IDF + SVM."""
|
| 97 |
-
return Pipeline([
|
| 98 |
-
("tfidf", TfidfVectorizer(
|
| 99 |
-
max_features=self.max_features,
|
| 100 |
-
ngram_range=self.ngram_range,
|
| 101 |
-
min_df=self.min_df,
|
| 102 |
-
max_df=self.max_df,
|
| 103 |
-
sublinear_tf=self.sublinear_tf,
|
| 104 |
-
)),
|
| 105 |
-
("clf", LinearSVC(
|
| 106 |
-
C=self.C,
|
| 107 |
-
max_iter=self.max_iter,
|
| 108 |
-
random_state=42,
|
| 109 |
-
)),
|
| 110 |
-
])
|
| 111 |
|
| 112 |
def train(
|
| 113 |
self,
|
|
@@ -128,38 +108,58 @@ class SenTextClassifier:
|
|
| 128 |
Returns:
|
| 129 |
Dictionary with training metrics
|
| 130 |
"""
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
# Build and train
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# Calculate training metrics
|
| 140 |
-
|
| 141 |
-
train_acc =
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
|
| 144 |
results = {
|
| 145 |
"train_accuracy": train_acc,
|
| 146 |
"train_f1": train_f1,
|
| 147 |
"num_classes": len(self.labels_),
|
| 148 |
"num_samples": len(train_texts),
|
|
|
|
| 149 |
}
|
| 150 |
|
| 151 |
print(f"Training completed:")
|
| 152 |
print(f" - Samples: {len(train_texts)}")
|
| 153 |
print(f" - Classes: {len(self.labels_)}")
|
|
|
|
| 154 |
print(f" - Train accuracy: {train_acc:.4f}")
|
| 155 |
print(f" - Train F1: {train_f1:.4f}")
|
| 156 |
|
| 157 |
# Validation metrics
|
| 158 |
if val_texts and val_labels:
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
val_acc =
|
| 162 |
-
val_f1 =
|
| 163 |
|
| 164 |
results["val_accuracy"] = val_acc
|
| 165 |
results["val_f1"] = val_f1
|
|
@@ -169,6 +169,28 @@ class SenTextClassifier:
|
|
| 169 |
|
| 170 |
return results
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
def predict(self, sentence: Sentence) -> None:
|
| 173 |
"""
|
| 174 |
Predict label for a sentence (underthesea-compatible API).
|
|
@@ -176,23 +198,11 @@ class SenTextClassifier:
|
|
| 176 |
Args:
|
| 177 |
sentence: Sentence object with text attribute
|
| 178 |
"""
|
| 179 |
-
if self.
|
| 180 |
raise ValueError("Model not trained. Call train() first or load a model.")
|
| 181 |
|
| 182 |
-
|
| 183 |
-
label_value = self.
|
| 184 |
-
|
| 185 |
-
# Get confidence score using decision function
|
| 186 |
-
try:
|
| 187 |
-
decision = self.pipeline.decision_function([sentence.text])[0]
|
| 188 |
-
if isinstance(decision, np.ndarray):
|
| 189 |
-
score = float(np.max(np.abs(decision)))
|
| 190 |
-
else:
|
| 191 |
-
score = float(abs(decision))
|
| 192 |
-
# Normalize to 0-1 range using sigmoid
|
| 193 |
-
score = 1 / (1 + np.exp(-score))
|
| 194 |
-
except Exception:
|
| 195 |
-
score = 1.0
|
| 196 |
|
| 197 |
sentence.labels = []
|
| 198 |
sentence.add_labels([Label(label_value, score)])
|
|
@@ -207,23 +217,17 @@ class SenTextClassifier:
|
|
| 207 |
Returns:
|
| 208 |
List of Label objects
|
| 209 |
"""
|
| 210 |
-
if self.
|
| 211 |
raise ValueError("Model not trained. Call train() first or load a model.")
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
if decisions.ndim == 1:
|
| 220 |
-
scores = 1 / (1 + np.exp(-np.abs(decisions)))
|
| 221 |
-
else:
|
| 222 |
-
scores = 1 / (1 + np.exp(-np.max(np.abs(decisions), axis=1)))
|
| 223 |
-
except Exception:
|
| 224 |
-
scores = [1.0] * len(texts)
|
| 225 |
|
| 226 |
-
return
|
| 227 |
|
| 228 |
def evaluate(self, texts: List[str], labels: List[str]) -> dict:
|
| 229 |
"""
|
|
@@ -236,23 +240,46 @@ class SenTextClassifier:
|
|
| 236 |
Returns:
|
| 237 |
Dictionary with evaluation metrics
|
| 238 |
"""
|
| 239 |
-
|
| 240 |
-
|
|
|
|
| 241 |
|
| 242 |
-
acc =
|
| 243 |
-
f1 =
|
| 244 |
|
| 245 |
print(f"Evaluation:")
|
| 246 |
print(f" - Accuracy: {acc:.4f}")
|
| 247 |
print(f" - F1 (weighted): {f1:.4f}")
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
target_names=self.labels_
|
| 252 |
-
))
|
| 253 |
|
| 254 |
return {"accuracy": acc, "f1": f1}
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def save(self, path: str) -> None:
|
| 257 |
"""
|
| 258 |
Save model to disk.
|
|
@@ -262,23 +289,25 @@ class SenTextClassifier:
|
|
| 262 |
"""
|
| 263 |
os.makedirs(path, exist_ok=True)
|
| 264 |
|
| 265 |
-
# Save
|
| 266 |
-
|
| 267 |
|
| 268 |
-
# Save
|
| 269 |
-
|
| 270 |
|
| 271 |
# Save metadata
|
| 272 |
metadata = {
|
| 273 |
-
"estimator": "
|
| 274 |
"max_features": self.max_features,
|
| 275 |
"ngram_range": self.ngram_range,
|
| 276 |
"min_df": self.min_df,
|
| 277 |
"max_df": self.max_df,
|
| 278 |
-
"
|
| 279 |
-
"C": self.C,
|
| 280 |
"max_iter": self.max_iter,
|
|
|
|
| 281 |
"labels": self.labels_,
|
|
|
|
|
|
|
| 282 |
}
|
| 283 |
with open(os.path.join(path, "metadata.json"), "w", encoding="utf-8") as f:
|
| 284 |
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
@@ -302,19 +331,21 @@ class SenTextClassifier:
|
|
| 302 |
|
| 303 |
# Create instance with saved parameters
|
| 304 |
classifier = cls(
|
| 305 |
-
max_features=metadata.get("max_features",
|
| 306 |
ngram_range=tuple(metadata.get("ngram_range", (1, 2))),
|
| 307 |
-
min_df=metadata.get("min_df",
|
| 308 |
-
max_df=metadata.get("max_df",
|
| 309 |
-
|
| 310 |
-
C=metadata.get("C", 1.0),
|
| 311 |
max_iter=metadata.get("max_iter", 1000),
|
|
|
|
| 312 |
)
|
| 313 |
|
| 314 |
-
# Load
|
| 315 |
-
classifier.
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
| 318 |
|
| 319 |
print(f"Model loaded from: {path}")
|
| 320 |
return classifier
|
|
|
|
| 1 |
"""
|
| 2 |
+
Sen Text Classifier - Rust-based classifier using underthesea_core_extend.
|
| 3 |
|
| 4 |
Based on: "A Comparative Study on Vietnamese Text Classification Methods"
|
| 5 |
Vu et al., RIVF 2007
|
| 6 |
https://ieeexplore.ieee.org/document/4223084/
|
| 7 |
|
| 8 |
Methods:
|
| 9 |
+
- TF-IDF vectorization (Rust: underthesea_core_extend.TfIdfVectorizer)
|
| 10 |
+
- Linear SVM classifier (Rust: underthesea_core_extend.LinearSVM)
|
| 11 |
"""
|
| 12 |
|
| 13 |
import json
|
| 14 |
import os
|
| 15 |
from typing import List, Optional, Union
|
| 16 |
|
| 17 |
+
from underthesea_core_extend import TfIdfVectorizer, LinearSVM, SVMTrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class Label:
|
|
|
|
| 53 |
|
| 54 |
class SenTextClassifier:
|
| 55 |
"""
|
| 56 |
+
Rust-based text classifier using TF-IDF + Linear SVM.
|
| 57 |
|
| 58 |
+
Uses underthesea_core_extend for fast training and inference.
|
| 59 |
Compatible with underthesea API.
|
| 60 |
|
| 61 |
Reference:
|
|
|
|
| 66 |
def __init__(
|
| 67 |
self,
|
| 68 |
# TF-IDF parameters
|
| 69 |
+
max_features: int = 20000,
|
| 70 |
ngram_range: tuple = (1, 2),
|
| 71 |
+
min_df: int = 1,
|
| 72 |
+
max_df: float = 1.0,
|
|
|
|
| 73 |
# SVM parameters
|
| 74 |
+
c: float = 1.0,
|
| 75 |
max_iter: int = 1000,
|
| 76 |
+
tol: float = 0.1,
|
| 77 |
+
verbose: bool = True,
|
| 78 |
):
|
| 79 |
self.max_features = max_features
|
| 80 |
self.ngram_range = ngram_range
|
| 81 |
self.min_df = min_df
|
| 82 |
self.max_df = max_df
|
| 83 |
+
self.c = c
|
|
|
|
| 84 |
self.max_iter = max_iter
|
| 85 |
+
self.tol = tol
|
| 86 |
+
self.verbose = verbose
|
| 87 |
|
| 88 |
+
self.vectorizer: Optional[TfIdfVectorizer] = None
|
| 89 |
+
self.classifier: Optional[LinearSVM] = None
|
| 90 |
+
self.labels_: Optional[List[str]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def train(
|
| 93 |
self,
|
|
|
|
| 108 |
Returns:
|
| 109 |
Dictionary with training metrics
|
| 110 |
"""
|
| 111 |
+
# Get unique labels
|
| 112 |
+
self.labels_ = sorted(list(set(train_labels)))
|
| 113 |
+
|
| 114 |
+
# Build and fit vectorizer
|
| 115 |
+
self.vectorizer = TfIdfVectorizer(
|
| 116 |
+
max_features=self.max_features,
|
| 117 |
+
ngram_range=self.ngram_range,
|
| 118 |
+
min_df=self.min_df,
|
| 119 |
+
max_df=self.max_df,
|
| 120 |
+
)
|
| 121 |
+
self.vectorizer.fit(train_texts)
|
| 122 |
+
|
| 123 |
+
# Transform to features
|
| 124 |
+
train_features = self.vectorizer.transform_batch(train_texts)
|
| 125 |
|
| 126 |
+
# Build and train SVM model
|
| 127 |
+
trainer = SVMTrainer(
|
| 128 |
+
c=self.c,
|
| 129 |
+
max_iter=self.max_iter,
|
| 130 |
+
tol=self.tol,
|
| 131 |
+
verbose=self.verbose,
|
| 132 |
+
)
|
| 133 |
+
self.classifier = trainer.train(train_features, train_labels)
|
| 134 |
|
| 135 |
# Calculate training metrics
|
| 136 |
+
train_preds = self.classifier.predict_batch(train_features)
|
| 137 |
+
train_acc = sum(1 for p, t in zip(train_preds, train_labels) if p == t) / len(train_labels)
|
| 138 |
+
|
| 139 |
+
# Calculate F1 score
|
| 140 |
+
train_f1 = self._calculate_f1(train_labels, train_preds)
|
| 141 |
|
| 142 |
results = {
|
| 143 |
"train_accuracy": train_acc,
|
| 144 |
"train_f1": train_f1,
|
| 145 |
"num_classes": len(self.labels_),
|
| 146 |
"num_samples": len(train_texts),
|
| 147 |
+
"vocab_size": self.vectorizer.vocab_size,
|
| 148 |
}
|
| 149 |
|
| 150 |
print(f"Training completed:")
|
| 151 |
print(f" - Samples: {len(train_texts)}")
|
| 152 |
print(f" - Classes: {len(self.labels_)}")
|
| 153 |
+
print(f" - Vocab size: {self.vectorizer.vocab_size}")
|
| 154 |
print(f" - Train accuracy: {train_acc:.4f}")
|
| 155 |
print(f" - Train F1: {train_f1:.4f}")
|
| 156 |
|
| 157 |
# Validation metrics
|
| 158 |
if val_texts and val_labels:
|
| 159 |
+
val_features = self.vectorizer.transform_batch(val_texts)
|
| 160 |
+
val_preds = self.classifier.predict_batch(val_features)
|
| 161 |
+
val_acc = sum(1 for p, t in zip(val_preds, val_labels) if p == t) / len(val_labels)
|
| 162 |
+
val_f1 = self._calculate_f1(val_labels, val_preds)
|
| 163 |
|
| 164 |
results["val_accuracy"] = val_acc
|
| 165 |
results["val_f1"] = val_f1
|
|
|
|
| 169 |
|
| 170 |
return results
|
| 171 |
|
| 172 |
+
def _calculate_f1(self, y_true: List[str], y_pred: List[str]) -> float:
|
| 173 |
+
"""Calculate weighted F1 score."""
|
| 174 |
+
from collections import Counter
|
| 175 |
+
|
| 176 |
+
label_counts = Counter(y_true)
|
| 177 |
+
total = len(y_true)
|
| 178 |
+
|
| 179 |
+
f1_sum = 0.0
|
| 180 |
+
for label in self.labels_:
|
| 181 |
+
tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
|
| 182 |
+
fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
|
| 183 |
+
fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
|
| 184 |
+
|
| 185 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 186 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 187 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
| 188 |
+
|
| 189 |
+
weight = label_counts[label] / total
|
| 190 |
+
f1_sum += f1 * weight
|
| 191 |
+
|
| 192 |
+
return f1_sum
|
| 193 |
+
|
| 194 |
def predict(self, sentence: Sentence) -> None:
|
| 195 |
"""
|
| 196 |
Predict label for a sentence (underthesea-compatible API).
|
|
|
|
| 198 |
Args:
|
| 199 |
sentence: Sentence object with text attribute
|
| 200 |
"""
|
| 201 |
+
if self.classifier is None or self.vectorizer is None:
|
| 202 |
raise ValueError("Model not trained. Call train() first or load a model.")
|
| 203 |
|
| 204 |
+
features = self.vectorizer.transform_dense(sentence.text)
|
| 205 |
+
label_value, score = self.classifier.predict_with_score(features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
sentence.labels = []
|
| 208 |
sentence.add_labels([Label(label_value, score)])
|
|
|
|
| 217 |
Returns:
|
| 218 |
List of Label objects
|
| 219 |
"""
|
| 220 |
+
if self.classifier is None or self.vectorizer is None:
|
| 221 |
raise ValueError("Model not trained. Call train() first or load a model.")
|
| 222 |
|
| 223 |
+
# Use dense transform (faster Python-Rust interface)
|
| 224 |
+
features = self.vectorizer.transform_batch(texts)
|
| 225 |
+
results = []
|
| 226 |
+
for feat in features:
|
| 227 |
+
label_value, score = self.classifier.predict_with_score(feat)
|
| 228 |
+
results.append(Label(label_value, float(score)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
return results
|
| 231 |
|
| 232 |
def evaluate(self, texts: List[str], labels: List[str]) -> dict:
|
| 233 |
"""
|
|
|
|
| 240 |
Returns:
|
| 241 |
Dictionary with evaluation metrics
|
| 242 |
"""
|
| 243 |
+
# Use dense transform (faster Python-Rust interface)
|
| 244 |
+
features = self.vectorizer.transform_batch(texts)
|
| 245 |
+
y_pred = self.classifier.predict_batch(features)
|
| 246 |
|
| 247 |
+
acc = sum(1 for p, t in zip(y_pred, labels) if p == t) / len(labels)
|
| 248 |
+
f1 = self._calculate_f1(labels, y_pred)
|
| 249 |
|
| 250 |
print(f"Evaluation:")
|
| 251 |
print(f" - Accuracy: {acc:.4f}")
|
| 252 |
print(f" - F1 (weighted): {f1:.4f}")
|
| 253 |
+
|
| 254 |
+
# Print classification report
|
| 255 |
+
self._print_classification_report(labels, y_pred)
|
|
|
|
|
|
|
| 256 |
|
| 257 |
return {"accuracy": acc, "f1": f1}
|
| 258 |
|
| 259 |
+
def _print_classification_report(self, y_true: List[str], y_pred: List[str]):
|
| 260 |
+
"""Print classification report."""
|
| 261 |
+
from collections import Counter
|
| 262 |
+
|
| 263 |
+
print("\nClassification Report:")
|
| 264 |
+
print(f"{'':>20} {'precision':>10} {'recall':>10} {'f1-score':>10} {'support':>10}")
|
| 265 |
+
print()
|
| 266 |
+
|
| 267 |
+
label_counts = Counter(y_true)
|
| 268 |
+
|
| 269 |
+
for label in self.labels_:
|
| 270 |
+
tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
|
| 271 |
+
fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
|
| 272 |
+
fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
|
| 273 |
+
|
| 274 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
| 275 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 276 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
| 277 |
+
support = label_counts[label]
|
| 278 |
+
|
| 279 |
+
print(f"{label:>20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f} {support:>10}")
|
| 280 |
+
|
| 281 |
+
print()
|
| 282 |
+
|
| 283 |
def save(self, path: str) -> None:
|
| 284 |
"""
|
| 285 |
Save model to disk.
|
|
|
|
| 289 |
"""
|
| 290 |
os.makedirs(path, exist_ok=True)
|
| 291 |
|
| 292 |
+
# Save vectorizer
|
| 293 |
+
self.vectorizer.save(os.path.join(path, "vectorizer.json"))
|
| 294 |
|
| 295 |
+
# Save classifier
|
| 296 |
+
self.classifier.save(os.path.join(path, "classifier.json"))
|
| 297 |
|
| 298 |
# Save metadata
|
| 299 |
metadata = {
|
| 300 |
+
"estimator": "RUST_SVM",
|
| 301 |
"max_features": self.max_features,
|
| 302 |
"ngram_range": self.ngram_range,
|
| 303 |
"min_df": self.min_df,
|
| 304 |
"max_df": self.max_df,
|
| 305 |
+
"c": self.c,
|
|
|
|
| 306 |
"max_iter": self.max_iter,
|
| 307 |
+
"tol": self.tol,
|
| 308 |
"labels": self.labels_,
|
| 309 |
+
"vocab_size": self.vectorizer.vocab_size,
|
| 310 |
+
"n_classes": self.classifier.n_classes,
|
| 311 |
}
|
| 312 |
with open(os.path.join(path, "metadata.json"), "w", encoding="utf-8") as f:
|
| 313 |
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
|
|
| 331 |
|
| 332 |
# Create instance with saved parameters
|
| 333 |
classifier = cls(
|
| 334 |
+
max_features=metadata.get("max_features", 20000),
|
| 335 |
ngram_range=tuple(metadata.get("ngram_range", (1, 2))),
|
| 336 |
+
min_df=metadata.get("min_df", 1),
|
| 337 |
+
max_df=metadata.get("max_df", 1.0),
|
| 338 |
+
c=metadata.get("c", 1.0),
|
|
|
|
| 339 |
max_iter=metadata.get("max_iter", 1000),
|
| 340 |
+
tol=metadata.get("tol", 0.1),
|
| 341 |
)
|
| 342 |
|
| 343 |
+
# Load vectorizer
|
| 344 |
+
classifier.vectorizer = TfIdfVectorizer.load(os.path.join(path, "vectorizer.json"))
|
| 345 |
+
|
| 346 |
+
# Load SVM model
|
| 347 |
+
classifier.classifier = LinearSVM.load(os.path.join(path, "classifier.json"))
|
| 348 |
+
classifier.labels_ = metadata.get("labels", [])
|
| 349 |
|
| 350 |
print(f"Model loaded from: {path}")
|
| 351 |
return classifier
|