Vu Anh Claude commited on
Commit
f62e707
·
1 Parent(s): a8bf530

Add dual dataset support and update system documentation

Browse files

- Enhanced train.py to support both VNTC and UTS2017_Bank datasets
- Added --dataset parameter for dataset selection
- Updated System Card with UTS2017_Bank performance metrics
- Achieved 92.33% accuracy on VNTC and 70.96% on UTS2017_Bank
- Cleaned up redundant files and old training runs

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. Sonar Core 1 - System Card.md +126 -24
  2. train.py +114 -12
Sonar Core 1 - System Card.md CHANGED
@@ -6,17 +6,22 @@
6
 
7
  # Changelog
8
 
 
 
 
 
 
9
  **2025-09-21**
10
 
11
  - Initial release of Sonar Core 1
12
 
13
  # Abstract
14
 
15
- **Sonar Core 1** is a machine learning-based text classification model designed for Vietnamese language processing. Built on a **TF-IDF** (Term Frequency-Inverse Document Frequency) feature extraction pipeline combined with **Logistic Regression**, this model achieves **92.33% accuracy** on the VNTC (Vietnamese Text Classification) dataset across **10 news categories**. The model is specifically designed for Vietnamese news article classification, content categorization for Vietnamese text, and document organization and tagging. Developed as a base model to provide quick and reliable text classification support for **scikit-learn >=1.6** integration since **underthesea 8.1.0**, it employs optimized feature engineering with **20,000 max features** and bigram support, along with a hash-based caching system for efficient processing. This system card provides comprehensive documentation of the model's architecture, performance metrics, intended uses, and limitations.
16
 
17
  # 1. Model Details
18
 
19
- **Sonar Core 1** is a Vietnamese text classification model built on **scikit-learn >=1.6**, utilizing a TF-IDF pipeline with Logistic Regression to classify text across 10 news categories. The architecture employs:
20
  - CountVectorizer with **20,000 max features** (optimized from the initial 10,000)
21
  - N-gram extraction: unigram and bigram support
22
  - TF-IDF transformation with IDF weighting
@@ -27,7 +32,7 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
27
 
28
  # 2. Training Data
29
 
30
- ## 2.1 Supported Categories (10 classes)
31
  1. **chinh_tri_xa_hoi** - Politics and Society
32
  2. **doi_song** - Lifestyle
33
  3. **khoa_hoc** - Science
@@ -39,27 +44,51 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
39
  9. **van_hoa** - Culture
40
  10. **vi_tinh** - Information Technology
41
 
42
- ## 2.2 Dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  - **Name**: VNTC (Vietnamese Text Classification) Dataset
44
  - **Training Samples**: 33,759 documents
45
  - **Test Samples**: 50,373 documents
46
  - **Language**: Vietnamese
47
  - **Format**: FastText format (__label__category followed by text)
 
 
48
 
49
- ## 2.3 Data Distribution
50
- - Balanced across 10 news categories
51
- - Text preprocessing: None (raw Vietnamese text)
52
- - Average document length: ~200-500 words
 
 
 
 
53
 
54
  # 3. Performance Metrics
55
 
56
- ## 3.1 Overall Performance (Latest Results - 2025-09-21)
57
  - **Training Accuracy**: 95.39%
58
  - **Test Accuracy**: 92.33%
59
  - **Training Time**: ~27.18 seconds (with caching system)
60
  - **Inference Time**: ~19.34 seconds for 50,373 samples
61
 
62
- ## 3.2 Per-Class Performance (Latest Run - All 10 Classes)
63
  | Category | Precision | Recall | F1-Score | Support |
64
  |----------|-----------|---------|-----------|---------|
65
  | chinh_tri_xa_hoi | 0.86 | 0.93 | 0.89 | 7,567 |
@@ -73,34 +102,82 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
73
  | van_hoa | 0.93 | 0.95 | 0.94 | 6,250 |
74
  | vi_tinh | 0.94 | 0.95 | 0.94 | 4,560 |
75
 
76
- ## 3.3 Aggregate Metrics (Latest Run)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  - **Overall Accuracy**: 92%
78
  - **Macro Average**: Precision: 0.91, Recall: 0.90, F1: 0.91
