Tiep Claude Opus 4.6 commited on
Commit
903cdb2
·
1 Parent(s): b5fd35d

Refactor training to Hydra config and use underthesea imports

Browse files

- Replace Click CLI with Hydra config system for flexible training
- Add config files for all training tasks (vntc, bank, sentiment_general, sentiment_bank)
- Change imports from underthesea_core to underthesea throughout
- Move preprocessing to Rust TextPreprocessor (built into model binary)
- Delete extends/ directory (code now in underthesea_core v3.2.0)
- Add outputs/ to .gitignore for Hydra run outputs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

.gitignore CHANGED
@@ -26,6 +26,9 @@ Thumbs.db
26
  # Jupyter
27
  .ipynb_checkpoints/
28
 
 
 
 
29
  # Testing
30
  .pytest_cache/
31
  .coverage
 
26
  # Jupyter
27
  .ipynb_checkpoints/
28
 
29
+ # Hydra outputs
30
+ outputs/
31
+
32
  # Testing
33
  .pytest_cache/
34
  .coverage
pyproject.toml CHANGED
@@ -1,7 +1,7 @@
1
  [project]
2
  name = "sen"
3
  version = "1.1.0"
4
- description = "Vietnamese Text Classification - Training scripts for underthesea_core"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  license = "Apache-2.0"
 
1
  [project]
2
  name = "sen"
3
  version = "1.1.0"
4
+ description = "Vietnamese Text Classification - Training scripts for underthesea"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  license = "Apache-2.0"
src/bench.py CHANGED
@@ -19,7 +19,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVecto
19
  from sklearn.svm import LinearSVC as SklearnLinearSVC
20
  from sklearn.metrics import accuracy_score, f1_score, classification_report
21
 
22
- from underthesea_core import TextClassifier
23
 
24
 
25
  def read_file(filepath):
 
19
  from sklearn.svm import LinearSVC as SklearnLinearSVC
20
  from sklearn.metrics import accuracy_score, f1_score, classification_report
21
 
22
+ from underthesea import TextClassifier
23
 
24
 
25
  def read_file(filepath):
src/conf/bank.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # python src/train.py --config-name=bank
2
+ defaults:
3
+ - data: bank
4
+ - model: small
5
+ - _self_
6
+
7
+ output: models/sen-bank-1.0.0-${now:%Y%m%d}.bin
src/conf/config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: vntc
3
+ - model: default
4
+ - _self_
5
+
6
+ output: models/sen-${data.name}.bin
src/conf/data/bank.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: bank
2
+ source: huggingface
3
+ dataset: undertheseanlp/UTS2017_Bank
4
+ config: classification
src/conf/data/sentiment_bank.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: sentiment_bank
2
+ source: huggingface
3
+ dataset: undertheseanlp/UTS2017_Bank
4
+ config: [classification, sentiment]
5
+ label_format: "{category}#{sentiment}"
src/conf/data/sentiment_general.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: sentiment_general
2
+ source: vlsp2016
3
+ data_dir: /tmp/VLSP2016_SA
src/conf/data/vntc.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: vntc
2
+ source: local
3
+ data_dir: /home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1
src/conf/model/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ max_features: 20000
2
+ ngram_range: [1, 2]
3
+ min_df: 1
4
+ max_df: 1.0
5
+ c: 1.0
6
+ max_iter: 1000
7
+ tol: 0.1
8
+ preprocess: false
src/conf/model/sentiment.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_features: 200000
2
+ ngram_range: [1, 3]
3
+ min_df: 1
4
+ max_df: 0.9
5
+ c: 0.7
6
+ max_iter: 1000
7
+ tol: 0.0001
8
+ preprocess: true
9
+ preprocessor:
10
+ lowercase: true
11
+ unicode_normalize: true
12
+ remove_urls: true
13
+ normalize_repeated_chars: true
14
+ normalize_punctuation: true
15
+ teencode:
16
+ ko: "không"
17
+ k: "không"
18
+ hok: "không"
19
+ hem: "không"
20
+ dc: "được"
21
+ đc: "được"
22
+ dk: "được"
23
+ ntn: "như thế nào"
24
+ nc: "nói chuyện"
25
+ nt: "nhắn tin"
26
+ cx: "cũng"
27
+ cg: "cũng"
28
+ vs: "với"
29
+ vl: "vãi"
30
+ bt: "bình thường"
31
+ bth: "bình thường"
32
+ lg: "lượng"
33
+ tl: "trả lời"
34
+ ms: "mới"
35
+ r: "rồi"
36
+ mn: "mọi người"
37
+ mk: "mình"
38
+ ok: "tốt"
39
+ oke: "tốt"
40
+ sp: "sản phẩm"
41
+ hqua: "hôm qua"
42
+ hnay: "hôm nay"
43
+ tks: "cảm ơn"
44
+ thanks: "cảm ơn"
45
+ thank: "cảm ơn"
46
+ j: "gì"
47
+ z: "vậy"
48
+ v: "vậy"
49
+ đt: "điện thoại"
50
+ dt: "điện thoại"
51
+ lm: "làm"
52
+ ns: "nói"
53
+ negation_words:
54
+ - "không"
55
+ - "chẳng"
56
+ - "chả"
57
+ - "chưa"
58
+ - "đừng"
59
+ - "ko"
60
+ - "hok"
61
+ - "hem"
62
+ - "chăng"
63
+ negation_window: 2
src/conf/model/small.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ max_features: 10000
2
+ ngram_range: [1, 2]
3
+ min_df: 1
4
+ max_df: 0.9
5
+ c: 1.0
6
+ max_iter: 1000
7
+ tol: 0.0001
8
+ preprocess: false
src/conf/sentiment_bank.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # python src/train.py --config-name=sentiment_bank
2
+ defaults:
3
+ - data: sentiment_bank
4
+ - model: sentiment
5
+ - _self_
6
+
7
+ output: models/sen-sentiment-bank-1.0.0-${now:%Y%m%d}.bin
src/conf/sentiment_general.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # python src/train.py --config-name=sentiment_general
2
+ defaults:
3
+ - data: sentiment_general
4
+ - model: sentiment
5
+ - _self_
6
+
7
+ output: models/sen-sentiment-general-1.0.0-${now:%Y%m%d}.bin
src/conf/vntc.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # python src/train.py --config-name=vntc
2
+ defaults:
3
+ - data: vntc
4
+ - model: default
5
+ - _self_
6
+
7
+ output: models/sen-vntc-1.0.0-${now:%Y%m%d}.bin
src/train.py CHANGED
@@ -1,73 +1,57 @@
1
  """
2
- Training CLI for Vietnamese Text Classification.
3
 
4
  Usage:
5
- python train.py vntc --output models/sen-vntc.bin
6
- python train.py bank --output models/sen-bank.bin
 
 
 
 
 
 
 
7
  """
