File size: 8,951 Bytes
b059f86
903cdb2
b059f86
 
903cdb2
 
 
 
 
 
 
 
 
b059f86
 
 
 
903cdb2
b059f86
 
903cdb2
 
b059f86
 
903cdb2
b5fd35d
903cdb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b059f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903cdb2
b5fd35d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903cdb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5fd35d
903cdb2
 
b5fd35d
903cdb2
 
b5fd35d
903cdb2
 
 
 
b5fd35d
903cdb2
 
 
 
 
 
b5fd35d
903cdb2
 
 
b5fd35d
903cdb2
 
 
b5fd35d
 
903cdb2
 
 
 
 
 
 
 
b5fd35d
 
 
903cdb2
b5fd35d
903cdb2
 
b5fd35d
903cdb2
 
 
 
 
b5fd35d
903cdb2
 
 
 
b5fd35d
 
903cdb2
b5fd35d
 
 
 
 
903cdb2
b5fd35d
 
b059f86
903cdb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Training CLI for Vietnamese Text Classification using Hydra.

Usage:
    python src/train.py --config-name=vntc
    python src/train.py --config-name=sentiment_general
    python src/train.py --config-name=sentiment_bank
    python src/train.py --config-name=bank

Override params from CLI:
    python src/train.py --config-name=sentiment_general model.c=0.5 model.max_features=100000
    python src/train.py --config-name=vntc preprocessor=sentiment
    python src/train.py --config-name=sentiment_general data.vlsp2016_dir=/path/to/VLSP2016_SA