79
  - **Weighted Average**: Precision: 0.92, Recall: 0.92, F1: 0.92
80
 
81
- ## 3.4 Performance Analysis
 
 
 
 
 
 
 
82
  - **Best Performing Categories**: Sports (the_thao) achieves 98% F1-score, followed by Health, World, Culture, and IT (all 94% F1-score)
83
  - **Lowest Performing Category**: Lifestyle (doi_song) with 76% F1-score due to lower recall (71%)
 
 
 
 
 
 
 
84
  - **Feature Count**: Uses 20,000 max features with bigram support
85
  - **Caching System**: Hash-based caching for efficient vectorizer and TF-IDF processing
 
86
 
87
  # 4. Limitations
88
 
89
  ## 4.1 Known Limitations
90
  1. **Language Specificity**: Only works with Vietnamese text
91
- 2. **Domain Specificity**: Optimized for news articles, may not perform well on:
92
- - Social media posts
93
- - Technical documentation
94
  - Conversational text
95
  3. **Feature Limitations**:
96
  - Limited to 20,000 most frequent features
97
  - May miss rare but important terms
98
- 4. **Class Confusion**: Lower performance on lifestyle (doi_song) category (71% recall)
 
 
 
 
 
99
 
100
  ## 4.2 Biases
101
- - Trained on news articles which may have formal writing style bias
102
- - May reflect biases present in the original VNTC dataset
103
- - Performance varies across categories (best on sports at 98% F1-score, weakest on lifestyle at 76% F1-score)
 
 
104
 
105
  # 5. Future Improvements
106
 
@@ -109,6 +186,11 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
109
  3. Add support for longer documents
110
  4. Implement confidence thresholds for uncertain predictions
111
  5. Fine-tune on domain-specific data if needed
 
 
 
 
 
112
 
113
  # 6. Usage
114
 
@@ -118,8 +200,26 @@ pip install scikit-learn>=1.6 joblib
118
  ```
119
 
120
  ## 6.2 Training
 
 
121
  ```bash
 