8
 
9
  import os
10
- import re
11
  import time
12
- import unicodedata
13
  from pathlib import Path
14
 
15
- import click
 
16
  from sklearn.metrics import accuracy_score, f1_score, classification_report
17
 
18
- from underthesea_core import TextClassifier
19
-
20
- # Vietnamese teencode dictionary
21
- _TEENCODE = {
22
- 'ko': 'không', 'k': 'không', 'hok': 'không', 'hem': 'không',
23
- 'dc': 'được', 'đc': 'được', 'dk': 'được',
24
- 'ntn': 'như thế nào',
25
- 'nc': 'nói chuyện', 'nt': 'nhắn tin',
26
- 'cx': 'cũng', 'cg': 'cũng',
27
- 'vs': 'với', 'vl': 'vãi',
28
- 'bt': 'bình thường', 'bth': 'bình thường',
29
- 'lg': 'lượng', 'tl': 'trả lời',
30
- 'ms': 'mới', 'r': 'rồi',
31
- 'mn': 'mọi người', 'mk': 'mình',
32
- 'ok': 'tốt', 'oke': 'tốt',
33
- 'sp': 'sản phẩm',
34
- 'hqua': 'hôm qua', 'hnay': 'hôm nay',
35
- 'tks': 'cảm ơn', 'thanks': 'cảm ơn', 'thank': 'cảm ơn',
36
- 'j': 'gì', 'z': 'vậy', 'v': 'vậy',
37
- 'đt': 'điện thoại', 'dt': 'điện thoại',
38
- 'lm': 'làm', 'ns': 'nói',
39
- }
40
-
41
- _NEG_WORDS = {'không', 'chẳng', 'chả', 'chưa', 'đừng', 'ko', 'hok', 'hem', 'chăng'}
42
-
43
-
44
- def preprocess_sentiment(text):
45
- """Preprocess Vietnamese text for sentiment analysis."""
46
- text = unicodedata.normalize('NFC', text)
47
- text = text.lower()
48
- text = re.sub(r'https?://\S+|www\.\S+', ' ', text)
49
- text = re.sub(r'(.)\1{2,}', r'\1\1', text)
50
- text = re.sub(r'!{2,}', '!', text)
51
- text = re.sub(r'\?{2,}', '?', text)
52
- text = re.sub(r'\.{4,}', '...', text)
53
- # Teencode expansion
54
- words = text.split()
55
- expanded = []
56
- for w in words:
57
- wl = w.strip('.,!?;:')
58
- if wl in _TEENCODE:
59
- expanded.append(_TEENCODE[wl])
60
- else:
61
- expanded.append(w)
62
- # Negation marking (2-word window)
63
- new_words = list(expanded)
64
- for i, w in enumerate(expanded):
65
- wl = w.strip('.,!?;:')
66
- if wl in _NEG_WORDS:
67
- for j in range(i + 1, min(i + 3, len(expanded))):
68
- new_words[j] = 'NEG_' + expanded[j]
69
- return ' '.join(new_words)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def read_file(filepath):
73
  """Read text file with multiple encoding attempts."""
