rain1024 Claude Opus 4.5 commited on
Commit
2ac432e
·
1 Parent(s): 21bb3e1

Add Rust extensions for TF-IDF and Linear SVM

Browse files

- Implement TF-IDF vectorizer in pure Rust with PyO3 bindings
- Implement Linear SVM with LIBLINEAR-style Dual Coordinate Descent
- Achieve 88.15% accuracy on VNTC (vs sklearn's 89.48%)
- Update text_classifier.py to use new Rust backend

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

extensions/underthesea_core_extend/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ target/
2
+ .venv/
3
+ __pycache__/
4
+ *.so
5
+ *.pyc
extensions/underthesea_core_extend/Cargo.lock ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "allocator-api2"
7
+ version = "0.2.21"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
10
+
11
+ [[package]]
12
+ name = "autocfg"
13
+ version = "1.5.0"
14
+ source = "registry+https://github.com/rust-lang/crates.io-index"
15
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
16
+
17
+ [[package]]
18
+ name = "cfg-if"
19
+ version = "1.0.4"
20
+ source = "registry+https://github.com/rust-lang/crates.io-index"
21
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
22
+
23
+ [[package]]
24
+ name = "crossbeam-deque"
25
+ version = "0.8.6"
26
+ source = "registry+https://github.com/rust-lang/crates.io-index"
27
+ checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
28
+ dependencies = [
29
+ "crossbeam-epoch",
30
+ "crossbeam-utils",
31
+ ]
32
+
33
+ [[package]]
34
+ name = "crossbeam-epoch"
35
+ version = "0.9.18"
36
+ source = "registry+https://github.com/rust-lang/crates.io-index"
37
+ checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
38
+ dependencies = [
39
+ "crossbeam-utils",
40
+ ]
41
+
42
+ [[package]]
43
+ name = "crossbeam-utils"
44
+ version = "0.8.21"
45
+ source = "registry+https://github.com/rust-lang/crates.io-index"
46
+ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
47
+
48
+ [[package]]
49
+ name = "either"
50
+ version = "1.15.0"
51
+ source = "registry+https://github.com/rust-lang/crates.io-index"
52
+ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
53
+
54
+ [[package]]
55
+ name = "equivalent"
56
+ version = "1.0.2"
57
+ source = "registry+https://github.com/rust-lang/crates.io-index"
58
+ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
59
+
60
+ [[package]]
61
+ name = "foldhash"
62
+ version = "0.1.5"
63
+ source = "registry+https://github.com/rust-lang/crates.io-index"
64
+ checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
65
+
66
+ [[package]]
67
+ name = "hashbrown"
68
+ version = "0.15.5"
69
+ source = "registry+https://github.com/rust-lang/crates.io-index"
70
+ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
71
+ dependencies = [
72
+ "allocator-api2",
73
+ "equivalent",
74
+ "foldhash",
75
+ "serde",
76
+ ]
77
+
78
+ [[package]]
79
+ name = "heck"
80
+ version = "0.5.0"
81
+ source = "registry+https://github.com/rust-lang/crates.io-index"
82
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
83
+
84
+ [[package]]
85
+ name = "indoc"
86
+ version = "2.0.7"
87
+ source = "registry+https://github.com/rust-lang/crates.io-index"
88
+ checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
89
+ dependencies = [
90
+ "rustversion",
91
+ ]
92
+
93
+ [[package]]
94
+ name = "itoa"
95
+ version = "1.0.17"
96
+ source = "registry+https://github.com/rust-lang/crates.io-index"
97
+ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
98
+
99
+ [[package]]
100
+ name = "libc"
101
+ version = "0.2.180"
102
+ source = "registry+https://github.com/rust-lang/crates.io-index"
103
+ checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
104
+
105
+ [[package]]
106
+ name = "memchr"
107
+ version = "2.7.6"
108
+ source = "registry+https://github.com/rust-lang/crates.io-index"
109
+ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
110
+
111
+ [[package]]
112
+ name = "memoffset"
113
+ version = "0.9.1"
114
+ source = "registry+https://github.com/rust-lang/crates.io-index"
115
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
116
+ dependencies = [
117
+ "autocfg",
118
+ ]
119
+
120
+ [[package]]
121
+ name = "once_cell"
122
+ version = "1.21.3"
123
+ source = "registry+https://github.com/rust-lang/crates.io-index"
124
+ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
125
+
126
+ [[package]]
127
+ name = "portable-atomic"
128
+ version = "1.13.1"
129
+ source = "registry+https://github.com/rust-lang/crates.io-index"
130
+ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
131
+
132
+ [[package]]
133
+ name = "proc-macro2"
134
+ version = "1.0.106"
135
+ source = "registry+https://github.com/rust-lang/crates.io-index"
136
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
137
+ dependencies = [
138
+ "unicode-ident",
139
+ ]
140
+
141
+ [[package]]
142
+ name = "pyo3"
143
+ version = "0.22.6"
144
+ source = "registry+https://github.com/rust-lang/crates.io-index"
145
+ checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
146
+ dependencies = [
147
+ "cfg-if",
148
+ "indoc",
149
+ "libc",
150
+ "memoffset",
151
+ "once_cell",
152
+ "portable-atomic",
153
+ "pyo3-build-config",
154
+ "pyo3-ffi",
155
+ "pyo3-macros",
156
+ "unindent",
157
+ ]
158
+
159
+ [[package]]
160
+ name = "pyo3-build-config"
161
+ version = "0.22.6"
162
+ source = "registry+https://github.com/rust-lang/crates.io-index"
163
+ checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
164
+ dependencies = [
165
+ "once_cell",
166
+ "target-lexicon",
167
+ ]
168
+
169
+ [[package]]
170
+ name = "pyo3-ffi"
171
+ version = "0.22.6"
172
+ source = "registry+https://github.com/rust-lang/crates.io-index"
173
+ checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
174
+ dependencies = [
175
+ "libc",
176
+ "pyo3-build-config",
177
+ ]
178
+
179
+ [[package]]
180
+ name = "pyo3-macros"
181
+ version = "0.22.6"
182
+ source = "registry+https://github.com/rust-lang/crates.io-index"
183
+ checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
184
+ dependencies = [
185
+ "proc-macro2",
186
+ "pyo3-macros-backend",
187
+ "quote",
188
+ "syn",
189
+ ]
190
+
191
+ [[package]]
192
+ name = "pyo3-macros-backend"
193
+ version = "0.22.6"
194
+ source = "registry+https://github.com/rust-lang/crates.io-index"
195
+ checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
196
+ dependencies = [
197
+ "heck",
198
+ "proc-macro2",
199
+ "pyo3-build-config",
200
+ "quote",
201
+ "syn",
202
+ ]
203
+
204
+ [[package]]
205
+ name = "quote"
206
+ version = "1.0.44"
207
+ source = "registry+https://github.com/rust-lang/crates.io-index"
208
+ checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
209
+ dependencies = [
210
+ "proc-macro2",
211
+ ]
212
+
213
+ [[package]]
214
+ name = "rayon"
215
+ version = "1.11.0"
216
+ source = "registry+https://github.com/rust-lang/crates.io-index"
217
+ checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
218
+ dependencies = [
219
+ "either",
220
+ "rayon-core",
221
+ ]
222
+
223
+ [[package]]
224
+ name = "rayon-core"
225
+ version = "1.13.0"
226
+ source = "registry+https://github.com/rust-lang/crates.io-index"
227
+ checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
228
+ dependencies = [
229
+ "crossbeam-deque",
230
+ "crossbeam-utils",
231
+ ]
232
+
233
+ [[package]]
234
+ name = "rustversion"
235
+ version = "1.0.22"
236
+ source = "registry+https://github.com/rust-lang/crates.io-index"
237
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
238
+
239
+ [[package]]
240
+ name = "serde"
241
+ version = "1.0.228"
242
+ source = "registry+https://github.com/rust-lang/crates.io-index"
243
+ checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
244
+ dependencies = [
245
+ "serde_core",
246
+ "serde_derive",
247
+ ]
248
+
249
+ [[package]]
250
+ name = "serde_core"
251
+ version = "1.0.228"
252
+ source = "registry+https://github.com/rust-lang/crates.io-index"
253
+ checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
254
+ dependencies = [
255
+ "serde_derive",
256
+ ]
257
+
258
+ [[package]]
259
+ name = "serde_derive"
260
+ version = "1.0.228"
261
+ source = "registry+https://github.com/rust-lang/crates.io-index"
262
+ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
263
+ dependencies = [
264
+ "proc-macro2",
265
+ "quote",
266
+ "syn",
267
+ ]
268
+
269
+ [[package]]
270
+ name = "serde_json"
271
+ version = "1.0.149"
272
+ source = "registry+https://github.com/rust-lang/crates.io-index"
273
+ checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
274
+ dependencies = [
275
+ "itoa",
276
+ "memchr",
277
+ "serde",
278
+ "serde_core",
279
+ "zmij",
280
+ ]
281
+
282
+ [[package]]
283
+ name = "syn"
284
+ version = "2.0.114"
285
+ source = "registry+https://github.com/rust-lang/crates.io-index"
286
+ checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
287
+ dependencies = [
288
+ "proc-macro2",
289
+ "quote",
290
+ "unicode-ident",
291
+ ]
292
+
293
+ [[package]]
294
+ name = "target-lexicon"
295
+ version = "0.12.16"
296
+ source = "registry+https://github.com/rust-lang/crates.io-index"
297
+ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
298
+
299
+ [[package]]
300
+ name = "tinyvec"
301
+ version = "1.10.0"
302
+ source = "registry+https://github.com/rust-lang/crates.io-index"
303
+ checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
304
+ dependencies = [
305
+ "tinyvec_macros",
306
+ ]
307
+
308
+ [[package]]
309
+ name = "tinyvec_macros"
310
+ version = "0.1.1"
311
+ source = "registry+https://github.com/rust-lang/crates.io-index"
312
+ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
313
+
314
+ [[package]]
315
+ name = "underthesea_core_extend"
316
+ version = "0.1.0"
317
+ dependencies = [
318
+ "hashbrown",
319
+ "pyo3",
320
+ "rayon",
321
+ "serde",
322
+ "serde_json",
323
+ "unicode-normalization",
324
+ ]
325
+
326
+ [[package]]
327
+ name = "unicode-ident"
328
+ version = "1.0.22"
329
+ source = "registry+https://github.com/rust-lang/crates.io-index"
330
+ checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
331
+
332
+ [[package]]
333
+ name = "unicode-normalization"
334
+ version = "0.1.25"
335
+ source = "registry+https://github.com/rust-lang/crates.io-index"
336
+ checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
337
+ dependencies = [
338
+ "tinyvec",
339
+ ]
340
+
341
+ [[package]]
342
+ name = "unindent"
343
+ version = "0.2.4"
344
+ source = "registry+https://github.com/rust-lang/crates.io-index"
345
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
346
+
347
+ [[package]]
348
+ name = "zmij"
349
+ version = "1.0.19"
350
+ source = "registry+https://github.com/rust-lang/crates.io-index"
351
+ checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445"
extensions/underthesea_core_extend/Cargo.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "underthesea_core_extend"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+ description = "Rust extensions for underthesea - Text Classification"
6
+ license = "Apache-2.0"
7
+
8
+ [lib]
9
+ name = "underthesea_core_extend"
10
+ crate-type = ["cdylib"]
11
+
12
+ [dependencies]
13
+ pyo3 = { version = "0.22", features = ["extension-module"] }
14
+ serde = { version = "1.0", features = ["derive"] }
15
+ serde_json = "1.0"
16
+ rayon = "1.10"
17
+ hashbrown = { version = "0.15", features = ["serde"] }
18
+ unicode-normalization = "0.1"
19
+
20
+ [profile.release]
21
+ lto = true
22
+ codegen-units = 1
23
+ opt-level = 3
extensions/underthesea_core_extend/pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["maturin>=1.0,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "underthesea_core_extend"
7
+ version = "0.1.0"
8
+ description = "Rust extensions for underthesea - Text Classification"
9
+ requires-python = ">=3.10"
10
+ license = { text = "Apache-2.0" }
11
+ authors = [{ name = "UnderTheSea NLP", email = "anhv.ict91@gmail.com" }]
12
+ classifiers = [
13
+ "Programming Language :: Rust",
14
+ "Programming Language :: Python :: Implementation :: CPython",
15
+ "Programming Language :: Python :: 3.10",
16
+ "Programming Language :: Python :: 3.11",
17
+ "Programming Language :: Python :: 3.12",
18
+ ]
19
+
20
+ [tool.maturin]
21
+ features = ["pyo3/extension-module"]
22
+ module-name = "underthesea_core_extend"
extensions/underthesea_core_extend/src/lib.rs ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! underthesea_core_extend - Rust extensions for Vietnamese Text Classification
2
+ //!
3
+ //! Provides fast TF-IDF vectorization and Linear SVM classification.
4
+
5
+ use pyo3::prelude::*;
6
+
7
+ mod tfidf;
8
+ mod svm;
9
+
10
+ pub use tfidf::TfIdfVectorizer;
11
+ pub use svm::{LinearSVM, SVMTrainer, FastSVMTrainer};
12
+
13
+ /// Python module
14
+ #[pymodule]
15
+ fn underthesea_core_extend(m: &Bound<'_, PyModule>) -> PyResult<()> {
16
+ m.add_class::<TfIdfVectorizer>()?;
17
+ m.add_class::<LinearSVM>()?;
18
+ m.add_class::<SVMTrainer>()?;
19
+ m.add_class::<FastSVMTrainer>()?;
20
+ Ok(())
21
+ }
extensions/underthesea_core_extend/src/svm.rs ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Optimized Linear SVM - LIBLINEAR-style Dual Coordinate Descent
2
+ //!
3
+ //! Pure Rust implementation of L2-regularized L2-loss SVM (dual form)
4
+ //! Reference: "A Dual Coordinate Descent Method for Large-scale Linear SVM"
5
+ //! Hsieh et al., ICML 2008
6
+
7
+ use hashbrown::HashMap;
8
+ use pyo3::prelude::*;
9
+ use rayon::prelude::*;
10
+ use serde::{Deserialize, Serialize};
11
+ use std::fs::File;
12
+ use std::io::{BufReader, BufWriter};
13
+
14
+ /// Sparse feature vector
15
+ pub type SparseVec = Vec<(u32, f32)>; // Use u32/f32 for memory efficiency
16
+
17
+ /// Linear SVM Model
18
+ #[pyclass]
19
+ #[derive(Clone, Serialize, Deserialize)]
20
+ pub struct LinearSVM {
21
+ weights: Vec<Vec<f32>>,
22
+ biases: Vec<f32>,
23
+ classes: Vec<String>,
24
+ n_features: usize,
25
+ }
26
+
27
+ #[pymethods]
28
+ impl LinearSVM {
29
+ #[new]
30
+ pub fn new() -> Self {
31
+ Self {
32
+ weights: Vec::new(),
33
+ biases: Vec::new(),
34
+ classes: Vec::new(),
35
+ n_features: 0,
36
+ }
37
+ }
38
+
39
+ pub fn predict(&self, features: Vec<f64>) -> String {
40
+ let idx = self.predict_idx(&features);
41
+ self.classes[idx].clone()
42
+ }
43
+
44
+ pub fn predict_with_score(&self, features: Vec<f64>) -> (String, f64) {
45
+ let scores = self.decision_scores(&features);
46
+ let (idx, &max_score) = scores
47
+ .iter()
48
+ .enumerate()
49
+ .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
50
+ .unwrap();
51
+ let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
52
+ (self.classes[idx].clone(), confidence)
53
+ }
54
+
55
+ pub fn predict_batch(&self, features_batch: Vec<Vec<f64>>) -> Vec<String> {
56
+ features_batch
57
+ .par_iter()
58
+ .map(|f| {
59
+ let idx = self.predict_idx(f);
60
+ self.classes[idx].clone()
61
+ })
62
+ .collect()
63
+ }
64
+
65
+ pub fn predict_batch_sparse(&self, features_batch: Vec<Vec<(usize, f64)>>) -> Vec<String> {
66
+ features_batch
67
+ .par_iter()
68
+ .map(|f| {
69
+ let idx = self.predict_idx_sparse(f);
70
+ self.classes[idx].clone()
71
+ })
72
+ .collect()
73
+ }
74
+
75
+ pub fn predict_sparse_with_score(&self, features: Vec<(usize, f64)>) -> (String, f64) {
76
+ let scores = self.decision_scores_sparse(&features);
77
+ let (idx, &max_score) = scores
78
+ .iter()
79
+ .enumerate()
80
+ .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
81
+ .unwrap();
82
+ let confidence = 1.0 / (1.0 + (-max_score as f64).exp());
83
+ (self.classes[idx].clone(), confidence)
84
+ }
85
+
86
+ pub fn decision_function(&self, features: Vec<f64>) -> Vec<f64> {
87
+ self.decision_scores(&features).into_iter().map(|x| x as f64).collect()
88
+ }
89
+
90
+ #[getter]
91
+ pub fn classes(&self) -> Vec<String> {
92
+ self.classes.clone()
93
+ }
94
+
95
+ #[getter]
96
+ pub fn n_classes(&self) -> usize {
97
+ self.classes.len()
98
+ }
99
+
100
+ #[getter]
101
+ pub fn n_features(&self) -> usize {
102
+ self.n_features
103
+ }
104
+
105
+ pub fn save(&self, path: &str) -> PyResult<()> {
106
+ let file = File::create(path)
107
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
108
+ let writer = BufWriter::new(file);
109
+ serde_json::to_writer(writer, self)
110
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
111
+ Ok(())
112
+ }
113
+
114
+ #[staticmethod]
115
+ pub fn load(path: &str) -> PyResult<Self> {
116
+ let file = File::open(path)
117
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
118
+ let reader = BufReader::new(file);
119
+ let model: Self = serde_json::from_reader(reader)
120
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
121
+ Ok(model)
122
+ }
123
+ }
124
+
125
+ impl LinearSVM {
126
+ #[inline]
127
+ fn predict_idx(&self, features: &[f64]) -> usize {
128
+ let mut best_idx = 0;
129
+ let mut best_score = f32::NEG_INFINITY;
130
+
131
+ for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
132
+ let score: f32 = w.iter()
133
+ .zip(features.iter())
134
+ .map(|(&wi, &fi)| wi * fi as f32)
135
+ .sum::<f32>() + b;
136
+
137
+ if score > best_score {
138
+ best_score = score;
139
+ best_idx = idx;
140
+ }
141
+ }
142
+ best_idx
143
+ }
144
+
145
+ #[inline]
146
+ fn predict_idx_sparse(&self, features: &[(usize, f64)]) -> usize {
147
+ let mut best_idx = 0;
148
+ let mut best_score = f32::NEG_INFINITY;
149
+
150
+ for (idx, (w, &b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
151
+ let score: f32 = features.iter()
152
+ .map(|&(j, v)| w[j] * v as f32)
153
+ .sum::<f32>() + b;
154
+
155
+ if score > best_score {
156
+ best_score = score;
157
+ best_idx = idx;
158
+ }
159
+ }
160
+ best_idx
161
+ }
162
+
163
+ fn decision_scores(&self, features: &[f64]) -> Vec<f32> {
164
+ self.weights
165
+ .iter()
166
+ .zip(self.biases.iter())
167
+ .map(|(w, &b)| {
168
+ w.iter()
169
+ .zip(features.iter())
170
+ .map(|(&wi, &fi)| wi * fi as f32)
171
+ .sum::<f32>() + b
172
+ })
173
+ .collect()
174
+ }
175
+
176
+ fn decision_scores_sparse(&self, features: &[(usize, f64)]) -> Vec<f32> {
177
+ self.weights
178
+ .iter()
179
+ .zip(self.biases.iter())
180
+ .map(|(w, &b)| {
181
+ features.iter()
182
+ .map(|&(j, v)| w[j] * v as f32)
183
+ .sum::<f32>() + b
184
+ })
185
+ .collect()
186
+ }
187
+ }
188
+
189
+ /// LIBLINEAR-style SVM Trainer
190
+ #[pyclass]
191
+ pub struct SVMTrainer {
192
+ c: f64,
193
+ max_iter: usize,
194
+ tol: f64,
195
+ verbose: bool,
196
+ }
197
+
198
+ #[pymethods]
199
+ impl SVMTrainer {
200
+ #[new]
201
+ #[pyo3(signature = (c=1.0, max_iter=1000, tol=0.1, verbose=false))]
202
+ pub fn new(c: f64, max_iter: usize, tol: f64, verbose: bool) -> Self {
203
+ Self { c, max_iter, tol, verbose }
204
+ }
205
+
206
+ pub fn set_c(&mut self, c: f64) {
207
+ self.c = c;
208
+ }
209
+
210
+ pub fn set_max_iter(&mut self, max_iter: usize) {
211
+ self.max_iter = max_iter;
212
+ }
213
+
214
+ pub fn set_verbose(&mut self, verbose: bool) {
215
+ self.verbose = verbose;
216
+ }
217
+
218
+ pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
219
+ let n_samples = features.len();
220
+ let n_features = if n_samples > 0 { features[0].len() } else { 0 };
221
+
222
+ // Convert to compact sparse format (f32 for memory/cache efficiency)
223
+ let sparse_features: Vec<SparseVec> = features
224
+ .par_iter()
225
+ .map(|dense| {
226
+ dense
227
+ .iter()
228
+ .enumerate()
229
+ .filter(|&(_, &v)| v.abs() > 1e-10)
230
+ .map(|(i, &v)| (i as u32, v as f32))
231
+ .collect()
232
+ })
233
+ .collect();
234
+
235
+ // Precompute ||x_i||^2
236
+ let x_sq_norms: Vec<f32> = sparse_features
237
+ .par_iter()
238
+ .map(|x| x.iter().map(|&(_, v)| v * v).sum())
239
+ .collect();
240
+
241
+ // Get unique classes
242
+ let mut classes: Vec<String> = labels.iter().cloned().collect();
243
+ classes.sort();
244
+ classes.dedup();
245
+ let n_classes = classes.len();
246
+
247
+ let class_to_idx: HashMap<String, usize> = classes
248
+ .iter()
249
+ .enumerate()
250
+ .map(|(i, c)| (c.clone(), i))
251
+ .collect();
252
+
253
+ let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
254
+
255
+ // Train binary classifiers in parallel (one-vs-rest)
256
+ let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
257
+ .into_par_iter()
258
+ .map(|class_idx| {
259
+ let y_binary: Vec<i8> = y_idx
260
+ .iter()
261
+ .map(|&idx| if idx == class_idx { 1 } else { -1 })
262
+ .collect();
263
+
264
+ solve_l2r_l2_svc(
265
+ &sparse_features,
266
+ &y_binary,
267
+ &x_sq_norms,
268
+ n_features,
269
+ self.c as f32,
270
+ self.tol as f32,
271
+ self.max_iter,
272
+ )
273
+ })
274
+ .collect();
275
+
276
+ let weights = results.iter().map(|(w, _)| w.clone()).collect();
277
+ let biases = results.iter().map(|(_, b)| *b).collect();
278
+
279
+ LinearSVM {
280
+ weights,
281
+ biases,
282
+ classes,
283
+ n_features,
284
+ }
285
+ }
286
+ }
287
+
288
+ /// LIBLINEAR's solve_l2r_l2_svc - Dual Coordinate Descent for L2-loss SVM
289
+ ///
290
+ /// Solves: min_α 0.5 * α^T * Q * α - e^T * α, s.t. α_i ≥ 0
291
+ /// where Q_ij = y_i * y_j * x_i^T * x_j + δ_ij / (2C)
292
+ ///
293
+ /// Primal-dual relationship: w = Σ α_i * y_i * x_i
294
+ #[inline(never)]
295
+ fn solve_l2r_l2_svc(
296
+ x: &[SparseVec],
297
+ y: &[i8],
298
+ x_sq_norms: &[f32],
299
+ n_features: usize,
300
+ c: f32,
301
+ eps: f32,
302
+ max_iter: usize,
303
+ ) -> (Vec<f32>, f32) {
304
+ let n = x.len();
305
+
306
+ // D_ii = 1/(2C) for L2-loss SVM
307
+ let diag = 0.5 / c;
308
+
309
+ // QD[i] = ||x_i||^2 + D_ii
310
+ let qd: Vec<f32> = x_sq_norms.iter().map(|&xn| xn + diag).collect();
311
+
312
+ // Initialize α = 0
313
+ let mut alpha = vec![0.0f32; n];
314
+
315
+ // w = Σ α_i * y_i * x_i (initially 0)
316
+ let mut w = vec![0.0f32; n_features];
317
+
318
+ // Index for permutation
319
+ let mut index: Vec<usize> = (0..n).collect();
320
+
321
+ // Main loop
322
+ for iter in 0..max_iter {
323
+ // Shuffle indices
324
+ for i in 0..n {
325
+ let j = i + (iter * 1103515245 + 12345) % (n - i).max(1);
326
+ index.swap(i, j);
327
+ }
328
+
329
+ let mut max_violation = 0.0f32;
330
+
331
+ for &i in &index {
332
+ let yi = y[i] as f32;
333
+ let xi = &x[i];
334
+
335
+ // G = y_i * (w · x_i) - 1 + D_ii * α_i
336
+ let wxi: f32 = xi.iter().map(|&(j, v)| w[j as usize] * v).sum();
337
+ let g = yi * wxi - 1.0 + diag * alpha[i];
338
+
339
+ // Projected gradient (α ≥ 0, no upper bound for L2-loss)
340
+ let pg = if alpha[i] == 0.0 { g.min(0.0) } else { g };
341
+
342
+ max_violation = max_violation.max(pg.abs());
343
+
344
+ if pg.abs() > 1e-12 {
345
+ let alpha_old = alpha[i];
346
+
347
+ // α_i = max(0, α_i - G/Q_ii)
348
+ alpha[i] = (alpha[i] - g / qd[i]).max(0.0);
349
+
350
+ // Update w: w += (α_new - α_old) * y_i * x_i
351
+ let d = (alpha[i] - alpha_old) * yi;
352
+ if d.abs() > 1e-12 {
353
+ for &(j, v) in xi.iter() {
354
+ w[j as usize] += d * v;
355
+ }
356
+ }
357
+ }
358
+ }
359
+
360
+ // Stopping criterion
361
+ if max_violation <= eps {
362
+ break;
363
+ }
364
+ }
365
+
366
+ // Compute bias from KKT conditions
367
+ // For α_i > 0: y_i * (w · x_i + b) = 1 - α_i / (2C)
368
+ let mut bias_sum = 0.0f32;
369
+ let mut n_sv = 0;
370
+
371
+ for i in 0..n {
372
+ if alpha[i] > 1e-8 {
373
+ let yi = y[i] as f32;
374
+ let wxi: f32 = x[i].iter().map(|&(j, v)| w[j as usize] * v).sum();
375
+ // b = y_i * (1 - α_i * diag) - w · x_i
376
+ bias_sum += yi * (1.0 - alpha[i] * diag) - wxi;
377
+ n_sv += 1;
378
+ }
379
+ }
380
+
381
+ let bias = if n_sv > 0 { bias_sum / n_sv as f32 } else { 0.0 };
382
+
383
+ (w, bias)
384
+ }
385
+
386
+ /// Fast SVM using Pegasos algorithm
387
+ #[pyclass]
388
+ pub struct FastSVMTrainer {
389
+ c: f64,
390
+ max_iter: usize,
391
+ }
392
+
393
+ #[pymethods]
394
+ impl FastSVMTrainer {
395
+ #[new]
396
+ #[pyo3(signature = (c=1.0, max_iter=100))]
397
+ pub fn new(c: f64, max_iter: usize) -> Self {
398
+ Self { c, max_iter }
399
+ }
400
+
401
+ pub fn train(&self, features: Vec<Vec<f64>>, labels: Vec<String>) -> LinearSVM {
402
+ let n_samples = features.len();
403
+ let n_features = if n_samples > 0 { features[0].len() } else { 0 };
404
+
405
+ let sparse_features: Vec<SparseVec> = features
406
+ .par_iter()
407
+ .map(|dense| {
408
+ dense
409
+ .iter()
410
+ .enumerate()
411
+ .filter(|&(_, &v)| v.abs() > 1e-10)
412
+ .map(|(i, &v)| (i as u32, v as f32))
413
+ .collect()
414
+ })
415
+ .collect();
416
+
417
+ let mut classes: Vec<String> = labels.iter().cloned().collect();
418
+ classes.sort();
419
+ classes.dedup();
420
+ let n_classes = classes.len();
421
+
422
+ let class_to_idx: HashMap<String, usize> = classes
423
+ .iter()
424
+ .enumerate()
425
+ .map(|(i, c)| (c.clone(), i))
426
+ .collect();
427
+
428
+ let y_idx: Vec<usize> = labels.iter().map(|l| class_to_idx[l]).collect();
429
+
430
+ let results: Vec<(Vec<f32>, f32)> = (0..n_classes)
431
+ .into_par_iter()
432
+ .map(|class_idx| {
433
+ let y_binary: Vec<i8> = y_idx
434
+ .iter()
435
+ .map(|&idx| if idx == class_idx { 1 } else { -1 })
436
+ .collect();
437
+
438
+ pegasos(&sparse_features, &y_binary, n_features, self.c as f32, self.max_iter)
439
+ })
440
+ .collect();
441
+
442
+ LinearSVM {
443
+ weights: results.iter().map(|(w, _)| w.clone()).collect(),
444
+ biases: results.iter().map(|(_, b)| *b).collect(),
445
+ classes,
446
+ n_features,
447
+ }
448
+ }
449
+ }
450
+
451
+ /// Pegasos algorithm with lazy scaling
452
+ #[inline(never)]
453
+ fn pegasos(
454
+ x: &[SparseVec],
455
+ y: &[i8],
456
+ n_features: usize,
457
+ c: f32,
458
+ max_iter: usize,
459
+ ) -> (Vec<f32>, f32) {
460
+ let n = x.len();
461
+ let lambda = 1.0 / c;
462
+
463
+ let mut w = vec![0.0f32; n_features];
464
+ let mut scale = 1.0f32;
465
+ let mut b = 0.0f32;
466
+
467
+ let eta0 = 0.5;
468
+ let t0 = 1.0 / (eta0 * lambda);
469
+
470
+ let mut indices: Vec<usize> = (0..n).collect();
471
+
472
+ for epoch in 0..max_iter {
473
+ // Shuffle
474
+ for i in 0..n {
475
+ let j = (i + epoch * 1103515245 + 12345) % n;
476
+ indices.swap(i, j);
477
+ }
478
+
479
+ for (t_inner, &i) in indices.iter().enumerate() {
480
+ let t = (epoch * n + t_inner) as f32;
481
+ let eta = 1.0 / (lambda * (t + t0));
482
+
483
+ let yi = y[i] as f32;
484
+ let xi = &x[i];
485
+
486
+ let margin: f32 = scale * xi.iter().map(|&(j, v)| w[j as usize] * v).sum::<f32>() + b;
487
+
488
+ scale *= 1.0 - eta * lambda;
489
+
490
+ if scale < 1e-9 {
491
+ for wj in w.iter_mut() {
492
+ *wj *= scale;
493
+ }
494
+ scale = 1.0;
495
+ }
496
+
497
+ if yi * margin < 1.0 {
498
+ let update = eta / scale;
499
+ for &(j, v) in xi.iter() {
500
+ w[j as usize] += update * yi * v;
501
+ }
502
+ b += eta * yi * 0.1;
503
+ }
504
+ }
505
+ }
506
+
507
+ for wj in w.iter_mut() {
508
+ *wj *= scale;
509
+ }
510
+
511
+ (w, b)
512
+ }
extensions/underthesea_core_extend/src/tfidf.rs ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! TF-IDF Vectorizer implementation
2
+
3
+ use hashbrown::HashMap;
4
+ use pyo3::prelude::*;
5
+ use rayon::prelude::*;
6
+ use serde::{Deserialize, Serialize};
7
+ use std::fs::File;
8
+ use std::io::{BufReader, BufWriter};
9
+
10
+ /// TF-IDF Vectorizer
11
+ ///
12
+ /// Converts text documents into TF-IDF feature vectors.
13
+ #[pyclass]
14
+ #[derive(Clone, Serialize, Deserialize)]
15
+ pub struct TfIdfVectorizer {
16
+ /// Vocabulary: word -> index
17
+ vocab: HashMap<String, usize>,
18
+ /// Inverse vocabulary: index -> word
19
+ inv_vocab: Vec<String>,
20
+ /// IDF values for each term
21
+ idf: Vec<f64>,
22
+ /// Number of documents used for fitting
23
+ n_docs: usize,
24
+ /// Maximum number of features
25
+ max_features: usize,
26
+ /// N-gram range (min, max)
27
+ ngram_range: (usize, usize),
28
+ /// Minimum document frequency
29
+ min_df: usize,
30
+ /// Maximum document frequency (as ratio)
31
+ max_df: f64,
32
+ /// Whether the vectorizer is fitted
33
+ is_fitted: bool,
34
+ }
35
+
36
+ #[pymethods]
37
+ impl TfIdfVectorizer {
38
+ /// Create a new TfIdfVectorizer
39
+ #[new]
40
+ #[pyo3(signature = (max_features=20000, ngram_range=(1, 2), min_df=1, max_df=1.0))]
41
+ pub fn new(
42
+ max_features: usize,
43
+ ngram_range: (usize, usize),
44
+ min_df: usize,
45
+ max_df: f64,
46
+ ) -> Self {
47
+ Self {
48
+ vocab: HashMap::new(),
49
+ inv_vocab: Vec::new(),
50
+ idf: Vec::new(),
51
+ n_docs: 0,
52
+ max_features,
53
+ ngram_range,
54
+ min_df,
55
+ max_df,
56
+ is_fitted: false,
57
+ }
58
+ }
59
+
60
+ /// Fit the vectorizer on a list of documents
61
+ pub fn fit(&mut self, documents: Vec<String>) {
62
+ let n_docs = documents.len();
63
+ self.n_docs = n_docs;
64
+
65
+ // Count document frequency for each term
66
+ let mut df: HashMap<String, usize> = HashMap::new();
67
+
68
+ for doc in &documents {
69
+ let tokens = self.tokenize(doc);
70
+ let unique_tokens: std::collections::HashSet<_> = tokens.into_iter().collect();
71
+ for token in unique_tokens {
72
+ *df.entry(token).or_insert(0) += 1;
73
+ }
74
+ }
75
+
76
+ // Filter by min_df and max_df
77
+ let max_df_count = (self.max_df * n_docs as f64) as usize;
78
+ let mut filtered: Vec<(String, usize)> = df
79
+ .into_iter()
80
+ .filter(|(_, count)| *count >= self.min_df && *count <= max_df_count)
81
+ .collect();
82
+
83
+ // Sort by frequency (descending) and take top max_features
84
+ filtered.sort_by(|a, b| b.1.cmp(&a.1));
85
+ filtered.truncate(self.max_features);
86
+
87
+ // Build vocabulary
88
+ self.vocab.clear();
89
+ self.inv_vocab.clear();
90
+ self.idf.clear();
91
+
92
+ for (idx, (term, doc_freq)) in filtered.into_iter().enumerate() {
93
+ self.vocab.insert(term.clone(), idx);
94
+ self.inv_vocab.push(term);
95
+ // IDF with smoothing: log((n_docs + 1) / (df + 1)) + 1
96
+ let idf_value = ((n_docs as f64 + 1.0) / (doc_freq as f64 + 1.0)).ln() + 1.0;
97
+ self.idf.push(idf_value);
98
+ }
99
+
100
+ self.is_fitted = true;
101
+ }
102
+
103
+ /// Transform a single document to TF-IDF vector (sparse format)
104
+ ///
105
+ /// Returns list of (index, value) tuples
106
+ pub fn transform(&self, document: &str) -> Vec<(usize, f64)> {
107
+ if !self.is_fitted {
108
+ return Vec::new();
109
+ }
110
+
111
+ let tokens = self.tokenize(document);
112
+ let mut tf: HashMap<usize, usize> = HashMap::new();
113
+
114
+ for token in &tokens {
115
+ if let Some(&idx) = self.vocab.get(token) {
116
+ *tf.entry(idx).or_insert(0) += 1;
117
+ }
118
+ }
119
+
120
+ let n_tokens = tokens.len() as f64;
121
+ if n_tokens == 0.0 {
122
+ return Vec::new();
123
+ }
124
+
125
+ let mut result: Vec<(usize, f64)> = tf
126
+ .into_iter()
127
+ .map(|(idx, count)| {
128
+ let tf_value = count as f64 / n_tokens;
129
+ let tfidf = tf_value * self.idf[idx];
130
+ (idx, tfidf)
131
+ })
132
+ .collect();
133
+
134
+ // L2 normalize
135
+ let norm: f64 = result.iter().map(|(_, v)| v * v).sum::<f64>().sqrt();
136
+ if norm > 0.0 {
137
+ for (_, v) in &mut result {
138
+ *v /= norm;
139
+ }
140
+ }
141
+
142
+ result.sort_by_key(|(idx, _)| *idx);
143
+ result
144
+ }
145
+
146
+ /// Transform a single document to dense TF-IDF vector
147
+ pub fn transform_dense(&self, document: &str) -> Vec<f64> {
148
+ let sparse = self.transform(document);
149
+ let mut dense = vec![0.0; self.vocab.len()];
150
+ for (idx, val) in sparse {
151
+ dense[idx] = val;
152
+ }
153
+ dense
154
+ }
155
+
156
+ /// Transform multiple documents to dense TF-IDF vectors (parallel)
157
+ pub fn transform_batch(&self, documents: Vec<String>) -> Vec<Vec<f64>> {
158
+ documents
159
+ .par_iter()
160
+ .map(|doc| self.transform_dense(doc))
161
+ .collect()
162
+ }
163
+
164
+ /// Transform multiple documents to sparse TF-IDF vectors (parallel)
165
+ pub fn transform_batch_sparse(&self, documents: Vec<String>) -> Vec<Vec<(usize, f64)>> {
166
+ documents
167
+ .par_iter()
168
+ .map(|doc| self.transform(doc))
169
+ .collect()
170
+ }
171
+
172
+ /// Fit and transform in one step
173
+ pub fn fit_transform(&mut self, documents: Vec<String>) -> Vec<Vec<f64>> {
174
+ self.fit(documents.clone());
175
+ self.transform_batch(documents)
176
+ }
177
+
178
+ /// Get vocabulary size
179
+ #[getter]
180
+ pub fn vocab_size(&self) -> usize {
181
+ self.vocab.len()
182
+ }
183
+
184
+ /// Get feature names (vocabulary terms)
185
+ pub fn get_feature_names(&self) -> Vec<String> {
186
+ self.inv_vocab.clone()
187
+ }
188
+
189
+ /// Check if vectorizer is fitted
190
+ #[getter]
191
+ pub fn is_fitted(&self) -> bool {
192
+ self.is_fitted
193
+ }
194
+
195
+ /// Save vectorizer to file
196
+ pub fn save(&self, path: &str) -> PyResult<()> {
197
+ let file = File::create(path)
198
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
199
+ let writer = BufWriter::new(file);
200
+ serde_json::to_writer(writer, self)
201
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
202
+ Ok(())
203
+ }
204
+
205
+ /// Load vectorizer from file
206
+ #[staticmethod]
207
+ pub fn load(path: &str) -> PyResult<Self> {
208
+ let file = File::open(path)
209
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
210
+ let reader = BufReader::new(file);
211
+ let vectorizer: Self = serde_json::from_reader(reader)
212
+ .map_err(|e| PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()))?;
213
+ Ok(vectorizer)
214
+ }
215
+ }
216
+
217
+ impl TfIdfVectorizer {
218
+ /// Tokenize document into n-grams
219
+ fn tokenize(&self, document: &str) -> Vec<String> {
220
+ let words: Vec<&str> = document.split_whitespace().collect();
221
+ let mut tokens = Vec::new();
222
+
223
+ for n in self.ngram_range.0..=self.ngram_range.1 {
224
+ if n > words.len() {
225
+ continue;
226
+ }
227
+ for i in 0..=(words.len() - n) {
228
+ let ngram = words[i..i + n].join(" ");
229
+ tokens.push(ngram);
230
+ }
231
+ }
232
+
233
+ tokens
234
+ }
235
+ }
extensions/underthesea_core_extend/uv.lock ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ version = 1
2
+ revision = 3
3
+ requires-python = ">=3.10"
4
+
5
+ [[package]]
6
+ name = "underthesea-core-extend"
7
+ version = "0.1.0"
8
+ source = { editable = "." }
pyproject.toml CHANGED
@@ -1,24 +1,23 @@
1
  [project]