122
  uv run --no-project --with 'scikit-learn>=1.6' python train.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  ```
124
 
125
  ## 6.3 Inference
@@ -151,16 +251,18 @@ probabilities = model.predict_proba([text])[0]
151
 
152
  1. VNTC Dataset: Hoang, Cong Duy Vu, Dien Dinh, Le Nguyen Nguyen, and Quoc Hung Ngo. (2007). A Comparative Study on Vietnamese Text Classification Methods. In Proceedings of IEEE International Conference on Research, Innovation and Vision for the Future (RIVF 2007), pp. 267-273. IEEE. DOI: 10.1109/RIVF.2007.369167
153
 
154
- 2. TF-IDF (Term Frequency-Inverse Document Frequency): Salton, Gerard, and Michael J. McGill. (1983). Introduction to Modern Information Retrieval. McGraw-Hill, New York. ISBN: 978-0070544840
 
 
155
 
156
- 3. Logistic Regression for Text Classification: Hastie, Trevor, Robert Tibshirani, and Jerome Friedman. (2009). The Elements of Statistical Learning: Data Mining, Inference, and Prediction (2nd ed.). Springer Series in Statistics. Springer, New York. DOI: 10.1007/978-0-387-84858-7
157
 
158
- 4. Scikit-learn: Pedregosa, Fabian, Gaël Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, Jake Vanderplas, Alexandre Passos, David Cournapeau, Matthieu Brucher, Matthieu Perrot, and Édouard Duchesnay. (2011). Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research, 12(85), 2825-2830. Retrieved from https://www.jmlr.org/papers/v12/pedregosa11a.html
159
 
160
- 5. N-gram Language Models: Brown, Peter F., Vincent J. Della Pietra, Peter V. deSouza, Jenifer C. Lai, and Robert L. Mercer. (1992). Class-Based n-gram Models of Natural Language. Computational Linguistics, 18(4), 467-480. Retrieved from https://aclanthology.org/J92-4003/
161
 
162
  # License
163
- Model trained on publicly available VNTC dataset. Please refer to original dataset license for usage terms.
164
 
165
  # Citation
166
 
 
6
 
7
  # Changelog
8
 
9
+ **2025-09-27**
10
+
11
+ - Added support for UTS2017_Bank Vietnamese banking text classification dataset
12
+ - Achieved 70.96% accuracy on 14 banking service categories
13
+
14
  **2025-09-21**
15
 
16
  - Initial release of Sonar Core 1
17
 
18
  # Abstract
19
 
20
+ **Sonar Core 1** is a machine learning-based text classification model designed for Vietnamese language processing. Built on a **TF-IDF** (Term Frequency-Inverse Document Frequency) feature extraction pipeline combined with **Logistic Regression**, this model achieves **92.33% accuracy** on the VNTC (Vietnamese Text Classification) dataset across **10 news categories** and **70.96% accuracy** on the UTS2017_Bank dataset across **14 banking service categories**. The model is specifically designed for Vietnamese news article classification, banking text categorization, content categorization for Vietnamese text, and document organization and tagging. Developed as a base model to provide quick and reliable text classification support for **scikit-learn >=1.6** integration since **underthesea 8.1.0**, it employs optimized feature engineering with **20,000 max features** and bigram support, along with a hash-based caching system for efficient processing. This system card provides comprehensive documentation of the model's architecture, performance metrics, intended uses, and limitations.
21
 
22
  # 1. Model Details
23
 
24
+ **Sonar Core 1** is a Vietnamese text classification model built on **scikit-learn >=1.6**, utilizing a TF-IDF pipeline with Logistic Regression to classify text across multiple domains including news categories and banking services. The architecture employs:
25
  - CountVectorizer with **20,000 max features** (optimized from the initial 10,000)
26
  - N-gram extraction: unigram and bigram support
27
  - TF-IDF transformation with IDF weighting
 
32
 
33
  # 2. Training Data
34
 
35
+ ## 2.1 VNTC Dataset - News Categories (10 classes)
36
  1. **chinh_tri_xa_hoi** - Politics and Society
37
  2. **doi_song** - Lifestyle
38
  3. **khoa_hoc** - Science
 
44
  9. **van_hoa** - Culture
45
  10. **vi_tinh** - Information Technology
46
 
47
+ ## 2.2 UTS2017_Bank Dataset - Banking Categories (14 classes)
48
+ 1. **ACCOUNT** - Account services
49
+ 2. **CARD** - Card services
50
+ 3. **CUSTOMER_SUPPORT** - Customer support
51
+ 4. **DISCOUNT** - Discount offers
52
+ 5. **INTEREST_RATE** - Interest rate information
53
+ 6. **INTERNET_BANKING** - Internet banking services
54
+ 7. **LOAN** - Loan services
55
+ 8. **MONEY_TRANSFER** - Money transfer services
56
+ 9. **OTHER** - Other services
57
+ 10. **PAYMENT** - Payment services
58
+ 11. **PROMOTION** - Promotional offers
59
+ 12. **SAVING** - Savings accounts
60
+ 13. **SECURITY** - Security features
61
+ 14. **TRADEMARK** - Trademark/branding
62
+
63
+ ## 2.3 Dataset Details
64
+
65
+ ### VNTC Dataset
66
  - **Name**: VNTC (Vietnamese Text Classification) Dataset
67
  - **Training Samples**: 33,759 documents
68
  - **Test Samples**: 50,373 documents
69
  - **Language**: Vietnamese
70
  - **Format**: FastText format (__label__category followed by text)
71
+ - **Distribution**: Balanced across 10 news categories
72
+ - **Average document length**: ~200-500 words
73
 
74
+ ### UTS2017_Bank Dataset
75
+ - **Name**: UTS2017_Bank Classification Dataset
76
+ - **Training Samples**: 1,581 documents
77
+ - **Test Samples**: 396 documents
78
+ - **Language**: Vietnamese
79
+ - **Format**: Text with categorical labels
80
+ - **Distribution**: Imbalanced (CUSTOMER_SUPPORT: 39%, TRADEMARK: 35%, others: 26%)
81
+ - **Text preprocessing**: None (raw Vietnamese text)
82
 
83
  # 3. Performance Metrics
84
 
85
+ ## 3.1 VNTC Dataset Performance (2025-09-21)
86
  - **Training Accuracy**: 95.39%
87
  - **Test Accuracy**: 92.33%
88
  - **Training Time**: ~27.18 seconds (with caching system)
89
  - **Inference Time**: ~19.34 seconds for 50,373 samples
90
 
91
+ ## 3.2 Per-Class Performance - VNTC Dataset
92
  | Category | Precision | Recall | F1-Score | Support |
93
  |----------|-----------|---------|-----------|---------|
94
  | chinh_tri_xa_hoi | 0.86 | 0.93 | 0.89 | 7,567 |
 
102
  | van_hoa | 0.93 | 0.95 | 0.94 | 6,250 |
103
  | vi_tinh | 0.94 | 0.95 | 0.94 | 4,560 |
104
 
105
+ ## 3.3 UTS2017_Bank Dataset Performance (2025-09-27)
106
+ - **Training Accuracy**: 76.22%
107
+ - **Test Accuracy**: 70.96%
108
+ - **Training Time**: ~0.78 seconds
109
+ - **Inference Time**: ~0.01 seconds for 396 samples
110
+
111
+ ## 3.4 Per-Class Performance - UTS2017_Bank Dataset
112
+ | Category | Precision | Recall | F1-Score | Support |
113
+ |----------|-----------|---------|-----------|---------|
114
+ | ACCOUNT | 0.00 | 0.00 | 0.00 | 1 |
115
+ | CARD | 0.00 | 0.00 | 0.00 | 13 |
116
+ | CUSTOMER_SUPPORT | 0.62 | 0.97 | 0.76 | 155 |
117
+ | DISCOUNT | 0.00 | 0.00 | 0.00 | 8 |
118
+ | INTEREST_RATE | 0.50 | 0.08 | 0.14 | 12 |
119
+ | INTERNET_BANKING | 0.00 | 0.00 | 0.00 | 14 |
120
+ | LOAN | 0.67 | 0.13 | 0.22 | 15 |
121
+ | MONEY_TRANSFER | 0.00 | 0.00 | 0.00 | 7 |
122
+ | OTHER | 0.50 | 0.07 | 0.12 | 14 |
123
+ | PAYMENT | 0.00 | 0.00 | 0.00 | 3 |
124
+ | PROMOTION | 1.00 | 0.18 | 0.31 | 11 |
125
+ | SAVING | 0.00 | 0.00 | 0.00 | 2 |
126
+ | SECURITY | 0.00 | 0.00 | 0.00 | 1 |
127
+ | TRADEMARK | 0.87 | 0.89 | 0.88 | 140 |
128
+
129
+ ## 3.5 Aggregate Metrics
130
+
131
+ ### VNTC Dataset
132
  - **Overall Accuracy**: 92%
133
  - **Macro Average**: Precision: 0.91, Recall: 0.90, F1: 0.91
134
  - **Weighted Average**: Precision: 0.92, Recall: 0.92, F1: 0.92
135
 
136
+ ### UTS2017_Bank Dataset
137
+ - **Overall Accuracy**: 71%
138
+ - **Macro Average**: Precision: 0.30, Recall: 0.17, F1: 0.17
139
+ - **Weighted Average**: Precision: 0.64, Recall: 0.71, F1: 0.63
140
+
141
+ ## 3.6 Performance Analysis
142
+
143
+ ### VNTC Dataset
144
  - **Best Performing Categories**: Sports (the_thao) achieves 98% F1-score, followed by Health, World, Culture, and IT (all 94% F1-score)
145
  - **Lowest Performing Category**: Lifestyle (doi_song) with 76% F1-score due to lower recall (71%)
146
+
147
+ ### UTS2017_Bank Dataset
148
+ - **Best Performing Categories**: TRADEMARK (88% F1-score) and CUSTOMER_SUPPORT (76% F1-score)
149
+ - **Challenges**: Many minority classes with insufficient training data result in zero predictions
150
+ - **Data Imbalance**: Significant class imbalance with CUSTOMER_SUPPORT and TRADEMARK dominating (74% of data)
151
+
152
+ ### General Observations
153
  - **Feature Count**: Uses 20,000 max features with bigram support
154
  - **Caching System**: Hash-based caching for efficient vectorizer and TF-IDF processing
155
+ - **Model performs better on balanced datasets** (VNTC) compared to imbalanced ones (UTS2017_Bank)
156
 
157
  # 4. Limitations
158
 
159
  ## 4.1 Known Limitations
160
  1. **Language Specificity**: Only works with Vietnamese text
161
+ 2. **Domain Specificity**: Optimized for specific domains, may not generalize well to:
162
+ - Social media posts (unless trained on specific datasets)
163
+ - Technical documentation outside IT/banking domains
164
  - Conversational text
165
  3. **Feature Limitations**:
166
  - Limited to 20,000 most frequent features
167
  - May miss rare but important terms
168
+ 4. **Class Imbalance Sensitivity**:
169
+ - Performance degrades significantly with imbalanced datasets
170
+ - Minority classes may receive zero predictions (as seen in UTS2017_Bank)
171
+ 5. **Specific Category Weaknesses**:
172
+ - VNTC: Lower performance on lifestyle (doi_song) category (71% recall)
173
+ - UTS2017_Bank: Poor performance on minority classes (ACCOUNT, CARD, PAYMENT, etc.)
174
 
175
  ## 4.2 Biases
176
+ - Trained on specific domains (news and banking) which may have formal writing style bias
177
+ - May reflect biases present in the original datasets
178
+ - Performance varies significantly across categories:
179
+ - VNTC: Best on sports at 98% F1-score, weakest on lifestyle at 76% F1-score
180
+ - UTS2017_Bank: Best on TRADEMARK at 88% F1-score, many categories at 0% F1-score
181
 
182
  # 5. Future Improvements
183
 
 
186
  3. Add support for longer documents
187
  4. Implement confidence thresholds for uncertain predictions
188
  5. Fine-tune on domain-specific data if needed
189
+ 6. Address class imbalance issues through:
190
+ - Oversampling minority classes
191
+ - Class weight adjustments
192
+ - Synthetic data generation (SMOTE)
193
+ 7. Expand to more Vietnamese text domains
194
 
195
  # 6. Usage
196
 
 
200
  ```