@@ -85,184 +69,20 @@ def read_file(filepath):
85
  def load_vntc_data(data_dir):
86
  """Load VNTC data from directory."""
87
  texts, labels = [], []
88
-
89
  for folder in sorted(os.listdir(data_dir)):
90
  folder_path = os.path.join(data_dir, folder)
91
  if not os.path.isdir(folder_path):
92
  continue
93
-
94
  for fname in os.listdir(folder_path):
95
  if fname.endswith('.txt'):
96
  text = read_file(os.path.join(folder_path, fname))
97
  if text:
98
  texts.append(text)
99
  labels.append(folder)
100
-
101
  return texts, labels
102
 
103
 
104
- @click.group()
105
- def cli():
106
- """Train Vietnamese text classification models."""
107
- pass
108
-
109
-
110
- @cli.command()
111
- @click.option('--data-dir', default='/home/claude-user/projects/workspace_underthesea/VNTC/Data/10Topics/Ver1.1',
112
- help='Path to VNTC dataset')
113
- @click.option('--output', '-o', default='models/sen-vntc.bin', help='Output model path')
114
- @click.option('--max-features', default=20000, help='Maximum vocabulary size')
115
- @click.option('--ngram-min', default=1, help='Minimum n-gram')
116
- @click.option('--ngram-max', default=2, help='Maximum n-gram')
117
- @click.option('--min-df', default=2, help='Minimum document frequency')
118
- @click.option('--c', default=1.0, help='SVM regularization parameter')
119
- @click.option('--max-iter', default=1000, help='Maximum iterations')
120
- @click.option('--tol', default=0.1, help='Convergence tolerance')
121
- def vntc(data_dir, output, max_features, ngram_min, ngram_max, min_df, c, max_iter, tol):
122
- """Train on VNTC dataset (10 topics, ~84k documents)."""
123
- click.echo("=" * 70)
124
- click.echo("VNTC Dataset Training (10 Topics)")
125
- click.echo("=" * 70)
126
-
127
- train_dir = os.path.join(data_dir, "Train_Full")
128
- test_dir = os.path.join(data_dir, "Test_Full")
129
-
130
- # Load data
131
- click.echo("\nLoading data...")
132
- t0 = time.perf_counter()
133
- train_texts, train_labels = load_vntc_data(train_dir)
134
- test_texts, test_labels = load_vntc_data(test_dir)
135
- load_time = time.perf_counter() - t0
136
-
137
- click.echo(f" Train samples: {len(train_texts)}")
138
- click.echo(f" Test samples: {len(test_texts)}")
139
- click.echo(f" Categories: {len(set(train_labels))}")
140
- click.echo(f" Load time: {load_time:.2f}s")
141
-
142
- # Train
143
- click.echo("\nTraining Rust TextClassifier...")
144
- clf = TextClassifier(
145
- max_features=max_features,
146
- ngram_range=(ngram_min, ngram_max),
147
- min_df=min_df,
148
- c=c,
149
- max_iter=max_iter,
150
- tol=tol,
151
- )
152
-
153
- t0 = time.perf_counter()
154
- clf.fit(train_texts, train_labels)
155
- train_time = time.perf_counter() - t0
156
- click.echo(f" Training time: {train_time:.2f}s")
157
- click.echo(f" Vocabulary size: {clf.n_features}")
158
-
159
- # Evaluate
160
- click.echo("\nEvaluating...")
161
- t0 = time.perf_counter()
162
- preds = clf.predict_batch(test_texts)
163
- infer_time = time.perf_counter() - t0
164
- throughput = len(test_texts) / infer_time
165
-
166
- acc = accuracy_score(test_labels, preds)
167
- f1_w = f1_score(test_labels, preds, average='weighted')
168
- f1_m = f1_score(test_labels, preds, average='macro')
169
-
170
- click.echo(f" Inference: {infer_time:.3f}s ({throughput:.0f} samples/sec)")
171
-
172
- click.echo("\n" + "=" * 70)
173
- click.echo("RESULTS")
174
- click.echo("=" * 70)
175
- click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
176
- click.echo(f" F1 (weighted): {f1_w:.4f}")
177
- click.echo(f" F1 (macro): {f1_m:.4f}")
178
-
179
- click.echo("\nClassification Report:")
180
- click.echo(classification_report(test_labels, preds))
181
-
182
- # Save model
183
- model_path = Path(output)
184
- model_path.parent.mkdir(parents=True, exist_ok=True)
185
- clf.save(str(model_path))
186
-
187
- size_mb = model_path.stat().st_size / (1024 * 1024)
188
- click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
189
-
190
-
191
- @cli.command()
192
- @click.option('--output', '-o', default='models/sen-bank.bin', help='Output model path')
193
- @click.option('--max-features', default=10000, help='Maximum vocabulary size')
194
- @click.option('--ngram-min', default=1, help='Minimum n-gram')
195
- @click.option('--ngram-max', default=2, help='Maximum n-gram')
196
- @click.option('--min-df', default=1, help='Minimum document frequency')
197
- @click.option('--c', default=1.0, help='SVM regularization parameter')
198
- @click.option('--max-iter', default=1000, help='Maximum iterations')
199
- @click.option('--tol', default=0.1, help='Convergence tolerance')
200
- def bank(output, max_features, ngram_min, ngram_max, min_df, c, max_iter, tol):
201
- """Train on UTS2017_Bank dataset (14 categories, banking domain)."""
202
- from datasets import load_dataset
203
-
204
- click.echo("=" * 70)
205
- click.echo("UTS2017_Bank Dataset Training (14 Categories)")
206
- click.echo("=" * 70)
207
-
208
- # Load data
209
- click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
210
- dataset = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
211
-
212
- train_texts = list(dataset["train"]["text"])
213
- train_labels = list(dataset["train"]["label"])
214
- test_texts = list(dataset["test"]["text"])
215
- test_labels = list(dataset["test"]["label"])
216
-
217
- click.echo(f" Train samples: {len(train_texts)}")
218
- click.echo(f" Test samples: {len(test_texts)}")
219
- click.echo(f" Categories: {len(set(train_labels))}")
220
-
221
- # Train
222
- click.echo("\nTraining Rust TextClassifier...")
223
- clf = TextClassifier(
224
- max_features=max_features,
225
- ngram_range=(ngram_min, ngram_max),
226
- min_df=min_df,
227
- c=c,
228
- max_iter=max_iter,
229
- tol=tol,
230
- )
231
-
232
- t0 = time.perf_counter()
233
- clf.fit(train_texts, train_labels)
234
- train_time = time.perf_counter() - t0
235
- click.echo(f" Training time: {train_time:.3f}s")
236
- click.echo(f" Vocabulary size: {clf.n_features}")
237
-
238
- # Evaluate
239
- click.echo("\nEvaluating...")
240
- preds = clf.predict_batch(test_texts)
241
-
242
- acc = accuracy_score(test_labels, preds)
243
- f1_w = f1_score(test_labels, preds, average='weighted')
244
- f1_m = f1_score(test_labels, preds, average='macro')
245
-
246
- click.echo("\n" + "=" * 70)
247
- click.echo("RESULTS")
248
- click.echo("=" * 70)
249
- click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
250
- click.echo(f" F1 (weighted): {f1_w:.4f}")
251
- click.echo(f" F1 (macro): {f1_m:.4f}")
252
-
253
- click.echo("\nClassification Report:")
254
- click.echo(classification_report(test_labels, preds))
255
-
256
- # Save model
257
- model_path = Path(output)
258
- model_path.parent.mkdir(parents=True, exist_ok=True)
259
- clf.save(str(model_path))
260
-
261
- size_mb = model_path.stat().st_size / (1024 * 1024)
262
- click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
263
-
264
-
265
- def _load_vlsp2016(data_dir):
266
  """Load VLSP2016 sentiment data from directory."""