2
  name = "sen"
3
- version = "1.0.0"
4
- description = "Vietnamese Text Classification Model"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  license = "Apache-2.0"
8
  authors = [
9
  {name = "UnderTheSea NLP", email = "undertheseanlp@gmail.com"}
10
  ]
11
- keywords = ["vietnamese", "nlp", "text-classification", "sklearn"]
12
  dependencies = [
13
- "scikit-learn>=1.0.0",
14
- "joblib>=1.0.0",
15
- "numpy>=1.20.0",
16
  ]
17
 
18
  [project.optional-dependencies]
19
  dev = [
20
  "pytest>=7.0.0",
21
  "huggingface-hub>=0.20.0",
 
22
  ]
23
 
24
  [project.urls]
 
1
  [project]
2
  name = "sen"
3
+ version = "1.1.0"
4
+ description = "Vietnamese Text Classification Model - Rust-powered"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  license = "Apache-2.0"
8
  authors = [
9
  {name = "UnderTheSea NLP", email = "undertheseanlp@gmail.com"}
10
  ]
11
+ keywords = ["vietnamese", "nlp", "text-classification", "rust", "svm"]
12
  dependencies = [
13
+ "underthesea_core_extend>=0.1.0",
 
 
14
  ]
15
 
16
  [project.optional-dependencies]