201
 
202
  ## 6.2 Training
203
+
204
+ ### VNTC Dataset (News Classification)
205
  ```bash
206
+ # Default training with VNTC dataset
207
  uv run --no-project --with 'scikit-learn>=1.6' python train.py
208
+
209
+ # With specific parameters
210
+ uv run --no-project --with 'scikit-learn>=1.6' python train.py --model logistic --max-features 20000
211
+ ```
212
+
213
+ ### UTS2017_Bank Dataset (Banking Text Classification)
214
+ ```bash
215
+ # Train with UTS2017_Bank dataset (assuming train.py is modified for UTS2017_Bank)
216
+ python train.py --model logistic
217
+
218
+ # With specific parameters
219
+ python train.py --model logistic --max-features 20000 --ngram-min 1 --ngram-max 2
220
+
221
+ # Compare multiple configurations
222
+ python train.py --compare
223
  ```
224
 
225
  ## 6.3 Inference
 
251
 
252
  1. VNTC Dataset: Hoang, Cong Duy Vu, Dien Dinh, Le Nguyen Nguyen, and Quoc Hung Ngo. (2007). A Comparative Study on Vietnamese Text Classification Methods. In Proceedings of IEEE International Conference on Research, Innovation and Vision for the Future (RIVF 2007), pp. 267-273. IEEE. DOI: 10.1109/RIVF.2007.369167
