sen-1 / src /train.py
Tiep's picture
Refactor training to Hydra config and use underthesea imports
903cdb2
"""
Training CLI for Vietnamese Text Classification using Hydra.
Usage:
python src/train.py --config-name=vntc
python src/train.py --config-name=sentiment_general
python src/train.py --config-name=sentiment_bank
python src/train.py --config-name=bank
Override params from CLI:
python src/train.py --config-name=sentiment_general model.c=0.5 model.max_features=100000
python src/train.py --config-name=vntc preprocessor=sentiment
python src/train.py --config-name=sentiment_general data.vlsp2016_dir=/path/to/VLSP2016_SA
"""
import os
import time
import logging
from pathlib import Path
import hydra
from omegaconf import DictConfig, OmegaConf
from sklearn.metrics import accuracy_score, f1_score, classification_report
from underthesea import TextClassifier, TextPreprocessor
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Preprocessor
# ---------------------------------------------------------------------------
def build_preprocessor(pp_cfg):
"""Build a Rust TextPreprocessor from model.preprocessor config."""
teencode = dict(pp_cfg.get("teencode", {})) or None
neg_words = list(pp_cfg.get("negation_words", [])) or None
neg_window = pp_cfg.get("negation_window", 2)
return TextPreprocessor(
lowercase=pp_cfg.get("lowercase", True),
unicode_normalize=pp_cfg.get("unicode_normalize", True),
remove_urls=pp_cfg.get("remove_urls", True),
normalize_repeated_chars=pp_cfg.get("normalize_repeated_chars", True),
normalize_punctuation=pp_cfg.get("normalize_punctuation", True),
teencode=teencode,
negation_words=neg_words,
negation_window=neg_window,
)
# ---------------------------------------------------------------------------
# Data loaders
# ---------------------------------------------------------------------------
def read_file(filepath):
"""Read text file with multiple encoding attempts."""
for enc in ['utf-16', 'utf-8', 'latin-1']:
try:
with open(filepath, 'r', encoding=enc) as f:
text = ' '.join(f.read().split())
if len(text) > 10:
return text
except (UnicodeDecodeError, UnicodeError):
continue
return None
def load_vntc_data(data_dir):
"""Load VNTC data from directory."""
texts, labels = [], []
for folder in sorted(os.listdir(data_dir)):
folder_path = os.path.join(data_dir, folder)
if not os.path.isdir(folder_path):
continue
for fname in os.listdir(folder_path):
if fname.endswith('.txt'):
text = read_file(os.path.join(folder_path, fname))
if text:
texts.append(text)
labels.append(folder)
return texts, labels
def load_vlsp2016(data_dir):
"""Load VLSP2016 sentiment data from directory."""
label_map = {'POS': 'positive', 'NEG': 'negative', 'NEU': 'neutral'}
texts, labels = [], []
for split in ['train.txt', 'test.txt']:
split_texts, split_labels = [], []
filepath = os.path.join(data_dir, split)
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line.startswith('__label__'):
parts = line.split(' ', 1)
label = label_map[parts[0].replace('__label__', '')]
text = parts[1] if len(parts) > 1 else ''
split_texts.append(text)
split_labels.append(label)
texts.append(split_texts)
labels.append(split_labels)
return texts[0], labels[0], texts[1], labels[1]
def load_data(cfg):
"""Load train/test data based on Hydra data config."""
data_cfg = cfg.data
name = data_cfg.name
extra_test = {}
if name == "vntc":
train_texts, train_labels = load_vntc_data(
os.path.join(data_cfg.data_dir, "Train_Full"))
test_texts, test_labels = load_vntc_data(
os.path.join(data_cfg.data_dir, "Test_Full"))
elif name == "bank":
from datasets import load_dataset
dataset = load_dataset(data_cfg.dataset, data_cfg.config)
train_texts = list(dataset["train"]["text"])
train_labels = list(dataset["train"]["label"])
test_texts = list(dataset["test"]["text"])
test_labels = list(dataset["test"]["label"])
elif name == "sentiment_general":
train_texts, train_labels, test_texts, test_labels = load_vlsp2016(
data_cfg.data_dir)
elif name == "sentiment_bank":
from datasets import load_dataset
ds_class = load_dataset(data_cfg.dataset, "classification")
ds_sent = load_dataset(data_cfg.dataset, "sentiment")
train_texts = list(ds_class["train"]["text"])
train_labels = [f'{c}#{s}' for c, s in
zip(ds_class["train"]["label"], ds_sent["train"]["sentiment"])]
test_texts = list(ds_class["test"]["text"])
test_labels = [f'{c}#{s}' for c, s in
zip(ds_class["test"]["label"], ds_sent["test"]["sentiment"])]
else:
raise ValueError(f"Unknown data: {name}")
return train_texts, train_labels, test_texts, test_labels, extra_test
# ---------------------------------------------------------------------------
# Evaluate
# ---------------------------------------------------------------------------
def evaluate(test_labels, preds, name=""):
"""Print evaluation metrics."""
acc = accuracy_score(test_labels, preds)
f1_w = f1_score(test_labels, preds, average='weighted', zero_division=0)
f1_m = f1_score(test_labels, preds, average='macro', zero_division=0)
header = f"RESULTS ({name})" if name else "RESULTS"
log.info("=" * 70)
log.info(header)
log.info("=" * 70)
log.info(f" Accuracy: {acc:.4f} ({acc*100:.2f}%)")
log.info(f" F1 (weighted): {f1_w:.4f}")
log.info(f" F1 (macro): {f1_m:.4f}")
log.info("\n" + classification_report(test_labels, preds, zero_division=0))
return acc
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
@hydra.main(version_base=None, config_path="conf", config_name="config")
def train(cfg: DictConfig):
"""Train Vietnamese text classification model."""
log.info("=" * 70)
log.info(f"Training: {cfg.data.name}")
log.info("=" * 70)
log.info(f"\nConfig:\n{OmegaConf.to_yaml(cfg)}")
# Load data
log.info("Loading data...")
t0 = time.perf_counter()
train_texts, train_labels, test_texts, test_labels, extra_test = load_data(cfg)
load_time = time.perf_counter() - t0
log.info(f" Train samples: {len(train_texts)}")
log.info(f" Test samples: {len(test_texts)}")
log.info(f" Labels: {len(set(train_labels))}")
log.info(f" Load time: {load_time:.2f}s")
# Build preprocessor — model.preprocess=true activates model.preprocessor config
# Preprocessor is passed to TextClassifier and packed into the .bin model
preprocessor = None
if cfg.model.get("preprocess", False):
preprocessor = build_preprocessor(cfg.model.preprocessor)
log.info(f"\nPreprocessor: {preprocessor}")
# Build classifier from config
model_cfg = cfg.model
ngram_range = tuple(model_cfg.ngram_range)
log.info("\nTraining TextClassifier...")
log.info(f" max_features={model_cfg.max_features}, ngram_range={ngram_range}, "
f"max_df={model_cfg.max_df}, C={model_cfg.c}")
clf = TextClassifier(
max_features=model_cfg.max_features,
ngram_range=ngram_range,
min_df=model_cfg.min_df,
max_df=model_cfg.max_df,
c=model_cfg.c,
max_iter=model_cfg.max_iter,
tol=model_cfg.tol,
preprocessor=preprocessor,
)
t0 = time.perf_counter()
clf.fit(train_texts, train_labels)
train_time = time.perf_counter() - t0
log.info(f" Training time: {train_time:.3f}s")
log.info(f" Vocabulary size: {clf.n_features}")
# Evaluate on primary test set
# TextClassifier auto-preprocesses via its built-in preprocessor
log.info("\nEvaluating...")
preds = clf.predict_batch(test_texts)
evaluate(test_labels, preds, cfg.data.name)
# Evaluate on extra test sets (e.g. VLSP2016)
for name, (et_texts, et_labels) in extra_test.items():
et_preds = clf.predict_batch(et_texts)
evaluate(et_labels, et_preds, name)
# Save model
output = cfg.output
model_path = Path(output)
model_path.parent.mkdir(parents=True, exist_ok=True)
clf.save(str(model_path))
size_mb = model_path.stat().st_size / (1024 * 1024)
log.info(f"\nModel saved to {model_path} ({size_mb:.2f} MB)")
if __name__ == "__main__":
train()