17
  dev = [
18
  "pytest>=7.0.0",
19
  "huggingface-hub>=0.20.0",
20
+ "maturin>=1.0.0",
21
  ]
22
 
23
  [project.urls]
src/sen/text_classifier.py CHANGED
@@ -1,26 +1,20 @@
1
  """
2
- Sen Text Classifier - sklearn-based classifier compatible with underthesea API.
3
 
4
  Based on: "A Comparative Study on Vietnamese Text Classification Methods"
5
  Vu et al., RIVF 2007
6
  https://ieeexplore.ieee.org/document/4223084/
7
 
8
  Methods:
9
- - TF-IDF vectorization
10
- - SVM (Support Vector Machine) classifier
11
  """
12
 
13
  import json
14
  import os
15
  from typing import List, Optional, Union
16
 
17
- import joblib
18
- from sklearn.feature_extraction.text import TfidfVectorizer
19
- from sklearn.svm import LinearSVC
20
- from sklearn.pipeline import Pipeline
21
- from sklearn.preprocessing import LabelEncoder
22
- from sklearn.metrics import accuracy_score, f1_score, classification_report
23
- import numpy as np
24
 
25
 
26
  class Label:
@@ -59,8 +53,9 @@ class Sentence:
59
 
60
  class SenTextClassifier:
61
  """
62
- sklearn-based text classifier using TF-IDF + SVM.
63
 
 
64
  Compatible with underthesea API.
65
 
66
  Reference:
@@ -71,43 +66,28 @@ class SenTextClassifier:
71
  def __init__(
72
  self,
73
  # TF-IDF parameters
74
- max_features: int = 10000,
75
  ngram_range: tuple = (1, 2),
76
- min_df: int = 2,
77
- max_df: float = 0.95,
78
- sublinear_tf: bool = True,
79
  # SVM parameters
80
- C: float = 1.0,
81
  max_iter: int = 1000,
 
 
82
  ):
83
  self.max_features = max_features
84
  self.ngram_range = ngram_range
85
  self.min_df = min_df
86
  self.max_df = max_df
87
- self.sublinear_tf = sublinear_tf
88
- self.C = C
89
  self.max_iter = max_iter
 
 
90
 
91
- self.label_encoder = LabelEncoder()
92
- self.pipeline = None
93
- self.labels_ = None
94
-
95
- def _build_pipeline(self) -> Pipeline:
96
- """Build sklearn pipeline with TF-IDF + SVM."""
97
- return Pipeline([
98
- ("tfidf", TfidfVectorizer(
99
- max_features=self.max_features,
100
- ngram_range=self.ngram_range,
101
- min_df=self.min_df,
102
- max_df=self.max_df,
103
- sublinear_tf=self.sublinear_tf,
104
- )),
105
- ("clf", LinearSVC(
106
- C=self.C,
107
- max_iter=self.max_iter,
108
- random_state=42,
109
- )),
110
- ])
111
 
112
  def train(
113
  self,
@@ -128,38 +108,58 @@ class SenTextClassifier:
128
  Returns:
129
  Dictionary with training metrics
130
  """
