|
|
""" |
|
|
Training CLI for Vietnamese Text Classification using Hydra. |
|
|
|
|
|
Usage: |
|
|
python src/train.py --config-name=vntc |
|
|
python src/train.py --config-name=sentiment_general |
|
|
python src/train.py --config-name=sentiment_bank |
|
|
python src/train.py --config-name=bank |
|
|
|
|
|
Override params from CLI: |
|
|
python src/train.py --config-name=sentiment_general model.c=0.5 model.max_features=100000 |
|
|
python src/train.py --config-name=vntc preprocessor=sentiment |
|
|
python src/train.py --config-name=sentiment_general data.vlsp2016_dir=/path/to/VLSP2016_SA |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import hydra |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from sklearn.metrics import accuracy_score, f1_score, classification_report |
|
|
|
|
|
from underthesea import TextClassifier, TextPreprocessor |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_preprocessor(pp_cfg): |
|
|
"""Build a Rust TextPreprocessor from model.preprocessor config.""" |
|
|
teencode = dict(pp_cfg.get("teencode", {})) or None |
|
|
neg_words = list(pp_cfg.get("negation_words", [])) or None |
|
|
neg_window = pp_cfg.get("negation_window", 2) |
|
|
|
|
|
return TextPreprocessor( |
|
|
lowercase=pp_cfg.get("lowercase", True), |
|
|
unicode_normalize=pp_cfg.get("unicode_normalize", True), |
|
|
remove_urls=pp_cfg.get("remove_urls", True), |
|
|
normalize_repeated_chars=pp_cfg.get("normalize_repeated_chars", True), |
|
|
normalize_punctuation=pp_cfg.get("normalize_punctuation", True), |
|
|
teencode=teencode, |
|
|
negation_words=neg_words, |
|
|
negation_window=neg_window, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_file(filepath): |
|
|
"""Read text file with multiple encoding attempts.""" |
|
|
for enc in ['utf-16', 'utf-8', 'latin-1']: |
|
|
try: |
|
|
with open(filepath, 'r', encoding=enc) as f: |
|
|
text = ' '.join(f.read().split()) |
|
|
if len(text) > 10: |
|
|
return text |
|
|
except (UnicodeDecodeError, UnicodeError): |
|
|
continue |
|
|
return None |
|
|
|
|
|
|
|
|
def load_vntc_data(data_dir): |
|
|
"""Load VNTC data from directory.""" |
|
|
texts, labels = [], [] |
|
|
for folder in sorted(os.listdir(data_dir)): |
|
|
folder_path = os.path.join(data_dir, folder) |
|
|
if not os.path.isdir(folder_path): |
|
|
continue |
|
|
for fname in os.listdir(folder_path): |
|
|
if fname.endswith('.txt'): |
|
|
text = read_file(os.path.join(folder_path, fname)) |
|
|
if text: |
|
|
texts.append(text) |
|
|
labels.append(folder) |
|
|
return texts, labels |
|
|
|
|
|
|
|
|
def load_vlsp2016(data_dir): |
|
|
"""Load VLSP2016 sentiment data from directory.""" |
|
|
label_map = {'POS': 'positive', 'NEG': 'negative', 'NEU': 'neutral'} |
|
|
texts, labels = [], [] |
|
|
for split in ['train.txt', 'test.txt']: |
|
|
split_texts, split_labels = [], [] |
|
|
filepath = os.path.join(data_dir, split) |
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line.startswith('__label__'): |
|
|
parts = line.split(' ', 1) |
|
|
label = label_map[parts[0].replace('__label__', '')] |
|
|
text = parts[1] if len(parts) > 1 else '' |
|
|
split_texts.append(text) |
|
|
split_labels.append(label) |
|
|
texts.append(split_texts) |
|
|
labels.append(split_labels) |
|
|
return texts[0], labels[0], texts[1], labels[1] |
|
|
|
|
|
|
|
|
def load_data(cfg): |
|
|
"""Load train/test data based on Hydra data config.""" |
|
|
data_cfg = cfg.data |
|
|
name = data_cfg.name |
|
|
extra_test = {} |
|
|
|
|
|
if name == "vntc": |
|
|
train_texts, train_labels = load_vntc_data( |
|
|
os.path.join(data_cfg.data_dir, "Train_Full")) |
|
|
test_texts, test_labels = load_vntc_data( |
|
|
os.path.join(data_cfg.data_dir, "Test_Full")) |
|
|
|
|
|
elif name == "bank": |
|
|
from datasets import load_dataset |
|
|
dataset = load_dataset(data_cfg.dataset, data_cfg.config) |
|
|
train_texts = list(dataset["train"]["text"]) |
|
|
train_labels = list(dataset["train"]["label"]) |
|
|
test_texts = list(dataset["test"]["text"]) |
|
|
test_labels = list(dataset["test"]["label"]) |
|
|
|
|
|
elif name == "sentiment_general": |
|
|
train_texts, train_labels, test_texts, test_labels = load_vlsp2016( |
|
|
data_cfg.data_dir) |
|
|
|
|
|
elif name == "sentiment_bank": |
|
|
from datasets import load_dataset |
|
|
ds_class = load_dataset(data_cfg.dataset, "classification") |
|
|
ds_sent = load_dataset(data_cfg.dataset, "sentiment") |
|
|
train_texts = list(ds_class["train"]["text"]) |
|
|
train_labels = [f'{c}#{s}' for c, s in |
|
|
zip(ds_class["train"]["label"], ds_sent["train"]["sentiment"])] |
|
|
test_texts = list(ds_class["test"]["text"]) |
|
|
test_labels = [f'{c}#{s}' for c, s in |
|
|
zip(ds_class["test"]["label"], ds_sent["test"]["sentiment"])] |
|
|
else: |
|
|
raise ValueError(f"Unknown data: {name}") |
|
|
|
|
|
return train_texts, train_labels, test_texts, test_labels, extra_test |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(test_labels, preds, name=""): |
|
|
"""Print evaluation metrics.""" |
|
|
acc = accuracy_score(test_labels, preds) |
|
|
f1_w = f1_score(test_labels, preds, average='weighted', zero_division=0) |
|
|
f1_m = f1_score(test_labels, preds, average='macro', zero_division=0) |
|
|
|
|
|
header = f"RESULTS ({name})" if name else "RESULTS" |
|
|
log.info("=" * 70) |
|
|
log.info(header) |
|
|
log.info("=" * 70) |
|
|
log.info(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)") |
|
|
log.info(f" F1 (weighted): {f1_w:.4f}") |
|
|
log.info(f" F1 (macro): {f1_m:.4f}") |
|
|
log.info("\n" + classification_report(test_labels, preds, zero_division=0)) |
|
|
return acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="conf", config_name="config") |
|
|
def train(cfg: DictConfig): |
|
|
"""Train Vietnamese text classification model.""" |
|
|
log.info("=" * 70) |
|
|
log.info(f"Training: {cfg.data.name}") |
|
|
log.info("=" * 70) |
|
|
log.info(f"\nConfig:\n{OmegaConf.to_yaml(cfg)}") |
|
|
|
|
|
|
|
|
log.info("Loading data...") |
|
|
t0 = time.perf_counter() |
|
|
train_texts, train_labels, test_texts, test_labels, extra_test = load_data(cfg) |
|
|
load_time = time.perf_counter() - t0 |
|
|
|
|
|
log.info(f" Train samples: {len(train_texts)}") |
|
|
log.info(f" Test samples: {len(test_texts)}") |
|
|
log.info(f" Labels: {len(set(train_labels))}") |
|
|
log.info(f" Load time: {load_time:.2f}s") |
|
|
|
|
|
|
|
|
|
|
|
preprocessor = None |
|
|
if cfg.model.get("preprocess", False): |
|
|
preprocessor = build_preprocessor(cfg.model.preprocessor) |
|
|
log.info(f"\nPreprocessor: {preprocessor}") |
|
|
|
|
|
|
|
|
model_cfg = cfg.model |
|
|
ngram_range = tuple(model_cfg.ngram_range) |
|
|
|
|
|
log.info("\nTraining TextClassifier...") |
|
|
log.info(f" max_features={model_cfg.max_features}, ngram_range={ngram_range}, " |
|
|
f"max_df={model_cfg.max_df}, C={model_cfg.c}") |
|
|
|
|
|
clf = TextClassifier( |
|
|
max_features=model_cfg.max_features, |
|
|
ngram_range=ngram_range, |
|
|
min_df=model_cfg.min_df, |
|
|
max_df=model_cfg.max_df, |
|
|
c=model_cfg.c, |
|
|
max_iter=model_cfg.max_iter, |
|
|
tol=model_cfg.tol, |
|
|
preprocessor=preprocessor, |
|
|
) |
|
|
|
|
|
t0 = time.perf_counter() |
|
|
clf.fit(train_texts, train_labels) |
|
|
train_time = time.perf_counter() - t0 |
|
|
log.info(f" Training time: {train_time:.3f}s") |
|
|
log.info(f" Vocabulary size: {clf.n_features}") |
|
|
|
|
|
|
|
|
|
|
|
log.info("\nEvaluating...") |
|
|
preds = clf.predict_batch(test_texts) |
|
|
evaluate(test_labels, preds, cfg.data.name) |
|
|
|
|
|
|
|
|
for name, (et_texts, et_labels) in extra_test.items(): |
|
|
et_preds = clf.predict_batch(et_texts) |
|
|
evaluate(et_labels, et_preds, name) |
|
|
|
|
|
|
|
|
output = cfg.output |
|
|
model_path = Path(output) |
|
|
model_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
clf.save(str(model_path)) |
|
|
|
|
|
size_mb = model_path.stat().st_size / (1024 * 1024) |
|
|
log.info(f"\nModel saved to {model_path} ({size_mb:.2f} MB)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|