"""

import os
import time
import logging
from pathlib import Path

import hydra
from omegaconf import DictConfig, OmegaConf
from sklearn.metrics import accuracy_score, f1_score, classification_report

from underthesea import TextClassifier, TextPreprocessor

log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Preprocessor
# ---------------------------------------------------------------------------

def build_preprocessor(pp_cfg):
    """Build a Rust TextPreprocessor from model.preprocessor config."""
    teencode = dict(pp_cfg.get("teencode", {})) or None
    neg_words = list(pp_cfg.get("negation_words", [])) or None
    neg_window = pp_cfg.get("negation_window", 2)

    return TextPreprocessor(
        lowercase=pp_cfg.get("lowercase", True),
        unicode_normalize=pp_cfg.get("unicode_normalize", True),
        remove_urls=pp_cfg.get("remove_urls", True),
        normalize_repeated_chars=pp_cfg.get("normalize_repeated_chars", True),
        normalize_punctuation=pp_cfg.get("normalize_punctuation", True),
        teencode=teencode,
        negation_words=neg_words,
        negation_window=neg_window,
    )


# ---------------------------------------------------------------------------
# Data loaders
# ---------------------------------------------------------------------------

def read_file(filepath):
    """Read text file with multiple encoding attempts."""
    for enc in ['utf-16', 'utf-8', 'latin-1']:
        try:
            with open(filepath, 'r', encoding=enc) as f:
                text = ' '.join(f.read().split())
                if len(text) > 10:
                    return text
        except (UnicodeDecodeError, UnicodeError):
            continue
    return None


def load_vntc_data(data_dir):
    """Load VNTC data from directory."""
    texts, labels = [], []
    for folder in sorted(os.listdir(data_dir)):
        folder_path = os.path.join(data_dir, folder)
        if not os.path.isdir(folder_path):
            continue
        for fname in os.listdir(folder_path):
            if fname.endswith('.txt'):
                text = read_file(os.path.join(folder_path, fname))
                if text:
                    texts.append(text)
                    labels.append(folder)
    return texts, labels


def load_vlsp2016(data_dir):
    """Load VLSP2016 sentiment data from directory."""
    label_map = {'POS': 'positive', 'NEG': 'negative', 'NEU': 'neutral'}
    texts, labels = [], []
    for split in ['train.txt', 'test.txt']:
        split_texts, split_labels = [], []
        filepath = os.path.join(data_dir, split)
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('__label__'):
                    parts = line.split(' ', 1)
                    label = label_map[parts[0].replace('__label__', '')]
                    text = parts[1] if len(parts) > 1 else ''
                    split_texts.append(text)
                    split_labels.append(label)
        texts.append(split_texts)
        labels.append(split_labels)
    return texts[0], labels[0], texts[1], labels[1]


def load_data(cfg):
    """Load train/test data based on Hydra data config."""
    data_cfg = cfg.data
    name = data_cfg.name
    extra_test = {}

    if name == "vntc":
        train_texts, train_labels = load_vntc_data(
            os.path.join(data_cfg.data_dir, "Train_Full"))
        test_texts, test_labels = load_vntc_data(
            os.path.join(data_cfg.data_dir, "Test_Full"))

    elif name == "bank":
        from datasets import load_dataset
        dataset = load_dataset(data_cfg.dataset, data_cfg.config)
        train_texts = list(dataset["train"]["text"])
        train_labels = list(dataset["train"]["label"])
        test_texts = list(dataset["test"]["text"])
        test_labels = list(dataset["test"]["label"])

    elif name == "sentiment_general":
        train_texts, train_labels, test_texts, test_labels = load_vlsp2016(
            data_cfg.data_dir)

    elif name == "sentiment_bank":
        from datasets import load_dataset
        ds_class = load_dataset(data_cfg.dataset, "classification")
        ds_sent = load_dataset(data_cfg.dataset, "sentiment")
        train_texts = list(ds_class["train"]["text"])
        train_labels = [f'{c}#{s}' for c, s in
                        zip(ds_class["train"]["label"], ds_sent["train"]["sentiment"])]
        test_texts = list(ds_class["test"]["text"])
        test_labels = [f'{c}#{s}' for c, s in
                       zip(ds_class["test"]["label"], ds_sent["test"]["sentiment"])]
    else:
        raise ValueError(f"Unknown data: {name}")

    return train_texts, train_labels, test_texts, test_labels, extra_test


# ---------------------------------------------------------------------------
# Evaluate
# ---------------------------------------------------------------------------

def evaluate(test_labels, preds, name=""):
    """Print evaluation metrics."""
    acc = accuracy_score(test_labels, preds)
    f1_w = f1_score(test_labels, preds, average='weighted', zero_division=0)
    f1_m = f1_score(test_labels, preds, average='macro', zero_division=0)

    header = f"RESULTS ({name})" if name else "RESULTS"
    log.info("=" * 70)
    log.info(header)
    log.info("=" * 70)
    log.info(f"  Accuracy: {acc:.4f} ({acc*100:.2f}%)")
    log.info(f"  F1 (weighted): {f1_w:.4f}")
    log.info(f"  F1 (macro): {f1_m:.4f}")
    log.info("\n" + classification_report(test_labels, preds, zero_division=0))
    return acc


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="conf", config_name="config")
def train(cfg: DictConfig):
    """Train Vietnamese text classification model."""
    log.info("=" * 70)
    log.info(f"Training: {cfg.data.name}")
    log.info("=" * 70)
    log.info(f"\nConfig:\n{OmegaConf.to_yaml(cfg)}")

    # Load data
    log.info("Loading data...")
    t0 = time.perf_counter()
    train_texts, train_labels, test_texts, test_labels, extra_test = load_data(cfg)
    load_time = time.perf_counter() - t0

    log.info(f"  Train samples: {len(train_texts)}")
    log.info(f"  Test samples: {len(test_texts)}")
    log.info(f"  Labels: {len(set(train_labels))}")
    log.info(f"  Load time: {load_time:.2f}s")

    # Build preprocessor — model.preprocess=true activates model.preprocessor config
    # Preprocessor is passed to TextClassifier and packed into the .bin model
    preprocessor = None
    if cfg.model.get("preprocess", False):
        preprocessor = build_preprocessor(cfg.model.preprocessor)
        log.info(f"\nPreprocessor: {preprocessor}")

    # Build classifier from config
    model_cfg = cfg.model
    ngram_range = tuple(model_cfg.ngram_range)

    log.info("\nTraining TextClassifier...")
    log.info(f"  max_features={model_cfg.max_features}, ngram_range={ngram_range}, "
             f"max_df={model_cfg.max_df}, C={model_cfg.c}")

    clf = TextClassifier(
        max_features=model_cfg.max_features,
        ngram_range=ngram_range,
        min_df=model_cfg.min_df,
        max_df=model_cfg.max_df,
        c=model_cfg.c,
        max_iter=model_cfg.max_iter,
        tol=model_cfg.tol,
        preprocessor=preprocessor,
    )

    t0 = time.perf_counter()
    clf.fit(train_texts, train_labels)
    train_time = time.perf_counter() - t0
    log.info(f"  Training time: {train_time:.3f}s")
    log.info(f"  Vocabulary size: {clf.n_features}")

    # Evaluate on primary test set
    # TextClassifier auto-preprocesses via its built-in preprocessor
    log.info("\nEvaluating...")
    preds = clf.predict_batch(test_texts)
    evaluate(test_labels, preds, cfg.data.name)

    # Evaluate on extra test sets (e.g. VLSP2016)
    for name, (et_texts, et_labels) in extra_test.items():
        et_preds = clf.predict_batch(et_texts)
        evaluate(et_labels, et_preds, name)

    # Save model
    output = cfg.output
    model_path = Path(output)
    model_path.parent.mkdir(parents=True, exist_ok=True)
    clf.save(str(model_path))

    size_mb = model_path.stat().st_size / (1024 * 1024)
    log.info(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")


if __name__ == "__main__":
    train()