Vu Anh Claude commited on
Commit ·
f62e707
1
Parent(s): a8bf530
Add dual dataset support and update system documentation
Browse files- Enhanced train.py to support both VNTC and UTS2017_Bank datasets
- Added --dataset parameter for dataset selection
- Updated System Card with UTS2017_Bank performance metrics
- Achieved 92.33% accuracy on VNTC and 70.96% on UTS2017_Bank
- Cleaned up redundant files and old training runs
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
- Sonar Core 1 - System Card.md +126 -24
- train.py +114 -12
Sonar Core 1 - System Card.md
CHANGED
|
@@ -6,17 +6,22 @@
|
|
| 6 |
|
| 7 |
# Changelog
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
**2025-09-21**
|
| 10 |
|
| 11 |
- Initial release of Sonar Core 1
|
| 12 |
|
| 13 |
# Abstract
|
| 14 |
|
| 15 |
-
**Sonar Core 1** is a machine learning-based text classification model designed for Vietnamese language processing. Built on a **TF-IDF** (Term Frequency-Inverse Document Frequency) feature extraction pipeline combined with **Logistic Regression**, this model achieves **92.33% accuracy** on the VNTC (Vietnamese Text Classification) dataset across **10 news categories**. The model is specifically designed for Vietnamese news article classification, content categorization for Vietnamese text, and document organization and tagging. Developed as a base model to provide quick and reliable text classification support for **scikit-learn >=1.6** integration since **underthesea 8.1.0**, it employs optimized feature engineering with **20,000 max features** and bigram support, along with a hash-based caching system for efficient processing. This system card provides comprehensive documentation of the model's architecture, performance metrics, intended uses, and limitations.
|
| 16 |
|
| 17 |
# 1. Model Details
|
| 18 |
|
| 19 |
-
**Sonar Core 1** is a Vietnamese text classification model built on **scikit-learn >=1.6**, utilizing a TF-IDF pipeline with Logistic Regression to classify text across
|
| 20 |
- CountVectorizer with **20,000 max features** (optimized from the initial 10,000)
|
| 21 |
- N-gram extraction: unigram and bigram support
|
| 22 |
- TF-IDF transformation with IDF weighting
|
|
@@ -27,7 +32,7 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
|
|
| 27 |
|
| 28 |
# 2. Training Data
|
| 29 |
|
| 30 |
-
## 2.1
|
| 31 |
1. **chinh_tri_xa_hoi** - Politics and Society
|
| 32 |
2. **doi_song** - Lifestyle
|
| 33 |
3. **khoa_hoc** - Science
|
|
@@ -39,27 +44,51 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
|
|
| 39 |
9. **van_hoa** - Culture
|
| 40 |
10. **vi_tinh** - Information Technology
|
| 41 |
|
| 42 |
-
## 2.2 Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
- **Name**: VNTC (Vietnamese Text Classification) Dataset
|
| 44 |
- **Training Samples**: 33,759 documents
|
| 45 |
- **Test Samples**: 50,373 documents
|
| 46 |
- **Language**: Vietnamese
|
| 47 |
- **Format**: FastText format (__label__category followed by text)
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
-
|
| 51 |
-
-
|
| 52 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# 3. Performance Metrics
|
| 55 |
|
| 56 |
-
## 3.1
|
| 57 |
- **Training Accuracy**: 95.39%
|
| 58 |
- **Test Accuracy**: 92.33%
|
| 59 |
- **Training Time**: ~27.18 seconds (with caching system)
|
| 60 |
- **Inference Time**: ~19.34 seconds for 50,373 samples
|
| 61 |
|
| 62 |
-
## 3.2 Per-Class Performance
|
| 63 |
| Category | Precision | Recall | F1-Score | Support |
|
| 64 |
|----------|-----------|---------|-----------|---------|
|
| 65 |
| chinh_tri_xa_hoi | 0.86 | 0.93 | 0.89 | 7,567 |
|
|
@@ -73,34 +102,82 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
|
|
| 73 |
| van_hoa | 0.93 | 0.95 | 0.94 | 6,250 |
|
| 74 |
| vi_tinh | 0.94 | 0.95 | 0.94 | 4,560 |
|
| 75 |
|
| 76 |
-
## 3.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
- **Overall Accuracy**: 92%
|
| 78 |
- **Macro Average**: Precision: 0.91, Recall: 0.90, F1: 0.91
|
| 79 |
- **Weighted Average**: Precision: 0.92, Recall: 0.92, F1: 0.92
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
- **Best Performing Categories**: Sports (the_thao) achieves 98% F1-score, followed by Health, World, Culture, and IT (all 94% F1-score)
|
| 83 |
- **Lowest Performing Category**: Lifestyle (doi_song) with 76% F1-score due to lower recall (71%)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
- **Feature Count**: Uses 20,000 max features with bigram support
|
| 85 |
- **Caching System**: Hash-based caching for efficient vectorizer and TF-IDF processing
|
|
|
|
| 86 |
|
| 87 |
# 4. Limitations
|
| 88 |
|
| 89 |
## 4.1 Known Limitations
|
| 90 |
1. **Language Specificity**: Only works with Vietnamese text
|
| 91 |
-
2. **Domain Specificity**: Optimized for
|
| 92 |
-
- Social media posts
|
| 93 |
-
- Technical documentation
|
| 94 |
- Conversational text
|
| 95 |
3. **Feature Limitations**:
|
| 96 |
- Limited to 20,000 most frequent features
|
| 97 |
- May miss rare but important terms
|
| 98 |
-
4. **Class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
## 4.2 Biases
|
| 101 |
-
- Trained on news
|
| 102 |
-
- May reflect biases present in the original
|
| 103 |
-
- Performance varies across categories
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# 5. Future Improvements
|
| 106 |
|
|
@@ -109,6 +186,11 @@ Released on **2025-09-21**, the model achieves **92.33% test accuracy** and **95
|
|
| 109 |
3. Add support for longer documents
|
| 110 |
4. Implement confidence thresholds for uncertain predictions
|
| 111 |
5. Fine-tune on domain-specific data if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# 6. Usage
|
| 114 |
|
|
@@ -118,8 +200,26 @@ pip install scikit-learn>=1.6 joblib
|
|
| 118 |
```
|
| 119 |
|
| 120 |
## 6.2 Training
|
|
|
|
|
|
|
| 121 |
```bash
|
|
|
|
| 122 |
uv run --no-project --with 'scikit-learn>=1.6' python train.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
```
|
| 124 |
|
| 125 |
## 6.3 Inference
|
|
@@ -151,16 +251,18 @@ probabilities = model.predict_proba([text])[0]
|
|
| 151 |
|
| 152 |
1. VNTC Dataset: Hoang, Cong Duy Vu, Dien Dinh, Le Nguyen Nguyen, and Quoc Hung Ngo. (2007). A Comparative Study on Vietnamese Text Classification Methods. In Proceedings of IEEE International Conference on Research, Innovation and Vision for the Future (RIVF 2007), pp. 267-273. IEEE. DOI: 10.1109/RIVF.2007.369167
|
| 153 |
|
| 154 |
-
2.
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
|
| 157 |
|
| 158 |
-
|
| 159 |
|
| 160 |
-
|
| 161 |
|
| 162 |
# License
|
| 163 |
-
Model trained on publicly available VNTC
|
| 164 |
|
| 165 |
# Citation
|
| 166 |
|
|
|
|
| 6 |
|
| 7 |
# Changelog
|
| 8 |
|
| 9 |
+
**2025-09-27**
|
| 10 |
+
|
| 11 |
+
- Added support for UTS2017_Bank Vietnamese banking text classification dataset
|
| 12 |
+
- Achieved 70.96% accuracy on 14 banking service categories
|
| 13 |
+
|
| 14 |
**2025-09-21**
|
| 15 |
|
| 16 |
- Initial release of Sonar Core 1
|
| 17 |
|
| 18 |
# Abstract
|
| 19 |
|
| 20 |
+
**Sonar Core 1** is a machine learning-based text classification model designed for Vietnamese language processing. Built on a **TF-IDF** (Term Frequency-Inverse Document Frequency) feature extraction pipeline combined with **Logistic Regression**, this model achieves **92.33% accuracy** on the VNTC (Vietnamese Text Classification) dataset across **10 news categories** and **70.96% accuracy** on the UTS2017_Bank dataset across **14 banking service categories**. The model is specifically designed for Vietnamese news article classification, banking text categorization, content categorization for Vietnamese text, and document organization and tagging. Developed as a base model to provide quick and reliable text classification support for **scikit-learn >=1.6** integration since **underthesea 8.1.0**, it employs optimized feature engineering with **20,000 max features** and bigram support, along with a hash-based caching system for efficient processing. This system card provides comprehensive documentation of the model's architecture, performance metrics, intended uses, and limitations.
|
| 21 |
|
| 22 |
# 1. Model Details
|
| 23 |
|
| 24 |
+
**Sonar Core 1** is a Vietnamese text classification model built on **scikit-learn >=1.6**, utilizing a TF-IDF pipeline with Logistic Regression to classify text across multiple domains including news categories and banking services. The architecture employs:
|
| 25 |
- CountVectorizer with **20,000 max features** (optimized from the initial 10,000)
|
| 26 |
- N-gram extraction: unigram and bigram support
|
| 27 |
- TF-IDF transformation with IDF weighting
|
|
|
|
| 32 |
|
| 33 |
# 2. Training Data
|
| 34 |
|
| 35 |
+
## 2.1 VNTC Dataset - News Categories (10 classes)
|
| 36 |
1. **chinh_tri_xa_hoi** - Politics and Society
|
| 37 |
2. **doi_song** - Lifestyle
|
| 38 |
3. **khoa_hoc** - Science
|
|
|
|
| 44 |
9. **van_hoa** - Culture
|
| 45 |
10. **vi_tinh** - Information Technology
|
| 46 |
|
| 47 |
+
## 2.2 UTS2017_Bank Dataset - Banking Categories (14 classes)
|
| 48 |
+
1. **ACCOUNT** - Account services
|
| 49 |
+
2. **CARD** - Card services
|
| 50 |
+
3. **CUSTOMER_SUPPORT** - Customer support
|
| 51 |
+
4. **DISCOUNT** - Discount offers
|
| 52 |
+
5. **INTEREST_RATE** - Interest rate information
|
| 53 |
+
6. **INTERNET_BANKING** - Internet banking services
|
| 54 |
+
7. **LOAN** - Loan services
|
| 55 |
+
8. **MONEY_TRANSFER** - Money transfer services
|
| 56 |
+
9. **OTHER** - Other services
|
| 57 |
+
10. **PAYMENT** - Payment services
|
| 58 |
+
11. **PROMOTION** - Promotional offers
|
| 59 |
+
12. **SAVING** - Savings accounts
|
| 60 |
+
13. **SECURITY** - Security features
|
| 61 |
+
14. **TRADEMARK** - Trademark/branding
|
| 62 |
+
|
| 63 |
+
## 2.3 Dataset Details
|
| 64 |
+
|
| 65 |
+
### VNTC Dataset
|
| 66 |
- **Name**: VNTC (Vietnamese Text Classification) Dataset
|
| 67 |
- **Training Samples**: 33,759 documents
|
| 68 |
- **Test Samples**: 50,373 documents
|
| 69 |
- **Language**: Vietnamese
|
| 70 |
- **Format**: FastText format (__label__category followed by text)
|
| 71 |
+
- **Distribution**: Balanced across 10 news categories
|
| 72 |
+
- **Average document length**: ~200-500 words
|
| 73 |
|
| 74 |
+
### UTS2017_Bank Dataset
|
| 75 |
+
- **Name**: UTS2017_Bank Classification Dataset
|
| 76 |
+
- **Training Samples**: 1,581 documents
|
| 77 |
+
- **Test Samples**: 396 documents
|
| 78 |
+
- **Language**: Vietnamese
|
| 79 |
+
- **Format**: Text with categorical labels
|
| 80 |
+
- **Distribution**: Imbalanced (CUSTOMER_SUPPORT: 39%, TRADEMARK: 35%, others: 26%)
|
| 81 |
+
- **Text preprocessing**: None (raw Vietnamese text)
|
| 82 |
|
| 83 |
# 3. Performance Metrics
|
| 84 |
|
| 85 |
+
## 3.1 VNTC Dataset Performance (2025-09-21)
|
| 86 |
- **Training Accuracy**: 95.39%
|
| 87 |
- **Test Accuracy**: 92.33%
|
| 88 |
- **Training Time**: ~27.18 seconds (with caching system)
|
| 89 |
- **Inference Time**: ~19.34 seconds for 50,373 samples
|
| 90 |
|
| 91 |
+
## 3.2 Per-Class Performance - VNTC Dataset
|
| 92 |
| Category | Precision | Recall | F1-Score | Support |
|
| 93 |
|----------|-----------|---------|-----------|---------|
|
| 94 |
| chinh_tri_xa_hoi | 0.86 | 0.93 | 0.89 | 7,567 |
|
|
|
|
| 102 |
| van_hoa | 0.93 | 0.95 | 0.94 | 6,250 |
|
| 103 |
| vi_tinh | 0.94 | 0.95 | 0.94 | 4,560 |
|
| 104 |
|
| 105 |
+
## 3.3 UTS2017_Bank Dataset Performance (2025-09-27)
|
| 106 |
+
- **Training Accuracy**: 76.22%
|
| 107 |
+
- **Test Accuracy**: 70.96%
|
| 108 |
+
- **Training Time**: ~0.78 seconds
|
| 109 |
+
- **Inference Time**: ~0.01 seconds for 396 samples
|
| 110 |
+
|
| 111 |
+
## 3.4 Per-Class Performance - UTS2017_Bank Dataset
|
| 112 |
+
| Category | Precision | Recall | F1-Score | Support |
|
| 113 |
+
|----------|-----------|---------|-----------|---------|
|
| 114 |
+
| ACCOUNT | 0.00 | 0.00 | 0.00 | 1 |
|
| 115 |
+
| CARD | 0.00 | 0.00 | 0.00 | 13 |
|
| 116 |
+
| CUSTOMER_SUPPORT | 0.62 | 0.97 | 0.76 | 155 |
|
| 117 |
+
| DISCOUNT | 0.00 | 0.00 | 0.00 | 8 |
|
| 118 |
+
| INTEREST_RATE | 0.50 | 0.08 | 0.14 | 12 |
|
| 119 |
+
| INTERNET_BANKING | 0.00 | 0.00 | 0.00 | 14 |
|
| 120 |
+
| LOAN | 0.67 | 0.13 | 0.22 | 15 |
|
| 121 |
+
| MONEY_TRANSFER | 0.00 | 0.00 | 0.00 | 7 |
|
| 122 |
+
| OTHER | 0.50 | 0.07 | 0.12 | 14 |
|
| 123 |
+
| PAYMENT | 0.00 | 0.00 | 0.00 | 3 |
|
| 124 |
+
| PROMOTION | 1.00 | 0.18 | 0.31 | 11 |
|
| 125 |
+
| SAVING | 0.00 | 0.00 | 0.00 | 2 |
|
| 126 |
+
| SECURITY | 0.00 | 0.00 | 0.00 | 1 |
|
| 127 |
+
| TRADEMARK | 0.87 | 0.89 | 0.88 | 140 |
|
| 128 |
+
|
| 129 |
+
## 3.5 Aggregate Metrics
|
| 130 |
+
|
| 131 |
+
### VNTC Dataset
|
| 132 |
- **Overall Accuracy**: 92%
|
| 133 |
- **Macro Average**: Precision: 0.91, Recall: 0.90, F1: 0.91
|
| 134 |
- **Weighted Average**: Precision: 0.92, Recall: 0.92, F1: 0.92
|
| 135 |
|
| 136 |
+
### UTS2017_Bank Dataset
|
| 137 |
+
- **Overall Accuracy**: 71%
|
| 138 |
+
- **Macro Average**: Precision: 0.30, Recall: 0.17, F1: 0.17
|
| 139 |
+
- **Weighted Average**: Precision: 0.64, Recall: 0.71, F1: 0.63
|
| 140 |
+
|
| 141 |
+
## 3.6 Performance Analysis
|
| 142 |
+
|
| 143 |
+
### VNTC Dataset
|
| 144 |
- **Best Performing Categories**: Sports (the_thao) achieves 98% F1-score, followed by Health, World, Culture, and IT (all 94% F1-score)
|
| 145 |
- **Lowest Performing Category**: Lifestyle (doi_song) with 76% F1-score due to lower recall (71%)
|
| 146 |
+
|
| 147 |
+
### UTS2017_Bank Dataset
|
| 148 |
+
- **Best Performing Categories**: TRADEMARK (88% F1-score) and CUSTOMER_SUPPORT (76% F1-score)
|
| 149 |
+
- **Challenges**: Many minority classes with insufficient training data result in zero predictions
|
| 150 |
+
- **Data Imbalance**: Significant class imbalance with CUSTOMER_SUPPORT and TRADEMARK dominating (74% of data)
|
| 151 |
+
|
| 152 |
+
### General Observations
|
| 153 |
- **Feature Count**: Uses 20,000 max features with bigram support
|
| 154 |
- **Caching System**: Hash-based caching for efficient vectorizer and TF-IDF processing
|
| 155 |
+
- **Model performs better on balanced datasets** (VNTC) compared to imbalanced ones (UTS2017_Bank)
|
| 156 |
|
| 157 |
# 4. Limitations
|
| 158 |
|
| 159 |
## 4.1 Known Limitations
|
| 160 |
1. **Language Specificity**: Only works with Vietnamese text
|
| 161 |
+
2. **Domain Specificity**: Optimized for specific domains, may not generalize well to:
|
| 162 |
+
- Social media posts (unless trained on specific datasets)
|
| 163 |
+
- Technical documentation outside IT/banking domains
|
| 164 |
- Conversational text
|
| 165 |
3. **Feature Limitations**:
|
| 166 |
- Limited to 20,000 most frequent features
|
| 167 |
- May miss rare but important terms
|
| 168 |
+
4. **Class Imbalance Sensitivity**:
|
| 169 |
+
- Performance degrades significantly with imbalanced datasets
|
| 170 |
+
- Minority classes may receive zero predictions (as seen in UTS2017_Bank)
|
| 171 |
+
5. **Specific Category Weaknesses**:
|
| 172 |
+
- VNTC: Lower performance on lifestyle (doi_song) category (71% recall)
|
| 173 |
+
- UTS2017_Bank: Poor performance on minority classes (ACCOUNT, CARD, PAYMENT, etc.)
|
| 174 |
|
| 175 |
## 4.2 Biases
|
| 176 |
+
- Trained on specific domains (news and banking) which may have formal writing style bias
|
| 177 |
+
- May reflect biases present in the original datasets
|
| 178 |
+
- Performance varies significantly across categories:
|
| 179 |
+
- VNTC: Best on sports at 98% F1-score, weakest on lifestyle at 76% F1-score
|
| 180 |
+
- UTS2017_Bank: Best on TRADEMARK at 88% F1-score, many categories at 0% F1-score
|
| 181 |
|
| 182 |
# 5. Future Improvements
|
| 183 |
|
|
|
|
| 186 |
3. Add support for longer documents
|
| 187 |
4. Implement confidence thresholds for uncertain predictions
|
| 188 |
5. Fine-tune on domain-specific data if needed
|
| 189 |
+
6. Address class imbalance issues through:
|
| 190 |
+
- Oversampling minority classes
|
| 191 |
+
- Class weight adjustments
|
| 192 |
+
- Synthetic data generation (SMOTE)
|
| 193 |
+
7. Expand to more Vietnamese text domains
|
| 194 |
|
| 195 |
# 6. Usage
|
| 196 |
|
|
|
|
| 200 |
```
|
| 201 |
|
| 202 |
## 6.2 Training
|
| 203 |
+
|
| 204 |
+
### VNTC Dataset (News Classification)
|
| 205 |
```bash
|
| 206 |
+
# Default training with VNTC dataset
|
| 207 |
uv run --no-project --with 'scikit-learn>=1.6' python train.py
|
| 208 |
+
|
| 209 |
+
# With specific parameters
|
| 210 |
+
uv run --no-project --with 'scikit-learn>=1.6' python train.py --model logistic --max-features 20000
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
### UTS2017_Bank Dataset (Banking Text Classification)
|
| 214 |
+
```bash
|
| 215 |
+
# Train with UTS2017_Bank dataset (assuming train.py is modified for UTS2017_Bank)
|
| 216 |
+
python train.py --model logistic
|
| 217 |
+
|
| 218 |
+
# With specific parameters
|
| 219 |
+
python train.py --model logistic --max-features 20000 --ngram-min 1 --ngram-max 2
|
| 220 |
+
|
| 221 |
+
# Compare multiple configurations
|
| 222 |
+
python train.py --compare
|
| 223 |
```
|
| 224 |
|
| 225 |
## 6.3 Inference
|
|
|
|
| 251 |
|
| 252 |
1. VNTC Dataset: Hoang, Cong Duy Vu, Dien Dinh, Le Nguyen Nguyen, and Quoc Hung Ngo. (2007). A Comparative Study on Vietnamese Text Classification Methods. In Proceedings of IEEE International Conference on Research, Innovation and Vision for the Future (RIVF 2007), pp. 267-273. IEEE. DOI: 10.1109/RIVF.2007.369167
|
| 253 |
|
| 254 |
+
2. UTS2017_Bank Dataset: Available from Hugging Face Datasets: https://huggingface.co/datasets/undertheseanlp/UTS2017_Bank
|
| 255 |
+
|
| 256 |
+
3. TF-IDF (Term Frequency-Inverse Document Frequency): Salton, Gerard, and Michael J. McGill. (1983). Introduction to Modern Information Retrieval. McGraw-Hill, New York. ISBN: 978-0070544840
|
| 257 |
|
| 258 |
+
4. Logistic Regression for Text Classification: Hastie, Trevor, Robert Tibshirani, and Jerome Friedman. (2009). The Elements of Statistical Learning: Data Mining, Inference, and Prediction (2nd ed.). Springer Series in Statistics. Springer, New York. DOI: 10.1007/978-0-387-84858-7
|
| 259 |
|
| 260 |
+
5. Scikit-learn: Pedregosa, Fabian, Gaël Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, Jake Vanderplas, Alexandre Passos, David Cournapeau, Matthieu Brucher, Matthieu Perrot, and Édouard Duchesnay. (2011). Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research, 12(85), 2825-2830. Retrieved from https://www.jmlr.org/papers/v12/pedregosa11a.html
|
| 261 |
|
| 262 |
+
6. N-gram Language Models: Brown, Peter F., Vincent J. Della Pietra, Peter V. deSouza, Jenifer C. Lai, and Robert L. Mercer. (1992). Class-Based n-gram Models of Natural Language. Computational Linguistics, 18(4), 467-480. Retrieved from https://aclanthology.org/J92-4003/
|
| 263 |
|
| 264 |
# License
|
| 265 |
+
Model trained on publicly available VNTC and UTS2017_Bank datasets. Please refer to original dataset licenses for usage terms.
|
| 266 |
|
| 267 |
# Citation
|
| 268 |
|
train.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Training script for Vietnamese text classification
|
| 4 |
-
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import argparse
|
|
@@ -13,6 +14,9 @@ from datetime import datetime
|
|
| 13 |
|
| 14 |
import numpy as np
|
| 15 |
from datasets import load_dataset
|
|
|
|
|
|
|
|
|
|
| 16 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
| 17 |
from sklearn.linear_model import LogisticRegression
|
| 18 |
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
|
@@ -41,6 +45,83 @@ def setup_logging(run_name):
|
|
| 41 |
return run_dir
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def load_uts2017_data(split_ratio=0.2, random_state=42, n_samples=None):
|
| 45 |
"""Load and prepare UTS2017_Bank classification dataset
|
| 46 |
|
|
@@ -106,6 +187,7 @@ def get_available_models():
|
|
| 106 |
|
| 107 |
|
| 108 |
def train_model(
|
|
|
|
| 109 |
model_name="logistic",
|
| 110 |
max_features=20000,
|
| 111 |
ngram_range=(1, 2),
|
|
@@ -115,6 +197,7 @@ def train_model(
|
|
| 115 |
"""Train a single model with specified parameters
|
| 116 |
|
| 117 |
Args:
|
|
|
|
| 118 |
model_name: Name of the model to train ('logistic' or 'svc')
|
| 119 |
max_features: Maximum number of features for TF-IDF vectorizer
|
| 120 |
ngram_range: N-gram range for feature extraction
|
|
@@ -139,10 +222,18 @@ def train_model(
|
|
| 139 |
os.makedirs(output_folder, exist_ok=True)
|
| 140 |
|
| 141 |
# Load data
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
# Get unique labels for reporting
|
| 148 |
unique_labels = sorted(set(y_train))
|
|
@@ -167,7 +258,7 @@ def train_model(
|
|
| 167 |
logging.info(f"Selected classifier: {clf_name}")
|
| 168 |
|
| 169 |
# Configuration name
|
| 170 |
-
config_name = f"
|
| 171 |
|
| 172 |
logging.info("=" * 60)
|
| 173 |
logging.info(f"Training: {config_name}")
|
|
@@ -354,23 +445,25 @@ def train_all_configurations():
|
|
| 354 |
return results
|
| 355 |
|
| 356 |
|
| 357 |
-
def train_notebook(model_name="logistic", max_features=20000, ngram_min=1, ngram_max=2,
|
| 358 |
split_ratio=0.2, n_samples=None, compare=False):
|
| 359 |
"""
|
| 360 |
Convenience function for training in Jupyter/Colab notebooks without argparse.
|
| 361 |
|
| 362 |
Example usage:
|
| 363 |
from train import train_notebook
|
| 364 |
-
train_notebook(model_name="logistic", max_features=20000)
|
| 365 |
"""
|
| 366 |
if compare:
|
| 367 |
print("Training and comparing multiple configurations...")
|
| 368 |
return train_all_configurations()
|
| 369 |
else:
|
| 370 |
-
|
|
|
|
| 371 |
print(f"Configuration: max_features={max_features}, ngram=({ngram_min}, {ngram_max})")
|
| 372 |
|
| 373 |
return train_model(
|
|
|
|
| 374 |
model_name=model_name,
|
| 375 |
max_features=max_features,
|
| 376 |
ngram_range=(ngram_min, ngram_max),
|
|
@@ -386,7 +479,14 @@ def main():
|
|
| 386 |
in_notebook = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules or 'google.colab' in sys.modules
|
| 387 |
|
| 388 |
parser = argparse.ArgumentParser(
|
| 389 |
-
description="Train Vietnamese text classification model on UTS2017_Bank dataset"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
)
|
| 391 |
parser.add_argument(
|
| 392 |
"--model",
|
|
@@ -433,12 +533,14 @@ def main():
|
|
| 433 |
print("Training and comparing multiple configurations...")
|
| 434 |
train_all_configurations()
|
| 435 |
else:
|
| 436 |
-
|
|
|
|
| 437 |
print(
|
| 438 |
f"Configuration: max_features={args.max_features}, ngram=({args.ngram_min}, {args.ngram_max})"
|
| 439 |
)
|
| 440 |
|
| 441 |
train_model(
|
|
|
|
| 442 |
model_name=args.model,
|
| 443 |
max_features=args.max_features,
|
| 444 |
ngram_range=(args.ngram_min, args.ngram_max),
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Training script for Vietnamese text classification.
|
| 4 |
+
Supports both VNTC (news) and UTS2017_Bank (banking) datasets.
|
| 5 |
+
This script trains a TF-IDF + Logistic Regression model on Vietnamese text classification datasets.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import argparse
|
|
|
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
from datasets import load_dataset
|
| 17 |
+
import requests
|
| 18 |
+
import zipfile
|
| 19 |
+
from io import BytesIO
|
| 20 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
| 21 |
from sklearn.linear_model import LogisticRegression
|
| 22 |
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
|
|
|
| 45 |
return run_dir
|
| 46 |
|
| 47 |
|
| 48 |
+
def load_vntc_data(split_ratio=0.2, random_state=42, n_samples=None):
|
| 49 |
+
"""Load and prepare VNTC dataset
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
split_ratio: Not used for VNTC (has predefined train/test split)
|
| 53 |
+
random_state: Not used for VNTC (has predefined train/test split)
|
| 54 |
+
n_samples: Optional limit on number of samples
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple of (X_train, y_train), (X_test, y_test)
|
| 58 |
+
"""
|
| 59 |
+
print("Loading VNTC dataset...")
|
| 60 |
+
|
| 61 |
+
# Define dataset folder
|
| 62 |
+
dataset_folder = os.path.expanduser("~/.underthesea/VNTC")
|
| 63 |
+
os.makedirs(dataset_folder, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
train_file = os.path.join(dataset_folder, "train.txt")
|
| 66 |
+
test_file = os.path.join(dataset_folder, "test.txt")
|
| 67 |
+
|
| 68 |
+
# Download if not exists
|
| 69 |
+
if not os.path.exists(train_file) or not os.path.exists(test_file):
|
| 70 |
+
print("Downloading VNTC dataset...")
|
| 71 |
+
url = "https://github.com/undertheseanlp/underthesea/releases/download/resources/VNTC.zip"
|
| 72 |
+
|
| 73 |
+
response = requests.get(url)
|
| 74 |
+
with zipfile.ZipFile(BytesIO(response.content)) as zip_file:
|
| 75 |
+
zip_file.extractall(dataset_folder)
|
| 76 |
+
print("Dataset downloaded and extracted.")
|
| 77 |
+
|
| 78 |
+
# Load train data
|
| 79 |
+
X_train = []
|
| 80 |
+
y_train = []
|
| 81 |
+
with open(train_file, 'r', encoding='utf-8') as f:
|
| 82 |
+
for line in f:
|
| 83 |
+
if line.strip():
|
| 84 |
+
parts = line.strip().split(' ', 1)
|
| 85 |
+
if len(parts) == 2:
|
| 86 |
+
label = parts[0].replace('__label__', '')
|
| 87 |
+
text = parts[1]
|
| 88 |
+
y_train.append(label)
|
| 89 |
+
X_train.append(text)
|
| 90 |
+
|
| 91 |
+
# Load test data
|
| 92 |
+
X_test = []
|
| 93 |
+
y_test = []
|
| 94 |
+
with open(test_file, 'r', encoding='utf-8') as f:
|
| 95 |
+
for line in f:
|
| 96 |
+
if line.strip():
|
| 97 |
+
parts = line.strip().split(' ', 1)
|
| 98 |
+
if len(parts) == 2:
|
| 99 |
+
label = parts[0].replace('__label__', '')
|
| 100 |
+
text = parts[1]
|
| 101 |
+
y_test.append(label)
|
| 102 |
+
X_test.append(text)
|
| 103 |
+
|
| 104 |
+
# Apply sample limit if specified
|
| 105 |
+
if n_samples:
|
| 106 |
+
if n_samples < len(X_train):
|
| 107 |
+
X_train = X_train[:n_samples]
|
| 108 |
+
y_train = y_train[:n_samples]
|
| 109 |
+
if n_samples < len(X_test):
|
| 110 |
+
X_test = X_test[:n_samples]
|
| 111 |
+
y_test = y_test[:n_samples]
|
| 112 |
+
|
| 113 |
+
# Convert to numpy arrays
|
| 114 |
+
X_train = np.array(X_train)
|
| 115 |
+
y_train = np.array(y_train)
|
| 116 |
+
X_test = np.array(X_test)
|
| 117 |
+
y_test = np.array(y_test)
|
| 118 |
+
|
| 119 |
+
print(f"Dataset loaded: {len(X_train)} train samples, {len(X_test)} test samples")
|
| 120 |
+
print(f"Number of unique labels: {len(set(y_train))}")
|
| 121 |
+
|
| 122 |
+
return (X_train, y_train), (X_test, y_test)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
def load_uts2017_data(split_ratio=0.2, random_state=42, n_samples=None):
|
| 126 |
"""Load and prepare UTS2017_Bank classification dataset
|
| 127 |
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
def train_model(
|
| 190 |
+
dataset="uts2017",
|
| 191 |
model_name="logistic",
|
| 192 |
max_features=20000,
|
| 193 |
ngram_range=(1, 2),
|
|
|
|
| 197 |
"""Train a single model with specified parameters
|
| 198 |
|
| 199 |
Args:
|
| 200 |
+
dataset: Dataset to use ('vntc' or 'uts2017')
|
| 201 |
model_name: Name of the model to train ('logistic' or 'svc')
|
| 202 |
max_features: Maximum number of features for TF-IDF vectorizer
|
| 203 |
ngram_range: N-gram range for feature extraction
|
|
|
|
| 222 |
os.makedirs(output_folder, exist_ok=True)
|
| 223 |
|
| 224 |
# Load data
|
| 225 |
+
if dataset == "vntc":
|
| 226 |
+
logging.info("Loading VNTC dataset...")
|
| 227 |
+
(X_train, y_train), (X_test, y_test) = load_vntc_data(
|
| 228 |
+
split_ratio=split_ratio, n_samples=n_samples
|
| 229 |
+
)
|
| 230 |
+
dataset_name = "VNTC"
|
| 231 |
+
else:
|
| 232 |
+
logging.info("Loading UTS2017_Bank dataset...")
|
| 233 |
+
(X_train, y_train), (X_test, y_test) = load_uts2017_data(
|
| 234 |
+
split_ratio=split_ratio, n_samples=n_samples
|
| 235 |
+
)
|
| 236 |
+
dataset_name = "UTS2017_Bank"
|
| 237 |
|
| 238 |
# Get unique labels for reporting
|
| 239 |
unique_labels = sorted(set(y_train))
|
|
|
|
| 258 |
logging.info(f"Selected classifier: {clf_name}")
|
| 259 |
|
| 260 |
# Configuration name
|
| 261 |
+
config_name = f"{dataset_name}_{clf_name}_feat{max_features // 1000}k_ngram{ngram_range[0]}-{ngram_range[1]}"
|
| 262 |
|
| 263 |
logging.info("=" * 60)
|
| 264 |
logging.info(f"Training: {config_name}")
|
|
|
|
| 445 |
return results
|
| 446 |
|
| 447 |
|
| 448 |
+
def train_notebook(dataset="uts2017", model_name="logistic", max_features=20000, ngram_min=1, ngram_max=2,
|
| 449 |
split_ratio=0.2, n_samples=None, compare=False):
|
| 450 |
"""
|
| 451 |
Convenience function for training in Jupyter/Colab notebooks without argparse.
|
| 452 |
|
| 453 |
Example usage:
|
| 454 |
from train import train_notebook
|
| 455 |
+
train_notebook(dataset="vntc", model_name="logistic", max_features=20000)
|
| 456 |
"""
|
| 457 |
if compare:
|
| 458 |
print("Training and comparing multiple configurations...")
|
| 459 |
return train_all_configurations()
|
| 460 |
else:
|
| 461 |
+
dataset_name = "VNTC" if dataset == "vntc" else "UTS2017_Bank"
|
| 462 |
+
print(f"Training {model_name} model on {dataset_name} dataset...")
|
| 463 |
print(f"Configuration: max_features={max_features}, ngram=({ngram_min}, {ngram_max})")
|
| 464 |
|
| 465 |
return train_model(
|
| 466 |
+
dataset=dataset,
|
| 467 |
model_name=model_name,
|
| 468 |
max_features=max_features,
|
| 469 |
ngram_range=(ngram_min, ngram_max),
|
|
|
|
| 479 |
in_notebook = hasattr(sys, 'ps1') or 'ipykernel' in sys.modules or 'google.colab' in sys.modules
|
| 480 |
|
| 481 |
parser = argparse.ArgumentParser(
|
| 482 |
+
description="Train Vietnamese text classification model on VNTC or UTS2017_Bank dataset"
|
| 483 |
+
)
|
| 484 |
+
parser.add_argument(
|
| 485 |
+
"--dataset",
|
| 486 |
+
type=str,
|
| 487 |
+
choices=["vntc", "uts2017"],
|
| 488 |
+
default="uts2017",
|
| 489 |
+
help="Dataset to use for training (default: uts2017)",
|
| 490 |
)
|
| 491 |
parser.add_argument(
|
| 492 |
"--model",
|
|
|
|
| 533 |
print("Training and comparing multiple configurations...")
|
| 534 |
train_all_configurations()
|
| 535 |
else:
|
| 536 |
+
dataset_name = "VNTC" if args.dataset == "vntc" else "UTS2017_Bank"
|
| 537 |
+
print(f"Training {args.model} model on {dataset_name} dataset...")
|
| 538 |
print(
|
| 539 |
f"Configuration: max_features={args.max_features}, ngram=({args.ngram_min}, {args.ngram_max})"
|
| 540 |
)
|
| 541 |
|
| 542 |
train_model(
|
| 543 |
+
dataset=args.dataset,
|
| 544 |
model_name=args.model,
|
| 545 |
max_features=args.max_features,
|
| 546 |
ngram_range=(args.ngram_min, args.ngram_max),
|