131
- # Encode labels
132
- y_train = self.label_encoder.fit_transform(train_labels)
133
- self.labels_ = list(self.label_encoder.classes_)
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Build and train pipeline
136
- self.pipeline = self._build_pipeline()
137
- self.pipeline.fit(train_texts, y_train)
 
 
 
 
 
138
 
139
  # Calculate training metrics
140
- y_train_pred = self.pipeline.predict(train_texts)
141
- train_acc = accuracy_score(y_train, y_train_pred)
142
- train_f1 = f1_score(y_train, y_train_pred, average="weighted")
 
 
143
 
144
  results = {
145
  "train_accuracy": train_acc,
146
  "train_f1": train_f1,
147
  "num_classes": len(self.labels_),
148
  "num_samples": len(train_texts),
 
149
  }
150
 
151
  print(f"Training completed:")
152
  print(f" - Samples: {len(train_texts)}")
153
  print(f" - Classes: {len(self.labels_)}")
 
154
  print(f" - Train accuracy: {train_acc:.4f}")
155
  print(f" - Train F1: {train_f1:.4f}")
156
 
157
  # Validation metrics
158
  if val_texts and val_labels:
159
- y_val = self.label_encoder.transform(val_labels)
160
- y_val_pred = self.pipeline.predict(val_texts)
161
- val_acc = accuracy_score(y_val, y_val_pred)
162
- val_f1 = f1_score(y_val, y_val_pred, average="weighted")
163
 