253
 
254
+ 2. UTS2017_Bank Dataset: Available from Hugging Face Datasets: https://huggingface.co/datasets/undertheseanlp/UTS2017_Bank
255
+
256
+ 3. TF-IDF (Term Frequency-Inverse Document Frequency): Salton, Gerard, and Michael J. McGill. (1983). Introduction to Modern Information Retrieval. McGraw-Hill, New York. ISBN: 978-0070544840
257
 
258
+ 4. Logistic Regression for Text Classification: Hastie, Trevor, Robert Tibshirani, and Jerome Friedman. (2009). The Elements of Statistical Learning: Data Mining, Inference, and Prediction (2nd ed.). Springer Series in Statistics. Springer, New York. DOI: 10.1007/978-0-387-84858-7
259
 
260
+ 5. Scikit-learn: Pedregosa, Fabian, Gaël Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, Jake Vanderplas, Alexandre Passos, David Cournapeau, Matthieu Brucher, Matthieu Perrot, and Édouard Duchesnay. (2011). Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research, 12(85), 2825-2830. Retrieved from https://www.jmlr.org/papers/v12/pedregosa11a.html
261
 
262
+ 6. N-gram Language Models: Brown, Peter F., Vincent J. Della Pietra, Peter V. deSouza, Jenifer C. Lai, and Robert L. Mercer. (1992). Class-Based n-gram Models of Natural Language. Computational Linguistics, 18(4), 467-480. Retrieved from https://aclanthology.org/J92-4003/
263
 