267
  label_map = {'POS': 'positive', 'NEG': 'negative', 'NEU': 'neutral'}
268
  texts, labels = [], []
@@ -283,209 +103,142 @@ def _load_vlsp2016(data_dir):
283
  return texts[0], labels[0], texts[1], labels[1]
284
 
285
 
286
- @cli.command('sentiment-general')
287
- @click.option('--output', '-o', default=None, help='Output model path')
288
- @click.option('--vlsp2016-dir', default=None, help='Path to VLSP2016_SA directory (adds to training data)')
289
- @click.option('--max-features', default=200000, help='Maximum vocabulary size')
290
- @click.option('--ngram-min', default=1, help='Minimum n-gram')
291
- @click.option('--ngram-max', default=3, help='Maximum n-gram')
292
- @click.option('--min-df', default=1, help='Minimum document frequency')
293
- @click.option('--max-df', default=0.9, help='Maximum document frequency')
294
- @click.option('--c', default=0.7, help='SVM regularization parameter')
295
- @click.option('--max-iter', default=1000, help='Maximum iterations')
296
- @click.option('--tol', default=0.0001, help='Convergence tolerance')
297
- def sentiment_general(output, vlsp2016_dir, max_features, ngram_min, ngram_max, min_df, max_df, c, max_iter, tol):
298
- """Train sentiment-general model (3 classes: positive/negative/neutral).
299
-
300
- Uses UTS2017_Bank sentiment data by default. Optionally adds VLSP2016 data
301
- with --vlsp2016-dir for improved general-domain coverage.
302
- """
303
- from datetime import datetime
304
- from datasets import load_dataset
305
-
306
- if output is None:
307
- date_str = datetime.now().strftime('%Y%m%d')
308
- output = f'models/sen-sentiment-general-1.0.0-{date_str}.bin'
309
-
310
- click.echo("=" * 70)
311
- click.echo("Sentiment General Training (positive/negative/neutral)")
312
- click.echo("=" * 70)
313
-
314
- # Load UTS2017_Bank sentiment data
315
- click.echo("\nLoading UTS2017_Bank sentiment dataset from HuggingFace...")
316
- dataset = load_dataset("undertheseanlp/UTS2017_Bank", "sentiment")
317
-
318
- train_texts = list(dataset["train"]["text"])
319
- train_labels = list(dataset["train"]["sentiment"])
320
- test_texts = list(dataset["test"]["text"])
321
- test_labels = list(dataset["test"]["sentiment"])
322
-
323
- vlsp_test_texts, vlsp_test_labels = None, None
324
-
325
- # Optionally add VLSP2016 data
326
- if vlsp2016_dir:
327
- click.echo(f"\nLoading VLSP2016 data from {vlsp2016_dir}...")
328
- vlsp_train_texts, vlsp_train_labels, vlsp_test_texts, vlsp_test_labels = _load_vlsp2016(vlsp2016_dir)
329
- train_texts.extend(vlsp_train_texts)
330
- train_labels.extend(vlsp_train_labels)
331
- click.echo(f" VLSP2016 train: {len(vlsp_train_texts)}, test: {len(vlsp_test_texts)}")
332
-
333
- click.echo(f" Total train samples: {len(train_texts)}")
334
- click.echo(f" UTS2017 test samples: {len(test_texts)}")
335
- click.echo(f" Labels: {sorted(set(train_labels))}")
336
-
337
- # Preprocess
338
- click.echo("\nPreprocessing...")
339
- proc_train = [preprocess_sentiment(t) for t in train_texts]
340
- proc_test = [preprocess_sentiment(t) for t in test_texts]
341
-
342
- # Train
343
- click.echo("\nTraining Rust TextClassifier...")
344
- clf = TextClassifier(
345
- max_features=max_features,
346
- ngram_range=(ngram_min, ngram_max),
347
- min_df=min_df,
348
- max_df=max_df,
349
- c=c,
350
- max_iter=max_iter,
351
- tol=tol,
352
- )
 
 
 
 
 
