|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Training script for Vietnamese POS Tagger (TRE-1). |
|
|
|
|
|
Supports 3 CRF trainers: |
|
|
- python-crfsuite: Original Python bindings to CRFsuite |
|
|
- crfsuite-rs: Rust bindings to CRFsuite (pip install crfsuite) |
|
|
- underthesea-core: Underthesea's native Rust CRF implementation |
|
|
|
|
|
Models are saved to: models/pos_tagger/{version}/model.crfsuite |
|
|
|
|
|
Usage: |
|
|
uv run scripts/train.py |
|
|
uv run scripts/train.py --trainer crfsuite-rs |
|
|
uv run scripts/train.py --trainer underthesea-core |
|
|
uv run scripts/train.py --version v1.1.0 |
|
|
uv run scripts/train.py --wandb |
|
|
uv run scripts/train.py --c1 0.5 --c2 0.01 --max-iterations 200 |
|
|
""" |
|
|
|
|
|
import platform |
|
|
import re |
|
|
import time |
|
|
from abc import ABC, abstractmethod |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
import psutil |
|
|
import yaml |
|
|
from datasets import load_dataset |
|
|
from sklearn.metrics import accuracy_score, classification_report |
|
|
|
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
|
|
|
|
|
|
TRAINERS = ["python-crfsuite", "crfsuite-rs", "underthesea-core"] |
|
|
|
|
|
|
|
|
def get_hardware_info(): |
|
|
"""Collect hardware and system information.""" |
|
|
info = { |
|
|
"platform": platform.system(), |
|
|
"platform_release": platform.release(), |
|
|
"architecture": platform.machine(), |
|
|
"python_version": platform.python_version(), |
|
|
"cpu_physical_cores": psutil.cpu_count(logical=False), |
|
|
"cpu_logical_cores": psutil.cpu_count(logical=True), |
|
|
"ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
|
|
} |
|
|
|
|
|
try: |
|
|
if platform.system() == "Linux": |
|
|
with open("/proc/cpuinfo", "r") as f: |
|
|
for line in f: |
|
|
if "model name" in line: |
|
|
info["cpu_model"] = line.split(":")[1].strip() |
|
|
break |
|
|
except Exception: |
|
|
info["cpu_model"] = "Unknown" |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
def format_duration(seconds): |
|
|
"""Format duration in human-readable format.""" |
|
|
if seconds < 60: |
|
|
return f"{seconds:.2f}s" |
|
|
elif seconds < 3600: |
|
|
minutes = int(seconds // 60) |
|
|
secs = seconds % 60 |
|
|
return f"{minutes}m {secs:.2f}s" |
|
|
else: |
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
secs = seconds % 60 |
|
|
return f"{hours}h {minutes}m {secs:.2f}s" |
|
|
|
|
|
|
|
|
FEATURE_TEMPLATES = [ |
|
|
"T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", |
|
|
"T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", |
|
|
"T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", |
|
|
"T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", |
|
|
"T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", |
|
|
"T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", |
|
|
"T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", |
|
|
] |
|
|
|
|
|
|
|
|
def get_token_value(tokens, position, index): |
|
|
actual_pos = position + index |
|
|
if actual_pos < 0: |
|
|
return "__BOS__" |
|
|
elif actual_pos >= len(tokens): |
|
|
return "__EOS__" |
|
|
return tokens[actual_pos] |
|
|
|
|
|
|
|
|
def apply_attribute(value, attribute, dictionary=None): |
|
|
if value in ("__BOS__", "__EOS__"): |
|
|
return value |
|
|
if attribute is None: |
|
|
return value |
|
|
elif attribute == "lower": |
|
|
return value.lower() |
|
|
elif attribute == "upper": |
|
|
return value.upper() |
|
|
elif attribute == "istitle": |
|
|
return str(value.istitle()) |
|
|
elif attribute == "isupper": |
|
|
return str(value.isupper()) |
|
|
elif attribute == "islower": |
|
|
return str(value.islower()) |
|
|
elif attribute == "isdigit": |
|
|
return str(value.isdigit()) |
|
|
elif attribute == "isalpha": |
|
|
return str(value.isalpha()) |
|
|
elif attribute == "is_in_dict": |
|
|
return str(value in dictionary) if dictionary else "False" |
|
|
elif attribute.startswith("prefix"): |
|
|
n = int(attribute[6:]) if len(attribute) > 6 else 2 |
|
|
return value[:n] if len(value) >= n else value |
|
|
elif attribute.startswith("suffix"): |
|
|
n = int(attribute[6:]) if len(attribute) > 6 else 2 |
|
|
return value[-n:] if len(value) >= n else value |
|
|
return value |
|
|
|
|
|
|
|
|
def parse_template(template): |
|
|
match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template) |
|
|
if not match: |
|
|
return None, None |
|
|
indices_str = match.group(1) |
|
|
attribute = match.group(2) |
|
|
indices = [int(i.strip()) for i in indices_str.split(",")] |
|
|
return indices, attribute |
|
|
|
|
|
|
|
|
def extract_features(tokens, position, dictionary=None): |
|
|
features = {} |
|
|
for template in FEATURE_TEMPLATES: |
|
|
indices, attribute = parse_template(template) |
|
|
if indices is None: |
|
|
continue |
|
|
if len(indices) == 1: |
|
|
value = get_token_value(tokens, position, indices[0]) |
|
|
value = apply_attribute(value, attribute, dictionary) |
|
|
features[template] = value |
|
|
else: |
|
|
values = [get_token_value(tokens, position, idx) for idx in indices] |
|
|
if attribute == "is_in_dict": |
|
|
combined = " ".join(values) |
|
|
features[template] = str(combined in dictionary) if dictionary else "False" |
|
|
else: |
|
|
combined = "|".join(values) |
|
|
features[template] = combined |
|
|
return features |
|
|
|
|
|
|
|
|
def sentence_to_features(tokens): |
|
|
return [ |
|
|
[f"{k}={v}" for k, v in extract_features(tokens, i).items()] |
|
|
for i in range(len(tokens)) |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CRFTrainerBase(ABC): |
|
|
"""Abstract base class for CRF trainers.""" |
|
|
|
|
|
name: str = "base" |
|
|
|
|
|
@abstractmethod |
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
"""Train the CRF model and save to output_path.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def predict(self, model_path, X_test): |
|
|
"""Load model and predict on test data.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class PythonCRFSuiteTrainer(CRFTrainerBase): |
|
|
"""Trainer using python-crfsuite (original Python bindings).""" |
|
|
|
|
|
name = "python-crfsuite" |
|
|
|
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
import pycrfsuite |
|
|
|
|
|
trainer = pycrfsuite.Trainer(verbose=verbose) |
|
|
|
|
|
for xseq, yseq in zip(X_train, y_train): |
|
|
trainer.append(xseq, yseq) |
|
|
|
|
|
trainer.set_params({ |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
"feature.possible_transitions": True, |
|
|
}) |
|
|
|
|
|
trainer.train(str(output_path)) |
|
|
|
|
|
def predict(self, model_path, X_test): |
|
|
import pycrfsuite |
|
|
|
|
|
tagger = pycrfsuite.Tagger() |
|
|
tagger.open(str(model_path)) |
|
|
return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
class CRFSuiteRsTrainer(CRFTrainerBase): |
|
|
"""Trainer using crfsuite-rs (Rust bindings via pip install crfsuite).""" |
|
|
|
|
|
name = "crfsuite-rs" |
|
|
|
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
import crfsuite |
|
|
|
|
|
trainer = crfsuite.Trainer() |
|
|
|
|
|
|
|
|
trainer.set_params({ |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
"feature.possible_transitions": True, |
|
|
}) |
|
|
|
|
|
|
|
|
for xseq, yseq in zip(X_train, y_train): |
|
|
trainer.append(xseq, yseq) |
|
|
|
|
|
|
|
|
trainer.train(str(output_path)) |
|
|
|
|
|
def predict(self, model_path, X_test): |
|
|
import crfsuite |
|
|
|
|
|
model = crfsuite.Model(str(model_path)) |
|
|
return [model.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
class UndertheseaCoreTrainer(CRFTrainerBase): |
|
|
"""Trainer using underthesea-core native Rust CRF with LBFGS optimization. |
|
|
|
|
|
This trainer uses the native underthesea-core Rust CRF implementation |
|
|
with L-BFGS optimization, matching CRFsuite performance. |
|
|
|
|
|
Requires building underthesea-core from source: |
|
|
cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core |
|
|
uv venv && source .venv/bin/activate |
|
|
uv pip install maturin |
|
|
maturin develop --release |
|
|
""" |
|
|
|
|
|
name = "underthesea-core" |
|
|
|
|
|
def _check_trainer_import(self): |
|
|
"""Check if CRFTrainer is available.""" |
|
|
try: |
|
|
from underthesea_core import CRFTrainer |
|
|
return CRFTrainer |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
from underthesea_core.underthesea_core import CRFTrainer |
|
|
return CRFTrainer |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
raise ImportError( |
|
|
"CRFTrainer not available in underthesea_core.\n" |
|
|
"Build from source with LBFGS support:\n" |
|
|
" cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core\n" |
|
|
" source .venv/bin/activate && maturin develop --release" |
|
|
) |
|
|
|
|
|
def _check_tagger_import(self): |
|
|
"""Check if CRFModel and CRFTagger are available.""" |
|
|
try: |
|
|
from underthesea_core import CRFModel, CRFTagger |
|
|
return CRFModel, CRFTagger |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
from underthesea_core.underthesea_core import CRFModel, CRFTagger |
|
|
return CRFModel, CRFTagger |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
raise ImportError("CRFModel/CRFTagger not available in underthesea_core") |
|
|
|
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
CRFTrainer = self._check_trainer_import() |
|
|
|
|
|
|
|
|
trainer = CRFTrainer( |
|
|
loss_function="lbfgs", |
|
|
l1_penalty=c1, |
|
|
l2_penalty=c2, |
|
|
max_iterations=max_iterations, |
|
|
verbose=1 if verbose else 0, |
|
|
) |
|
|
|
|
|
|
|
|
model = trainer.train(X_train, y_train) |
|
|
|
|
|
|
|
|
output_path_str = str(output_path) |
|
|
if output_path_str.endswith('.crfsuite'): |
|
|
output_path_str = output_path_str.replace('.crfsuite', '.crf') |
|
|
model.save(output_path_str) |
|
|
|
|
|
|
|
|
self._model_path = output_path_str |
|
|
|
|
|
def predict(self, model_path, X_test): |
|
|
CRFModel, CRFTagger = self._check_tagger_import() |
|
|
|
|
|
|
|
|
model_path_str = str(model_path) |
|
|
if hasattr(self, '_model_path'): |
|
|
model_path_str = self._model_path |
|
|
elif model_path_str.endswith('.crfsuite'): |
|
|
model_path_str = model_path_str.replace('.crfsuite', '.crf') |
|
|
|
|
|
model = CRFModel.load(model_path_str) |
|
|
tagger = CRFTagger.from_model(model) |
|
|
return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
def get_trainer(trainer_name: str) -> CRFTrainerBase: |
|
|
"""Get trainer instance by name.""" |
|
|
trainers = { |
|
|
"python-crfsuite": PythonCRFSuiteTrainer, |
|
|
"crfsuite-rs": CRFSuiteRsTrainer, |
|
|
"underthesea-core": UndertheseaCoreTrainer, |
|
|
} |
|
|
if trainer_name not in trainers: |
|
|
raise ValueError(f"Unknown trainer: {trainer_name}. Available: {list(trainers.keys())}") |
|
|
return trainers[trainer_name]() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(): |
|
|
click.echo("Loading UDD-1 dataset...") |
|
|
dataset = load_dataset("undertheseanlp/UDD-1") |
|
|
|
|
|
def extract_sentences(split): |
|
|
sentences = [] |
|
|
for item in split: |
|
|
tokens = item["tokens"] |
|
|
tags = item["upos"] |
|
|
if tokens and tags: |
|
|
sentences.append((tokens, tags)) |
|
|
return sentences |
|
|
|
|
|
train_data = extract_sentences(dataset["train"]) |
|
|
val_data = extract_sentences(dataset["validation"]) |
|
|
test_data = extract_sentences(dataset["test"]) |
|
|
|
|
|
click.echo(f"Loaded {len(train_data)} train, {len(val_data)} val, {len(test_data)} test sentences") |
|
|
return train_data, val_data, test_data |
|
|
|
|
|
|
|
|
def save_metadata(output_dir, version, trainer_name, train_data, val_data, test_data, c1, c2, max_iterations, accuracy, hw_info, training_time): |
|
|
"""Save model metadata to YAML file.""" |
|
|
metadata = { |
|
|
"model": { |
|
|
"name": "Vietnamese POS Tagger", |
|
|
"version": version, |
|
|
"type": "CRF (Conditional Random Field)", |
|
|
"framework": trainer_name, |
|
|
}, |
|
|
"training": { |
|
|
"dataset": "undertheseanlp/UDD-1", |
|
|
"train_sentences": len(train_data), |
|
|
"val_sentences": len(val_data), |
|
|
"test_sentences": len(test_data), |
|
|
"hyperparameters": { |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
}, |
|
|
"duration_seconds": round(training_time, 2), |
|
|
}, |
|
|
"performance": { |
|
|
"test_accuracy": round(accuracy, 4), |
|
|
}, |
|
|
"environment": { |
|
|
"platform": hw_info["platform"], |
|
|
"cpu_model": hw_info.get("cpu_model", "Unknown"), |
|
|
"python_version": hw_info["python_version"], |
|
|
}, |
|
|
"files": { |
|
|
"model": "model.crfsuite", |
|
|
"config": "../../../configs/pos_tagger.yaml", |
|
|
}, |
|
|
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
|
"author": "undertheseanlp", |
|
|
} |
|
|
|
|
|
metadata_path = output_dir / "metadata.yaml" |
|
|
with open(metadata_path, "w") as f: |
|
|
yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False) |
|
|
click.echo(f"Metadata saved to {metadata_path}") |
|
|
|
|
|
|
|
|
def get_default_version(): |
|
|
"""Generate timestamp-based version.""" |
|
|
return datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"--trainer", "-t", |
|
|
type=click.Choice(TRAINERS), |
|
|
default="python-crfsuite", |
|
|
help="CRF trainer to use", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--version", "-v", |
|
|
default=None, |
|
|
help="Model version (default: timestamp, e.g., 20260131_154530)", |
|
|
) |
|
|
@click.option( |
|
|
"--output", "-o", |
|
|
default=None, |
|
|
help="Custom output path (overrides version-based path)", |
|
|
) |
|
|
@click.option( |
|
|
"--c1", |
|
|
default=1.0, |
|
|
type=float, |
|
|
help="L1 regularization coefficient", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--c2", |
|
|
default=0.001, |
|
|
type=float, |
|
|
help="L2 regularization coefficient", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--max-iterations", |
|
|
default=100, |
|
|
type=int, |
|
|
help="Maximum training iterations", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--wandb/--no-wandb", |
|
|
default=False, |
|
|
help="Enable Weights & Biases logging", |
|
|
) |
|
|
def train(trainer, version, output, c1, c2, max_iterations, wandb): |
|
|
"""Train Vietnamese POS Tagger using CRF on UDD-1 dataset.""" |
|
|
total_start_time = time.time() |
|
|
start_datetime = datetime.now() |
|
|
|
|
|
|
|
|
crf_trainer = get_trainer(trainer) |
|
|
|
|
|
|
|
|
if version is None: |
|
|
version = get_default_version() |
|
|
|
|
|
|
|
|
if output: |
|
|
output_path = Path(output) |
|
|
output_dir = output_path.parent |
|
|
else: |
|
|
output_dir = PROJECT_ROOT / "models" / "pos_tagger" / version |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
output_path = output_dir / "model.crfsuite" |
|
|
|
|
|
|
|
|
hw_info = get_hardware_info() |
|
|
|
|
|
click.echo("=" * 60) |
|
|
click.echo(f"POS Tagger Training - {version}") |
|
|
click.echo("=" * 60) |
|
|
click.echo(f"Trainer: {trainer}") |
|
|
click.echo(f"Platform: {hw_info['platform']}") |
|
|
click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}") |
|
|
click.echo(f"Output: {output_path}") |
|
|
click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
train_data, val_data, test_data = load_data() |
|
|
|
|
|
click.echo(f"\nTrain: {len(train_data)} sentences") |
|
|
click.echo(f"Validation: {len(val_data)} sentences") |
|
|
click.echo(f"Test: {len(test_data)} sentences") |
|
|
|
|
|
|
|
|
click.echo("\nExtracting features...") |
|
|
feature_start = time.time() |
|
|
X_train = [sentence_to_features(tokens) for tokens, _ in train_data] |
|
|
y_train = [tags for _, tags in train_data] |
|
|
click.echo(f"Feature extraction: {format_duration(time.time() - feature_start)}") |
|
|
|
|
|
|
|
|
click.echo(f"\nTraining CRF model with {trainer}...") |
|
|
|
|
|
use_wandb = wandb |
|
|
if use_wandb: |
|
|
try: |
|
|
import wandb as wb |
|
|
wb.init(project="pos-tagger-vietnamese", name=f"crf-{trainer}-{version}") |
|
|
wb.config.update({ |
|
|
"trainer": trainer, |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
"num_features": len(FEATURE_TEMPLATES), |
|
|
"train_sentences": len(train_data), |
|
|
"val_sentences": len(val_data), |
|
|
"test_sentences": len(test_data), |
|
|
"version": version, |
|
|
}) |
|
|
except ImportError: |
|
|
click.echo("wandb not installed, skipping logging", err=True) |
|
|
use_wandb = False |
|
|
|
|
|
crf_start = time.time() |
|
|
crf_trainer.train(X_train, y_train, output_path, c1, c2, max_iterations, verbose=True) |
|
|
crf_time = time.time() - crf_start |
|
|
click.echo(f"\nModel saved to {output_path}") |
|
|
click.echo(f"CRF training: {format_duration(crf_time)}") |
|
|
|
|
|
|
|
|
click.echo("\nEvaluating on test set...") |
|
|
X_test = [sentence_to_features(tokens) for tokens, _ in test_data] |
|
|
y_test = [tags for _, tags in test_data] |
|
|
|
|
|
y_pred = crf_trainer.predict(output_path, X_test) |
|
|
|
|
|
|
|
|
y_test_flat = [tag for tags in y_test for tag in tags] |
|
|
y_pred_flat = [tag for tags in y_pred for tag in tags] |
|
|
|
|
|
accuracy = accuracy_score(y_test_flat, y_pred_flat) |
|
|
|
|
|
total_time = time.time() - total_start_time |
|
|
|
|
|
click.echo(f"\nAccuracy: {accuracy:.4f}") |
|
|
click.echo("\nClassification Report:") |
|
|
click.echo(classification_report(y_test_flat, y_pred_flat)) |
|
|
|
|
|
|
|
|
if not output: |
|
|
save_metadata(output_dir, version, trainer, train_data, val_data, test_data, |
|
|
c1, c2, max_iterations, accuracy, hw_info, total_time) |
|
|
|
|
|
click.echo("\n" + "=" * 60) |
|
|
click.echo("Training Summary") |
|
|
click.echo("=" * 60) |
|
|
click.echo(f"Trainer: {trainer}") |
|
|
click.echo(f"Version: {version}") |
|
|
click.echo(f"Model: {output_path}") |
|
|
click.echo(f"Accuracy: {accuracy:.4f}") |
|
|
click.echo(f"Total time: {format_duration(total_time)}") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
if use_wandb: |
|
|
wb.log({"accuracy": accuracy}) |
|
|
wb.finish() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|