264
  # License
265
+ Model trained on publicly available VNTC and UTS2017_Bank datasets. Please refer to original dataset licenses for usage terms.
266
 
267
  # Citation
268
 
train.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- Training script for Vietnamese text classification using UTS2017_Bank dataset.
4
- This script trains a TF-IDF + Logistic Regression model on the UTS2017_Bank classification dataset.
 
5
  """
6
 
7
  import argparse
@@ -13,6 +14,9 @@ from datetime import datetime
13
 
14
  import numpy as np
15
  from datasets import load_dataset
 
 
 
16
  from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
17
  from sklearn.linear_model import LogisticRegression
18
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
@@ -41,6 +45,83 @@ def setup_logging(run_name):
41
  return run_dir
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def load_uts2017_data(split_ratio=0.2, random_state=42, n_samples=None):
45
  """Load and prepare UTS2017_Bank classification dataset
46
 
@@ -106,6 +187,7 @@ def get_available_models():
106
 
107
 
108
  def train_model(
 
109
  model_name="logistic",
110
  max_features=20000,
111
  ngram_range=(1, 2),
@@ -115,6 +197,7 @@ def train_model(
115
  """Train a single model with specified parameters
116
 
117
  Args:
 
118
  model_name: Name of the model to train ('logistic' or 'svc')
119
  max_features: Maximum number of features for TF-IDF vectorizer
120
  ngram_range: N-gram range for feature extraction
@@ -139,10 +222,18 @@ def train_model(
139
  os.makedirs(output_folder, exist_ok=True)
140
 
141
  # Load data
142
- logging.info("Loading UTS2017_Bank dataset...")
143
- (X_train, y_train), (X_test, y_test) = load_uts2017_data(
144
- split_ratio=split_ratio, n_samples=n_samples
145
- )
 
 
 
 
 
 
 
 
146
 
147
  # Get unique labels for reporting
148
  unique_labels = sorted(set(y_train))
@@ -167,7 +258,7 @@ def train_model(
167
  logging.info(f"Selected classifier: {clf_name}")
168
 
169
  # Configuration name
170
- config_name = f"UTS2017_Bank_{clf_name}_feat{max_features // 1000}k_ngram{ngram_range[0]}-{ngram_range[1]}"
171
 
172
  logging.info("=" * 60)
173
  logging.info(f"Training: {config_name}")
@@ -354,23 +445,25 @@ def train_all_configurations():
354
  return results
355
 
356
 
357
- def train_notebook(model_name="logistic", max_features=20000, ngram_min=1, ngram_max=2,
358
  split_ratio=0.2, n_samples=None, compare=False):
359
  """
360
  Convenience function for training in Jupyter/Colab notebooks without argparse.
361
 
362
  Example usage:
363
  from train import train_notebook
364
- train_notebook(model_name="logistic", max_features=20000)
365
  """
366
  if compare:
367
  print("Training and comparing multiple configurations...")
368
  return train_all_configurations()
369
  else:
370
- print(f"Training {model_name} model on UTS2017_Bank dataset...")
 
371
  print(f"Configuration: max_features={max_features}, ngram=({ngram_min}, {ngram_max})")
372
 
373
  return train_model(
 
374
  model_name=model_name,
375
  max_features=max_features,
376
  ngram_range=(ngram_min, ngram_max),
@@ -386,7 +479,14 @@ def main():
386
  in_notebook = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules or 'google.colab' in sys.modules
387
 
388
  parser = argparse.ArgumentParser(
389
- description="Train Vietnamese text classification model on UTS2017_Bank dataset"
 
 
 
 
 
 
 
390
  )
391
  parser.add_argument(
392
  "--model",
@@ -433,12 +533,14 @@ def main():
433
  print("Training and comparing multiple configurations...")
434
  train_all_configurations()
435
  else:
436
- print(f"Training {args.model} model on UTS2017_Bank dataset...")
 
437
  print(
438
  f"Configuration: max_features={args.max_features}, ngram=({args.ngram_min}, {args.ngram_max})"
439
  )
440
 
441
  train_model(
 
442
  model_name=args.model,
443
  max_features=args.max_features,
444
  ngram_range=(args.ngram_min, args.ngram_max),
 
1
  #!/usr/bin/env python3
2
  """
3
+ Training script for Vietnamese text classification.
4
+ Supports both VNTC (news) and UTS2017_Bank (banking) datasets.
5
+ This script trains a TF-IDF + Logistic Regression model on Vietnamese text classification datasets.
6
  """
7
 
8
  import argparse
 
14
 
15
  import numpy as np
16
  from datasets import load_dataset
17
+ import requests
18
+ import zipfile
19
+ from io import BytesIO
20
  from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
21
  from sklearn.linear_model import LogisticRegression
22
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
 
45
  return run_dir
46
 
47
 
48
+ def load_vntc_data(split_ratio=0.2, random_state=42, n_samples=None):
49
+ """Load and prepare VNTC dataset
50
+
51
+ Args:
52
+ split_ratio: Not used for VNTC (has predefined train/test split)
53
+ random_state: Not used for VNTC (has predefined train/test split)
54
+ n_samples: Optional limit on number of samples
55
+
56
+ Returns:
57
+ Tuple of (X_train, y_train), (X_test, y_test)
58
+ """
59
+ print("Loading VNTC dataset...")
60
+
61
+ # Define dataset folder
62
+ dataset_folder = os.path.expanduser("~/.underthesea/VNTC")
63
+ os.makedirs(dataset_folder, exist_ok=True)
64
+
65
+ train_file = os.path.join(dataset_folder, "train.txt")
66
+ test_file = os.path.join(dataset_folder, "test.txt")
67
+
68
+ # Download if not exists
69
+ if not os.path.exists(train_file) or not os.path.exists(test_file):
70
+ print("Downloading VNTC dataset...")
71
+ url = "https://github.com/undertheseanlp/underthesea/releases/download/resources/VNTC.zip"
72
+
73
+ response = requests.get(url)
74
+ with zipfile.ZipFile(BytesIO(response.content)) as zip_file:
75
+ zip_file.extractall(dataset_folder)
76
+ print("Dataset downloaded and extracted.")
77
+
78
+ # Load train data
79
+ X_train = []
80
+ y_train = []
81
+ with open(train_file, 'r', encoding='utf-8') as f:
82
+ for line in f:
83
+ if line.strip():
84
+ parts = line.strip().split(' ', 1)
85
+ if len(parts) == 2:
86
+ label = parts[0].replace('__label__', '')
87
+ text = parts[1]
88
+ y_train.append(label)
89
+ X_train.append(text)
90
+
91
+ # Load test data
92
+ X_test = []
93
+ y_test = []
94
+ with open(test_file, 'r', encoding='utf-8') as f:
95
+ for line in f:
96
+ if line.strip():
97
+ parts = line.strip().split(' ', 1)
98
+ if len(parts) == 2:
99
+ label = parts[0].replace('__label__', '')
100
+ text = parts[1]
101
+ y_test.append(label)
102
+ X_test.append(text)
103
+
104
+ # Apply sample limit if specified
105
+ if n_samples:
106
+ if n_samples < len(X_train):
107
+ X_train = X_train[:n_samples]
108
+ y_train = y_train[:n_samples]
109
+ if n_samples < len(X_test):
110
+ X_test = X_test[:n_samples]
111
+ y_test = y_test[:n_samples]
112
+
113
+ # Convert to numpy arrays
114
+ X_train = np.array(X_train)
115
+ y_train = np.array(y_train)
116
+ X_test = np.array(X_test)
117
+ y_test = np.array(y_test)
118
+
119
+ print(f"Dataset loaded: {len(X_train)} train samples, {len(X_test)} test samples")
120
+ print(f"Number of unique labels: {len(set(y_train))}")
121
+
122
+ return (X_train, y_train), (X_test, y_test)
123
+
124
+
125
  def load_uts2017_data(split_ratio=0.2, random_state=42, n_samples=None):
126
  """Load and prepare UTS2017_Bank classification dataset
127
 
 
187
 
188
 
189
  def train_model(
190
+ dataset="uts2017",
191
  model_name="logistic",
192
  max_features=20000,
193
  ngram_range=(1, 2),
 
197
  """Train a single model with specified parameters
198
 
199
  Args:
200
+ dataset: Dataset to use ('vntc' or 'uts2017')
201
  model_name: Name of the model to train ('logistic' or 'svc')
202
  max_features: Maximum number of features for TF-IDF vectorizer
203
  ngram_range: N-gram range for feature extraction
 
222
  os.makedirs(output_folder, exist_ok=True)
223
 
224
  # Load data
225
+ if dataset == "vntc":
226
+ logging.info("Loading VNTC dataset...")
227
+ (X_train, y_train), (X_test, y_test) = load_vntc_data(
228
+ split_ratio=split_ratio, n_samples=n_samples
229
+ )
230
+ dataset_name = "VNTC"
231
+ else:
232
+ logging.info("Loading UTS2017_Bank dataset...")
233
+ (X_train, y_train), (X_test, y_test) = load_uts2017_data(
234
+ split_ratio=split_ratio, n_samples=n_samples
235
+ )
236
+ dataset_name = "UTS2017_Bank"
237
 
238
  # Get unique labels for reporting
239
  unique_labels = sorted(set(y_train))
 
258
  logging.info(f"Selected classifier: {clf_name}")
259
 
260
  # Configuration name
261
+ config_name = f"{dataset_name}_{clf_name}_feat{max_features // 1000}k_ngram{ngram_range[0]}-{ngram_range[1]}"
262
 
263
  logging.info("=" * 60)
264
  logging.info(f"Training: {config_name}")
 
445
  return results
446
 
447
 
448
+ def train_notebook(dataset="uts2017", model_name="logistic", max_features=20000, ngram_min=1, ngram_max=2,
449
  split_ratio=0.2, n_samples=None, compare=False):
450
  """
451
  Convenience function for training in Jupyter/Colab notebooks without argparse.
452
 
453
  Example usage:
454
  from train import train_notebook
455
+ train_notebook(dataset="vntc", model_name="logistic", max_features=20000)
456
  """
457
  if compare:
458
  print("Training and comparing multiple configurations...")
459
  return train_all_configurations()
460
  else:
461
+ dataset_name = "VNTC" if dataset == "vntc" else "UTS2017_Bank"
462
+ print(f"Training {model_name} model on {dataset_name} dataset...")
463
  print(f"Configuration: max_features={max_features}, ngram=({ngram_min}, {ngram_max})")
464
 
465
  return train_model(
466
+ dataset=dataset,
467
  model_name=model_name,
468
  max_features=max_features,
469
  ngram_range=(ngram_min, ngram_max),
 
479
  in_notebook = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules or 'google.colab' in sys.modules
480
 
481
  parser = argparse.ArgumentParser(
482
+ description="Train Vietnamese text classification model on VNTC or UTS2017_Bank dataset"
483
+ )
484
+ parser.add_argument(
485
+ "--dataset",
486
+ type=str,
487
+ choices=["vntc", "uts2017"],
488
+ default="uts2017",
489
+ help="Dataset to use for training (default: uts2017)",
490
  )
491
  parser.add_argument(
492
  "--model",
 
533
  print("Training and comparing multiple configurations...")
534
  train_all_configurations()
535
  else:
536
+ dataset_name = "VNTC" if args.dataset == "vntc" else "UTS2017_Bank"
537
+ print(f"Training {args.model} model on {dataset_name} dataset...")
538
  print(
539
  f"Configuration: max_features={args.max_features}, ngram=({args.ngram_min}, {args.ngram_max})"
540
  )
541
 
542
  train_model(
543
+ dataset=args.dataset,
544
  model_name=args.model,
545
  max_features=args.max_features,
546
  ngram_range=(args.ngram_min, args.ngram_max),