353
 
 
 
354
  t0 = time.perf_counter()
355
- clf.fit(proc_train, train_labels)
356
- train_time = time.perf_counter() - t0
357
- click.echo(f" Training time: {train_time:.3f}s")
358
- click.echo(f" Vocabulary size: {clf.n_features}")
359
 
360
- # Evaluate on UTS2017
361
- click.echo("\nEvaluating on UTS2017_Bank test set...")
362
- preds = clf.predict_batch(proc_test)
 
363
 
364
- acc = accuracy_score(test_labels, preds)
365
- f1_w = f1_score(test_labels, preds, average='weighted', zero_division=0)
366
- f1_m = f1_score(test_labels, preds, average='macro', zero_division=0)
 
 
 
367
 
368
- click.echo("\n" + "=" * 70)
369
- click.echo("RESULTS (UTS2017_Bank)")
370
- click.echo("=" * 70)
371
- click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
372
- click.echo(f" F1 (weighted): {f1_w:.4f}")
373
- click.echo(f" F1 (macro): {f1_m:.4f}")
374
- click.echo("\nClassification Report:")
375
- click.echo(classification_report(test_labels, preds, zero_division=0))
376
-
377
- # Evaluate on VLSP2016 if available
378
- if vlsp_test_texts:
379
- proc_vlsp_test = [preprocess_sentiment(t) for t in vlsp_test_texts]
380
- vlsp_preds = clf.predict_batch(proc_vlsp_test)
381
- vlsp_acc = accuracy_score(vlsp_test_labels, vlsp_preds)
382
- vlsp_f1w = f1_score(vlsp_test_labels, vlsp_preds, average='weighted', zero_division=0)
383
- vlsp_f1m = f1_score(vlsp_test_labels, vlsp_preds, average='macro', zero_division=0)
384
-
385
- click.echo("=" * 70)
386
- click.echo("RESULTS (VLSP2016)")
387
- click.echo("=" * 70)
388
- click.echo(f" Accuracy: {vlsp_acc:.4f} ({vlsp_acc*100:.2f}%)")
389
- click.echo(f" F1 (weighted): {vlsp_f1w:.4f}")
390
- click.echo(f" F1 (macro): {vlsp_f1m:.4f}")
391
- click.echo("\nClassification Report:")
392
- click.echo(classification_report(vlsp_test_labels, vlsp_preds, zero_division=0))
393
 