164
  results["val_accuracy"] = val_acc
165
  results["val_f1"] = val_f1
@@ -169,6 +169,28 @@ class SenTextClassifier:
169
 
170
  return results
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def predict(self, sentence: Sentence) -> None:
173
  """
174
  Predict label for a sentence (underthesea-compatible API).
@@ -176,23 +198,11 @@ class SenTextClassifier:
176
  Args:
177
  sentence: Sentence object with text attribute
178
  """
179
- if self.pipeline is None:
180
  raise ValueError("Model not trained. Call train() first or load a model.")
181
 
182
- pred_idx = self.pipeline.predict([sentence.text])[0]
183
- label_value = self.label_encoder.inverse_transform([pred_idx])[0]
184
-
185
- # Get confidence score using decision function
186
- try:
187
- decision = self.pipeline.decision_function([sentence.text])[0]
188
- if isinstance(decision, np.ndarray):
189
- score = float(np.max(np.abs(decision)))
190
- else:
191
- score = float(abs(decision))
192
- # Normalize to 0-1 range using sigmoid
193
- score = 1 / (1 + np.exp(-score))
194
- except Exception:
195
- score = 1.0
196
 
197
  sentence.labels = []
198
  sentence.add_labels([Label(label_value, score)])
@@ -207,23 +217,17 @@ class SenTextClassifier:
207
  Returns:
208
  List of Label objects
209
  """
210
- if self.pipeline is None:
211
  raise ValueError("Model not trained. Call train() first or load a model.")
212
 
213
- pred_indices = self.pipeline.predict(texts)
214
- label_values = self.label_encoder.inverse_transform(pred_indices)
215
-
216
- # Get confidence scores
217
- try:
218
- decisions = self.pipeline.decision_function(texts)
219
- if decisions.ndim == 1:
220
- scores = 1 / (1 + np.exp(-np.abs(decisions)))
221
- else:
222
- scores = 1 / (1 + np.exp(-np.max(np.abs(decisions), axis=1)))
223
- except Exception:
224
- scores = [1.0] * len(texts)
225
 
226
- return [Label(val, float(score)) for val, score in zip(label_values, scores)]
227
 
228
  def evaluate(self, texts: List[str], labels: List[str]) -> dict:
229
  """
@@ -236,23 +240,46 @@ class SenTextClassifier:
236
  Returns:
237
  Dictionary with evaluation metrics
238
  """
239
- y_true = self.label_encoder.transform(labels)
240
- y_pred = self.pipeline.predict(texts)
 
241
 
242
- acc = accuracy_score(y_true, y_pred)
243
- f1 = f1_score(y_true, y_pred, average="weighted")
244
 
245
  print(f"Evaluation:")
246
  print(f" - Accuracy: {acc:.4f}")
247
  print(f" - F1 (weighted): {f1:.4f}")
248
- print("\nClassification Report:")
249
- print(classification_report(
250
- y_true, y_pred,
251
- target_names=self.labels_
252
- ))
253
 
254
  return {"accuracy": acc, "f1": f1}
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def save(self, path: str) -> None:
257
  """
258
  Save model to disk.
@@ -262,23 +289,25 @@ class SenTextClassifier:
262
  """
263
  os.makedirs(path, exist_ok=True)
264
 
265
- # Save pipeline
266
- joblib.dump(self.pipeline, os.path.join(path, "pipeline.joblib"))
267
 
268
- # Save label encoder
269
- joblib.dump(self.label_encoder, os.path.join(path, "label_encoder.joblib"))
270
 
271
  # Save metadata
272
  metadata = {
273
- "estimator": "PIPELINE",
274
  "max_features": self.max_features,
275
  "ngram_range": self.ngram_range,
276
  "min_df": self.min_df,
277
  "max_df": self.max_df,
278
- "sublinear_tf": self.sublinear_tf,
279
- "C": self.C,
280
  "max_iter": self.max_iter,
 
281
  "labels": self.labels_,
 
 
282
  }
283
  with open(os.path.join(path, "metadata.json"), "w", encoding="utf-8") as f:
284
  json.dump(metadata, f, ensure_ascii=False, indent=2)
@@ -302,19 +331,21 @@ class SenTextClassifier:
302
 
303
  # Create instance with saved parameters
304
  classifier = cls(
305
- max_features=metadata.get("max_features", 10000),
306
  ngram_range=tuple(metadata.get("ngram_range", (1, 2))),
307
- min_df=metadata.get("min_df", 2),
308
- max_df=metadata.get("max_df", 0.95),
309
- sublinear_tf=metadata.get("sublinear_tf", True),
310
- C=metadata.get("C", 1.0),
311
  max_iter=metadata.get("max_iter", 1000),
 
312
  )
313
 
314
- # Load pipeline and label encoder
315
- classifier.pipeline = joblib.load(os.path.join(path, "pipeline.joblib"))
316
- classifier.label_encoder = joblib.load(os.path.join(path, "label_encoder.joblib"))
317
- classifier.labels_ = metadata.get("labels", list(classifier.label_encoder.classes_))
 
 
318
 
319
  print(f"Model loaded from: {path}")
320
  return classifier
 
1
  """
2
+ Sen Text Classifier - Rust-based classifier using underthesea_core_extend.
3
 
4
  Based on: "A Comparative Study on Vietnamese Text Classification Methods"
5
  Vu et al., RIVF 2007
6
  https://ieeexplore.ieee.org/document/4223084/
7
 
8
  Methods:
9
+ - TF-IDF vectorization (Rust: underthesea_core_extend.TfIdfVectorizer)
10
+ - Linear SVM classifier (Rust: underthesea_core_extend.LinearSVM)
11
  """
12
 
13
  import json
14
  import os
15
  from typing import List, Optional, Union
16
 
17
+ from underthesea_core_extend import TfIdfVectorizer, LinearSVM, SVMTrainer
 
 
 
 
 
 
18
 
19
 
20
  class Label:
 
53
 
54
  class SenTextClassifier:
55
  """
56
+ Rust-based text classifier using TF-IDF + Linear SVM.
57
 
58
+ Uses underthesea_core_extend for fast training and inference.
59
  Compatible with underthesea API.
60
 
61
  Reference:
 
66
  def __init__(
67
  self,
68
  # TF-IDF parameters
69
+ max_features: int = 20000,
70
  ngram_range: tuple = (1, 2),
71
+ min_df: int = 1,
72
+ max_df: float = 1.0,
 
73
  # SVM parameters
74
+ c: float = 1.0,
75
  max_iter: int = 1000,
76
+ tol: float = 0.1,
77
+ verbose: bool = True,
78
  ):
79
  self.max_features = max_features
80
  self.ngram_range = ngram_range
81
  self.min_df = min_df
82
  self.max_df = max_df
83
+ self.c = c
 
84
  self.max_iter = max_iter
85
+ self.tol = tol
86
+ self.verbose = verbose
87
 