394
- # Save model
395
- model_path = Path(output)
396
- model_path.parent.mkdir(parents=True, exist_ok=True)
397
- clf.save(str(model_path))
398
 
399
- size_mb = model_path.stat().st_size / (1024 * 1024)
400
- click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
401
-
402
-
403
- @cli.command('sentiment-bank')
404
- @click.option('--output', '-o', default=None, help='Output model path')
405
- @click.option('--max-features', default=200000, help='Maximum vocabulary size')
406
- @click.option('--ngram-min', default=1, help='Minimum n-gram')
407
- @click.option('--ngram-max', default=3, help='Maximum n-gram')
408
- @click.option('--min-df', default=1, help='Minimum document frequency')
409
- @click.option('--max-df', default=0.9, help='Maximum document frequency')
410
- @click.option('--c', default=0.7, help='SVM regularization parameter')
411
- @click.option('--max-iter', default=1000, help='Maximum iterations')
412
- @click.option('--tol', default=0.0001, help='Convergence tolerance')
413
- def sentiment_bank(output, max_features, ngram_min, ngram_max, min_df, max_df, c, max_iter, tol):
414
- """Train sentiment-bank model on UTS2017_Bank (36 combined category#sentiment labels)."""
415
- from datetime import datetime
416
- from datasets import load_dataset
417
-
418
- if output is None:
419
- date_str = datetime.now().strftime('%Y%m%d')
420
- output = f'models/sen-sentiment-bank-1.0.0-{date_str}.bin'
421
-
422
- click.echo("=" * 70)
423
- click.echo("Sentiment Bank Training (category#sentiment, 36 labels)")
424
- click.echo("=" * 70)
425
-
426
- # Load and merge classification + sentiment configs
427
- click.echo("\nLoading UTS2017_Bank dataset from HuggingFace...")
428
- ds_class = load_dataset("undertheseanlp/UTS2017_Bank", "classification")
429
- ds_sent = load_dataset("undertheseanlp/UTS2017_Bank", "sentiment")
430
-
431
- train_texts = list(ds_class["train"]["text"])
432
- train_labels = [f'{c}#{s}' for c, s in zip(ds_class["train"]["label"], ds_sent["train"]["sentiment"])]
433
- test_texts = list(ds_class["test"]["text"])
434
- test_labels = [f'{c}#{s}' for c, s in zip(ds_class["test"]["label"], ds_sent["test"]["sentiment"])]
435
-
436
- click.echo(f" Train samples: {len(train_texts)}")
437
- click.echo(f" Test samples: {len(test_texts)}")
438
- click.echo(f" Labels: {len(set(train_labels))}")
439
-
440
- # Preprocess
441
- click.echo("\nPreprocessing...")
442
- proc_train = [preprocess_sentiment(t) for t in train_texts]
443
- proc_test = [preprocess_sentiment(t) for t in test_texts]
444
-
445
- # Train
446
- click.echo("\nTraining Rust TextClassifier...")
447
  clf = TextClassifier(
448
- max_features=max_features,
449
- ngram_range=(ngram_min, ngram_max),
450
- min_df=min_df,
451
- max_df=max_df,
452
- c=c,
453
- max_iter=max_iter,
454
- tol=tol,
 
455
  )
456
 
457
  t0 = time.perf_counter()
458
- clf.fit(proc_train, train_labels)
459
  train_time = time.perf_counter() - t0
460
- click.echo(f" Training time: {train_time:.3f}s")
461
- click.echo(f" Vocabulary size: {clf.n_features}")
462
-
463
- # Evaluate
464
- click.echo("\nEvaluating...")
465
- preds = clf.predict_batch(proc_test)
466
 
467
- acc = accuracy_score(test_labels, preds)
468
- f1_w = f1_score(test_labels, preds, average='weighted', zero_division=0)
469
- f1_m = f1_score(test_labels, preds, average='macro', zero_division=0)
470
-
471
- click.echo("\n" + "=" * 70)
472
- click.echo("RESULTS")
473
- click.echo("=" * 70)
474
- click.echo(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
475
- click.echo(f" F1 (weighted): {f1_w:.4f}")
476
- click.echo(f" F1 (macro): {f1_m:.4f}")
477
 
478
- click.echo("\nClassification Report:")
479
- click.echo(classification_report(test_labels, preds, zero_division=0))
 
 
480
 
481
  # Save model
 
482
  model_path = Path(output)
483
  model_path.parent.mkdir(parents=True, exist_ok=True)
484
  clf.save(str(model_path))
485
 
486
  size_mb = model_path.stat().st_size / (1024 * 1024)
487
- click.echo(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
488
 
489
 
490
  if __name__ == "__main__":
491
- cli()
 
1
  """
2
+ Training CLI for Vietnamese Text Classification using Hydra.
3
 
4
  Usage:
5
+ python src/train.py --config-name=vntc
6
+ python src/train.py --config-name=sentiment_general
7
+ python src/train.py --config-name=sentiment_bank
8
+ python src/train.py --config-name=bank
9
+
10
+ Override params from CLI:
11
+ python src/train.py --config-name=sentiment_general model.c=0.5 model.max_features=100000
12
+ python src/train.py --config-name=vntc preprocessor=sentiment
13
+ python src/train.py --config-name=sentiment_general data.vlsp2016_dir=/path/to/VLSP2016_SA
14
  """
15
 
16
  import os
 
17
  import time
18
+ import logging
19
  from pathlib import Path
20
 
21
+ import hydra
22
+ from omegaconf import DictConfig, OmegaConf
23
  from sklearn.metrics import accuracy_score, f1_score, classification_report
24
 
25
+ from underthesea import TextClassifier, TextPreprocessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ log = logging.getLogger(__name__)
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Preprocessor
32
+ # ---------------------------------------------------------------------------
33
+
34
+ def build_preprocessor(pp_cfg):
35
+ """Build a Rust TextPreprocessor from model.preprocessor config."""
36
+ teencode = dict(pp_cfg.get("teencode", {})) or None
37
+ neg_words = list(pp_cfg.get("negation_words", [])) or None
38
+ neg_window = pp_cfg.get("negation_window", 2)
39
+
40
+ return TextPreprocessor(
41
+ lowercase=pp_cfg.get("lowercase", True),
42
+ unicode_normalize=pp_cfg.get("unicode_normalize", True),
43
+ remove_urls=pp_cfg.get("remove_urls", True),
44
+ normalize_repeated_chars=pp_cfg.get("normalize_repeated_chars", True),
45
+ normalize_punctuation=pp_cfg.get("normalize_punctuation", True),
46
+ teencode=teencode,
47
+ negation_words=neg_words,
48
+ negation_window=neg_window,
49
+ )
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Data loaders
54
+ # ---------------------------------------------------------------------------
55
 
56
  def read_file(filepath):
57
  """Read text file with multiple encoding attempts."""
 
69
  def load_vntc_data(data_dir):
70
  """Load VNTC data from directory."""
71
  texts, labels = [], []
 
72
  for folder in sorted(os.listdir(data_dir)):
73
  folder_path = os.path.join(data_dir, folder)
74
  if not os.path.isdir(folder_path):
75
  continue
 
76
  for fname in os.listdir(folder_path):
77
  if fname.endswith('.txt'):
78
  text = read_file(os.path.join(folder_path, fname))
79
  if text:
80
  texts.append(text)
81
  labels.append(folder)
 
82
  return texts, labels
83
 
84
 
85
+ def load_vlsp2016(data_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  """Load VLSP2016 sentiment data from directory."""
87
  label_map = {'POS': 'positive', 'NEG': 'negative', 'NEU': 'neutral'}
88
  texts, labels = [], []
 
103
  return texts[0], labels[0], texts[1], labels[1]
104
 
105
 
106
+ def load_data(cfg):
107
+ """Load train/test data based on Hydra data config."""
108
+ data_cfg = cfg.data
109
+ name = data_cfg.name
110
+ extra_test = {}
111
+
112
+ if name == "vntc":
113
+ train_texts, train_labels = load_vntc_data(
114
+ os.path.join(data_cfg.data_dir, "Train_Full"))
115
+ test_texts, test_labels = load_vntc_data(
116
+ os.path.join(data_cfg.data_dir, "Test_Full"))
117
+
118
+ elif name == "bank":
119
+ from datasets import load_dataset
120
+ dataset = load_dataset(data_cfg.dataset, data_cfg.config)
121
+ train_texts = list(dataset["train"]["text"])
122
+ train_labels = list(dataset["train"]["label"])
123
+ test_texts = list(dataset["test"]["text"])
124
+ test_labels = list(dataset["test"]["label"])
125
+
126
+ elif name == "sentiment_general":
127
+ train_texts, train_labels, test_texts, test_labels = load_vlsp2016(
128
+ data_cfg.data_dir)
129
+
130
+ elif name == "sentiment_bank":
131
+ from datasets import load_dataset
132
+ ds_class = load_dataset(data_cfg.dataset, "classification")
133
+ ds_sent = load_dataset(data_cfg.dataset, "sentiment")
134
+ train_texts = list(ds_class["train"]["text"])
135
+ train_labels = [f'{c}#{s}' for c, s in
136
+ zip(ds_class["train"]["label"], ds_sent["train"]["sentiment"])]
137
+ test_texts = list(ds_class["test"]["text"])
138
+ test_labels = [f'{c}#{s}' for c, s in
139
+ zip(ds_class["test"]["label"], ds_sent["test"]["sentiment"])]
140
+ else:
141
+ raise ValueError(f"Unknown data: {name}")
142
+
143
+ return train_texts, train_labels, test_texts, test_labels, extra_test
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Evaluate
148
+ # ---------------------------------------------------------------------------
149
+
150
+ def evaluate(test_labels, preds, name=""):
151
+ """Print evaluation metrics."""
152
+ acc = accuracy_score(test_labels, preds)
153
+ f1_w = f1_score(test_labels, preds, average='weighted', zero_division=0)
154
+ f1_m = f1_score(test_labels, preds, average='macro', zero_division=0)
155
+
156
+ header = f"RESULTS ({name})" if name else "RESULTS"
157
+ log.info("=" * 70)
158
+ log.info(header)
159
+ log.info("=" * 70)
160
+ log.info(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
161
+ log.info(f" F1 (weighted): {f1_w:.4f}")
162
+ log.info(f" F1 (macro): {f1_m:.4f}")
163
+ log.info("\n" + classification_report(test_labels, preds, zero_division=0))
164
+ return acc
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # Main
169
+ # ---------------------------------------------------------------------------
170
+
171
+ @hydra.main(version_base=None, config_path="conf", config_name="config")
172
+ def train(cfg: DictConfig):
173
+ """Train Vietnamese text classification model."""
174
+ log.info("=" * 70)
175
+ log.info(f"Training: {cfg.data.name}")
176
+ log.info("=" * 70)
177
+ log.info(f"\nConfig:\n{OmegaConf.to_yaml(cfg)}")
178
 
179
+ # Load data
180
+ log.info("Loading data...")
181
  t0 = time.perf_counter()
182
+ train_texts, train_labels, test_texts, test_labels, extra_test = load_data(cfg)
183
+ load_time = time.perf_counter() - t0
 
 
184
 
185
+ log.info(f" Train samples: {len(train_texts)}")
186
+ log.info(f" Test samples: {len(test_texts)}")
187
+ log.info(f" Labels: {len(set(train_labels))}")
188
+ log.info(f" Load time: {load_time:.2f}s")
189
 
190
+ # Build preprocessor — model.preprocess=true activates model.preprocessor config
191
+ # Preprocessor is passed to TextClassifier and packed into the .bin model
192
+ preprocessor = None
193
+ if cfg.model.get("preprocess", False):
194
+ preprocessor = build_preprocessor(cfg.model.preprocessor)
195
+ log.info(f"\nPreprocessor: {preprocessor}")
196
 
197
+ # Build classifier from config
198
+ model_cfg = cfg.model
199
+ ngram_range = tuple(model_cfg.ngram_range)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ log.info("\nTraining TextClassifier...")
202
+ log.info(f" max_features={model_cfg.max_features}, ngram_range={ngram_range}, "
203
+ f"max_df={model_cfg.max_df}, C={model_cfg.c}")
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  clf = TextClassifier(
206
+ max_features=model_cfg.max_features,
207
+ ngram_range=ngram_range,
208
+ min_df=model_cfg.min_df,
209
+ max_df=model_cfg.max_df,
210
+ c=model_cfg.c,
211
+ max_iter=model_cfg.max_iter,
212
+ tol=model_cfg.tol,
213
+ preprocessor=preprocessor,
214
  )
215
 
216
  t0 = time.perf_counter()
217
+ clf.fit(train_texts, train_labels)
218
  train_time = time.perf_counter() - t0
219
+ log.info(f" Training time: {train_time:.3f}s")
220
+ log.info(f" Vocabulary size: {clf.n_features}")
 
 
 
 
221
 
222
+ # Evaluate on primary test set
223
+ # TextClassifier auto-preprocesses via its built-in preprocessor
224
+ log.info("\nEvaluating...")
225
+ preds = clf.predict_batch(test_texts)
226
+ evaluate(test_labels, preds, cfg.data.name)
 
 
 
 
 
227
 
228
+ # Evaluate on extra test sets (e.g. VLSP2016)
229
+ for name, (et_texts, et_labels) in extra_test.items():
230
+ et_preds = clf.predict_batch(et_texts)
231
+ evaluate(et_labels, et_preds, name)
232
 
233
  # Save model
234
+ output = cfg.output
235
  model_path = Path(output)
236
  model_path.parent.mkdir(parents=True, exist_ok=True)
237
  clf.save(str(model_path))
238
 
239
  size_mb = model_path.stat().st_size / (1024 * 1024)
240
+ log.info(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
241
 
242
 
243
  if __name__ == "__main__":
244
+ train()