88
+ self.vectorizer: Optional[TfIdfVectorizer] = None
89
+ self.classifier: Optional[LinearSVM] = None
90
+ self.labels_: Optional[List[str]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def train(
93
  self,
 
108
  Returns:
109
  Dictionary with training metrics
110
  """
111
+ # Get unique labels
112
+ self.labels_ = sorted(list(set(train_labels)))
113
+
114
+ # Build and fit vectorizer
115
+ self.vectorizer = TfIdfVectorizer(
116
+ max_features=self.max_features,
117
+ ngram_range=self.ngram_range,
118
+ min_df=self.min_df,
119
+ max_df=self.max_df,
120
+ )
121
+ self.vectorizer.fit(train_texts)
122
+
123
+ # Transform to features
124
+ train_features = self.vectorizer.transform_batch(train_texts)
125
 
126
+ # Build and train SVM model
127
+ trainer = SVMTrainer(
128
+ c=self.c,
129
+ max_iter=self.max_iter,
130
+ tol=self.tol,
131
+ verbose=self.verbose,
132
+ )
133
+ self.classifier = trainer.train(train_features, train_labels)
134
 
135
  # Calculate training metrics
136
+ train_preds = self.classifier.predict_batch(train_features)
137
+ train_acc = sum(1 for p, t in zip(train_preds, train_labels) if p == t) / len(train_labels)
138
+
139
+ # Calculate F1 score
140
+ train_f1 = self._calculate_f1(train_labels, train_preds)
141
 
142
  results = {
143
  "train_accuracy": train_acc,
144
  "train_f1": train_f1,
145
  "num_classes": len(self.labels_),
146
  "num_samples": len(train_texts),
147
+ "vocab_size": self.vectorizer.vocab_size,
148
  }
149
 
150
  print(f"Training completed:")
151
  print(f" - Samples: {len(train_texts)}")
152
  print(f" - Classes: {len(self.labels_)}")
153
+ print(f" - Vocab size: {self.vectorizer.vocab_size}")
154
  print(f" - Train accuracy: {train_acc:.4f}")
155
  print(f" - Train F1: {train_f1:.4f}")
156
 
157
  # Validation metrics
158
  if val_texts and val_labels:
159
+ val_features = self.vectorizer.transform_batch(val_texts)
160
+ val_preds = self.classifier.predict_batch(val_features)
161
+ val_acc = sum(1 for p, t in zip(val_preds, val_labels) if p == t) / len(val_labels)
162
+ val_f1 = self._calculate_f1(val_labels, val_preds)
163
 
164
  results["val_accuracy"] = val_acc
165
  results["val_f1"] = val_f1
 
169
 
170
  return results
171
 
172
+ def _calculate_f1(self, y_true: List[str], y_pred: List[str]) -> float:
173
+ """Calculate weighted F1 score."""
174
+ from collections import Counter
175
+
176
+ label_counts = Counter(y_true)
177
+ total = len(y_true)
178
+
179
+ f1_sum = 0.0
180
+ for label in self.labels_:
181
+ tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
182
+ fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
183
+ fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
184
+
185
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
186
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
187
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
188
+
189
+ weight = label_counts[label] / total
190
+ f1_sum += f1 * weight
191
+
192
+ return f1_sum
193
+
194
  def predict(self, sentence: Sentence) -> None:
195
  """
196
  Predict label for a sentence (underthesea-compatible API).
 
198
  Args:
199
  sentence: Sentence object with text attribute
200
  """
201
+ if self.classifier is None or self.vectorizer is None:
202
  raise ValueError("Model not trained. Call train() first or load a model.")
203
 
204
+ features = self.vectorizer.transform_dense(sentence.text)
205
+ label_value, score = self.classifier.predict_with_score(features)
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  sentence.labels = []
208
  sentence.add_labels([Label(label_value, score)])
 
217
  Returns:
218
  List of Label objects
219
  """
220
+ if self.classifier is None or self.vectorizer is None:
221
  raise ValueError("Model not trained. Call train() first or load a model.")
222
 
223
+ # Use dense transform (faster Python-Rust interface)
224
+ features = self.vectorizer.transform_batch(texts)
225
+ results = []
226
+ for feat in features:
227
+ label_value, score = self.classifier.predict_with_score(feat)
228
+ results.append(Label(label_value, float(score)))
 
 
 
 
 
 
229
 
230
+ return results
231
 
232
  def evaluate(self, texts: List[str], labels: List[str]) -> dict:
233
  """
 
240
  Returns:
241
  Dictionary with evaluation metrics
242
  """
243
+ # Use dense transform (faster Python-Rust interface)
244
+ features = self.vectorizer.transform_batch(texts)
245
+ y_pred = self.classifier.predict_batch(features)
246
 
247
+ acc = sum(1 for p, t in zip(y_pred, labels) if p == t) / len(labels)
248
+ f1 = self._calculate_f1(labels, y_pred)
249
 
250
  print(f"Evaluation:")
251
  print(f" - Accuracy: {acc:.4f}")
252
  print(f" - F1 (weighted): {f1:.4f}")
253
+
254
+ # Print classification report
255
+ self._print_classification_report(labels, y_pred)
 
 
256
 
257
  return {"accuracy": acc, "f1": f1}
258
 
259
+ def _print_classification_report(self, y_true: List[str], y_pred: List[str]):
260
+ """Print classification report."""
261
+ from collections import Counter
262
+
263
+ print("\nClassification Report:")
264
+ print(f"{'':>20} {'precision':>10} {'recall':>10} {'f1-score':>10} {'support':>10}")
265
+ print()
266
+
267
+ label_counts = Counter(y_true)
268
+
269
+ for label in self.labels_:
270
+ tp = sum(1 for t, p in zip(y_true, y_pred) if t == label and p == label)
271
+ fp = sum(1 for t, p in zip(y_true, y_pred) if t != label and p == label)
272
+ fn = sum(1 for t, p in zip(y_true, y_pred) if t == label and p != label)
273
+
274
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
275
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
276
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
277
+ support = label_counts[label]
278
+
279
+ print(f"{label:>20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f} {support:>10}")
280
+
281
+ print()
282
+
283
  def save(self, path: str) -> None:
284
  """
285
  Save model to disk.
 
289
  """
290
  os.makedirs(path, exist_ok=True)
291
 
292
+ # Save vectorizer
293
+ self.vectorizer.save(os.path.join(path, "vectorizer.json"))
294
 
295
+ # Save classifier
296
+ self.classifier.save(os.path.join(path, "classifier.json"))
297
 
298
  # Save metadata
299
  metadata = {
300
+ "estimator": "RUST_SVM",
301
  "max_features": self.max_features,
302
  "ngram_range": self.ngram_range,
303
  "min_df": self.min_df,
304
  "max_df": self.max_df,
305
+ "c": self.c,
 
306
  "max_iter": self.max_iter,
307
+ "tol": self.tol,
308
  "labels": self.labels_,
309
+ "vocab_size": self.vectorizer.vocab_size,
310
+ "n_classes": self.classifier.n_classes,
311
  }
312
  with open(os.path.join(path, "metadata.json"), "w", encoding="utf-8") as f:
313
  json.dump(metadata, f, ensure_ascii=False, indent=2)
 
331
 
332
  # Create instance with saved parameters
333
  classifier = cls(
334
+ max_features=metadata.get("max_features", 20000),
335
  ngram_range=tuple(metadata.get("ngram_range", (1, 2))),
336
+ min_df=metadata.get("min_df", 1),
337
+ max_df=metadata.get("max_df", 1.0),
338
+ c=metadata.get("c", 1.0),
 
339
  max_iter=metadata.get("max_iter", 1000),
340
+ tol=metadata.get("tol", 0.1),
341
  )
342
 
343
+ # Load vectorizer
344
+ classifier.vectorizer = TfIdfVectorizer.load(os.path.join(path, "vectorizer.json"))
345
+
346
+ # Load SVM model
347
+ classifier.classifier = LinearSVM.load(os.path.join(path, "classifier.json"))
348
+ classifier.labels_ = metadata.get("labels", [])
349
 
350
  print(f"Model loaded from: {path}")
351
  return classifier