|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Training script for Vietnamese Word Segmentation using CRF. |
|
|
|
|
|
Supports 3 CRF trainers: |
|
|
- python-crfsuite: Original Python bindings to CRFsuite |
|
|
- crfsuite-rs: Rust bindings to CRFsuite (pip install crfsuite) |
|
|
- underthesea-core: Underthesea's native Rust CRF implementation |
|
|
|
|
|
Models are saved to: models/word_segmentation/{version}/model.crfsuite |
|
|
|
|
|
Uses BIO tagging at SYLLABLE level: |
|
|
- B: Beginning of a word (first syllable) |
|
|
- I: Inside a word (continuation syllables) |
|
|
|
|
|
Usage: |
|
|
uv run scripts/train_word_segmentation.py |
|
|
uv run scripts/train_word_segmentation.py --trainer crfsuite-rs |
|
|
uv run scripts/train_word_segmentation.py --trainer underthesea-core |
|
|
uv run scripts/train_word_segmentation.py --version v1.1.0 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import platform |
|
|
import time |
|
|
from abc import ABC, abstractmethod |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
import psutil |
|
|
import yaml |
|
|
from datasets import load_dataset |
|
|
from sklearn.metrics import accuracy_score, classification_report, f1_score |
|
|
from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize |
|
|
|
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
|
|
|
|
|
|
TRAINERS = ["python-crfsuite", "crfsuite-rs", "underthesea-core"] |
|
|
|
|
|
|
|
|
def get_hardware_info(): |
|
|
"""Collect hardware and system information.""" |
|
|
info = { |
|
|
"platform": platform.system(), |
|
|
"platform_release": platform.release(), |
|
|
"architecture": platform.machine(), |
|
|
"python_version": platform.python_version(), |
|
|
"cpu_physical_cores": psutil.cpu_count(logical=False), |
|
|
"cpu_logical_cores": psutil.cpu_count(logical=True), |
|
|
"ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
|
|
} |
|
|
|
|
|
try: |
|
|
if platform.system() == "Linux": |
|
|
with open("/proc/cpuinfo", "r") as f: |
|
|
for line in f: |
|
|
if "model name" in line: |
|
|
info["cpu_model"] = line.split(":")[1].strip() |
|
|
break |
|
|
except Exception: |
|
|
info["cpu_model"] = "Unknown" |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
def format_duration(seconds): |
|
|
"""Format duration in human-readable format.""" |
|
|
if seconds < 60: |
|
|
return f"{seconds:.2f}s" |
|
|
elif seconds < 3600: |
|
|
minutes = int(seconds // 60) |
|
|
secs = seconds % 60 |
|
|
return f"{minutes}m {secs:.2f}s" |
|
|
else: |
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
secs = seconds % 60 |
|
|
return f"{hours}h {minutes}m {secs:.2f}s" |
|
|
|
|
|
|
|
|
|
|
|
FEATURE_TEMPLATES = [ |
|
|
|
|
|
"S[0]", |
|
|
"S[0].lower", |
|
|
"S[0].istitle", |
|
|
"S[0].isupper", |
|
|
"S[0].isdigit", |
|
|
"S[0].ispunct", |
|
|
"S[0].len", |
|
|
"S[0].prefix2", |
|
|
"S[0].suffix2", |
|
|
|
|
|
"S[-1]", |
|
|
"S[-1].lower", |
|
|
"S[-2]", |
|
|
"S[-2].lower", |
|
|
|
|
|
"S[1]", |
|
|
"S[1].lower", |
|
|
"S[2]", |
|
|
"S[2].lower", |
|
|
|
|
|
"S[-1,0]", |
|
|
"S[0,1]", |
|
|
|
|
|
"S[-1,0,1]", |
|
|
] |
|
|
|
|
|
|
|
|
def get_syllable_at(syllables, position, offset): |
|
|
"""Get syllable at position + offset, with boundary handling.""" |
|
|
idx = position + offset |
|
|
if idx < 0: |
|
|
return "__BOS__" |
|
|
elif idx >= len(syllables): |
|
|
return "__EOS__" |
|
|
return syllables[idx] |
|
|
|
|
|
|
|
|
def is_punct(s): |
|
|
"""Check if string is punctuation.""" |
|
|
return len(s) == 1 and not s.isalnum() |
|
|
|
|
|
|
|
|
def extract_syllable_features(syllables, position): |
|
|
"""Extract features for a syllable at given position.""" |
|
|
features = {} |
|
|
|
|
|
|
|
|
s0 = get_syllable_at(syllables, position, 0) |
|
|
is_boundary = s0 in ("__BOS__", "__EOS__") |
|
|
|
|
|
features["S[0]"] = s0 |
|
|
features["S[0].lower"] = s0.lower() if not is_boundary else s0 |
|
|
features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False" |
|
|
features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False" |
|
|
features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False" |
|
|
features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False" |
|
|
features["S[0].len"] = str(len(s0)) if not is_boundary else "0" |
|
|
features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0 |
|
|
features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0 |
|
|
|
|
|
|
|
|
s_1 = get_syllable_at(syllables, position, -1) |
|
|
s_2 = get_syllable_at(syllables, position, -2) |
|
|
features["S[-1]"] = s_1 |
|
|
features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1 |
|
|
features["S[-2]"] = s_2 |
|
|
features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2 |
|
|
|
|
|
|
|
|
s1 = get_syllable_at(syllables, position, 1) |
|
|
s2 = get_syllable_at(syllables, position, 2) |
|
|
features["S[1]"] = s1 |
|
|
features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1 |
|
|
features["S[2]"] = s2 |
|
|
features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2 |
|
|
|
|
|
|
|
|
features["S[-1,0]"] = f"{s_1}|{s0}" |
|
|
features["S[0,1]"] = f"{s0}|{s1}" |
|
|
|
|
|
|
|
|
features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}" |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
def sentence_to_syllable_features(syllables): |
|
|
"""Convert syllable sequence to feature sequences.""" |
|
|
return [ |
|
|
[f"{k}={v}" for k, v in extract_syllable_features(syllables, i).items()] |
|
|
for i in range(len(syllables)) |
|
|
] |
|
|
|
|
|
|
|
|
def tokens_to_syllable_labels(tokens): |
|
|
""" |
|
|
Convert tokenized compound words to syllable-level BIO labels. |
|
|
|
|
|
Each compound word (e.g., "Thời hạn") is split into syllables, |
|
|
first syllable gets 'B', rest get 'I'. |
|
|
""" |
|
|
syllables = [] |
|
|
labels = [] |
|
|
|
|
|
for token in tokens: |
|
|
|
|
|
token_syllables = regex_tokenize(token) |
|
|
|
|
|
for i, syl in enumerate(token_syllables): |
|
|
syllables.append(syl) |
|
|
if i == 0: |
|
|
labels.append("B") |
|
|
else: |
|
|
labels.append("I") |
|
|
|
|
|
return syllables, labels |
|
|
|
|
|
|
|
|
def labels_to_words(syllables, labels): |
|
|
"""Convert syllable sequence and BIO labels back to words.""" |
|
|
words = [] |
|
|
current_word = [] |
|
|
|
|
|
for syl, label in zip(syllables, labels): |
|
|
if label == "B": |
|
|
if current_word: |
|
|
words.append(" ".join(current_word)) |
|
|
current_word = [syl] |
|
|
else: |
|
|
current_word.append(syl) |
|
|
|
|
|
if current_word: |
|
|
words.append(" ".join(current_word)) |
|
|
|
|
|
return words |
|
|
|
|
|
|
|
|
def compute_word_metrics(y_true, y_pred, syllables_list): |
|
|
"""Compute word-level F1 score.""" |
|
|
correct = 0 |
|
|
total_pred = 0 |
|
|
total_true = 0 |
|
|
|
|
|
for syllables, true_labels, pred_labels in zip(syllables_list, y_true, y_pred): |
|
|
true_words = labels_to_words(syllables, true_labels) |
|
|
pred_words = labels_to_words(syllables, pred_labels) |
|
|
|
|
|
total_true += len(true_words) |
|
|
total_pred += len(pred_words) |
|
|
|
|
|
|
|
|
true_boundaries = set() |
|
|
pred_boundaries = set() |
|
|
|
|
|
pos = 0 |
|
|
for word in true_words: |
|
|
n_syls = len(word.split()) |
|
|
true_boundaries.add((pos, pos + n_syls)) |
|
|
pos += n_syls |
|
|
|
|
|
pos = 0 |
|
|
for word in pred_words: |
|
|
n_syls = len(word.split()) |
|
|
pred_boundaries.add((pos, pos + n_syls)) |
|
|
pos += n_syls |
|
|
|
|
|
correct += len(true_boundaries & pred_boundaries) |
|
|
|
|
|
precision = correct / total_pred if total_pred > 0 else 0 |
|
|
recall = correct / total_true if total_true > 0 else 0 |
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
|
|
|
return precision, recall, f1 |
|
|
|
|
|
|
|
|
def load_data(): |
|
|
"""Load UDD-1 dataset and convert to syllable-level sequences.""" |
|
|
click.echo("Loading UDD-1 dataset...") |
|
|
dataset = load_dataset("undertheseanlp/UDD-1") |
|
|
|
|
|
def extract_syllable_sequences(split): |
|
|
sequences = [] |
|
|
for item in split: |
|
|
tokens = item["tokens"] |
|
|
if tokens: |
|
|
syllables, labels = tokens_to_syllable_labels(tokens) |
|
|
if syllables: |
|
|
sequences.append((syllables, labels)) |
|
|
return sequences |
|
|
|
|
|
train_data = extract_syllable_sequences(dataset["train"]) |
|
|
val_data = extract_syllable_sequences(dataset["validation"]) |
|
|
test_data = extract_syllable_sequences(dataset["test"]) |
|
|
|
|
|
|
|
|
train_syls = sum(len(syls) for syls, _ in train_data) |
|
|
val_syls = sum(len(syls) for syls, _ in val_data) |
|
|
test_syls = sum(len(syls) for syls, _ in test_data) |
|
|
|
|
|
click.echo(f"Loaded {len(train_data)} train ({train_syls} syllables), " |
|
|
f"{len(val_data)} val ({val_syls} syllables), " |
|
|
f"{len(test_data)} test ({test_syls} syllables) sentences") |
|
|
|
|
|
return train_data, val_data, test_data, { |
|
|
"train_sentences": len(train_data), |
|
|
"train_syllables": train_syls, |
|
|
"val_sentences": len(val_data), |
|
|
"val_syllables": val_syls, |
|
|
"test_sentences": len(test_data), |
|
|
"test_syllables": test_syls, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CRFTrainerBase(ABC): |
|
|
"""Abstract base class for CRF trainers.""" |
|
|
|
|
|
name: str = "base" |
|
|
|
|
|
@abstractmethod |
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
"""Train the CRF model and save to output_path.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def predict(self, model_path, X_test): |
|
|
"""Load model and predict on test data.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class PythonCRFSuiteTrainer(CRFTrainerBase): |
|
|
"""Trainer using python-crfsuite (original Python bindings).""" |
|
|
|
|
|
name = "python-crfsuite" |
|
|
|
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
import pycrfsuite |
|
|
|
|
|
trainer = pycrfsuite.Trainer(verbose=verbose) |
|
|
|
|
|
for xseq, yseq in zip(X_train, y_train): |
|
|
trainer.append(xseq, yseq) |
|
|
|
|
|
trainer.set_params({ |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
"feature.possible_transitions": True, |
|
|
}) |
|
|
|
|
|
trainer.train(str(output_path)) |
|
|
|
|
|
def predict(self, model_path, X_test): |
|
|
import pycrfsuite |
|
|
|
|
|
tagger = pycrfsuite.Tagger() |
|
|
tagger.open(str(model_path)) |
|
|
return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
class CRFSuiteRsTrainer(CRFTrainerBase): |
|
|
"""Trainer using crfsuite-rs (Rust bindings via pip install crfsuite).""" |
|
|
|
|
|
name = "crfsuite-rs" |
|
|
|
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
import crfsuite |
|
|
|
|
|
trainer = crfsuite.Trainer() |
|
|
|
|
|
|
|
|
trainer.set_params({ |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
"feature.possible_transitions": True, |
|
|
}) |
|
|
|
|
|
|
|
|
for xseq, yseq in zip(X_train, y_train): |
|
|
trainer.append(xseq, yseq) |
|
|
|
|
|
|
|
|
trainer.train(str(output_path)) |
|
|
|
|
|
def predict(self, model_path, X_test): |
|
|
import crfsuite |
|
|
|
|
|
model = crfsuite.Model(str(model_path)) |
|
|
return [model.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
class UndertheseaCoreTrainer(CRFTrainerBase): |
|
|
"""Trainer using underthesea-core native Rust CRF with LBFGS optimization. |
|
|
|
|
|
This trainer uses the native underthesea-core Rust CRF implementation |
|
|
with L-BFGS optimization, matching CRFsuite performance. |
|
|
|
|
|
Requires building underthesea-core from source: |
|
|
cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core |
|
|
uv venv && source .venv/bin/activate |
|
|
uv pip install maturin |
|
|
maturin develop --release |
|
|
""" |
|
|
|
|
|
name = "underthesea-core" |
|
|
|
|
|
def _check_trainer_import(self): |
|
|
"""Check if CRFTrainer is available.""" |
|
|
try: |
|
|
from underthesea_core import CRFTrainer |
|
|
return CRFTrainer |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
from underthesea_core.underthesea_core import CRFTrainer |
|
|
return CRFTrainer |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
raise ImportError( |
|
|
"CRFTrainer not available in underthesea_core.\n" |
|
|
"Build from source with LBFGS support:\n" |
|
|
" cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core\n" |
|
|
" source .venv/bin/activate && maturin develop --release" |
|
|
) |
|
|
|
|
|
def _check_tagger_import(self): |
|
|
"""Check if CRFModel and CRFTagger are available.""" |
|
|
try: |
|
|
from underthesea_core import CRFModel, CRFTagger |
|
|
return CRFModel, CRFTagger |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
try: |
|
|
from underthesea_core.underthesea_core import CRFModel, CRFTagger |
|
|
return CRFModel, CRFTagger |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
raise ImportError("CRFModel/CRFTagger not available in underthesea_core") |
|
|
|
|
|
def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
|
|
CRFTrainer = self._check_trainer_import() |
|
|
|
|
|
|
|
|
trainer = CRFTrainer( |
|
|
loss_function="lbfgs", |
|
|
l1_penalty=c1, |
|
|
l2_penalty=c2, |
|
|
max_iterations=max_iterations, |
|
|
verbose=1 if verbose else 0, |
|
|
) |
|
|
|
|
|
|
|
|
model = trainer.train(X_train, y_train) |
|
|
|
|
|
|
|
|
output_path_str = str(output_path) |
|
|
if output_path_str.endswith('.crfsuite'): |
|
|
output_path_str = output_path_str.replace('.crfsuite', '.crf') |
|
|
model.save(output_path_str) |
|
|
|
|
|
|
|
|
self._model_path = output_path_str |
|
|
|
|
|
def predict(self, model_path, X_test): |
|
|
CRFModel, CRFTagger = self._check_tagger_import() |
|
|
|
|
|
|
|
|
model_path_str = str(model_path) |
|
|
if hasattr(self, '_model_path'): |
|
|
model_path_str = self._model_path |
|
|
elif model_path_str.endswith('.crfsuite'): |
|
|
model_path_str = model_path_str.replace('.crfsuite', '.crf') |
|
|
|
|
|
model = CRFModel.load(model_path_str) |
|
|
tagger = CRFTagger.from_model(model) |
|
|
return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
|
|
|
|
def get_trainer(trainer_name: str) -> CRFTrainerBase: |
|
|
"""Get trainer instance by name.""" |
|
|
trainers = { |
|
|
"python-crfsuite": PythonCRFSuiteTrainer, |
|
|
"crfsuite-rs": CRFSuiteRsTrainer, |
|
|
"underthesea-core": UndertheseaCoreTrainer, |
|
|
} |
|
|
if trainer_name not in trainers: |
|
|
raise ValueError(f"Unknown trainer: {trainer_name}. Available: {list(trainers.keys())}") |
|
|
return trainers[trainer_name]() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_metadata(output_dir, version, trainer_name, data_stats, c1, c2, max_iterations, metrics, hw_info, training_time): |
|
|
"""Save model metadata to YAML file.""" |
|
|
metadata = { |
|
|
"model": { |
|
|
"name": "Vietnamese Word Segmentation", |
|
|
"version": version, |
|
|
"type": "CRF (Conditional Random Field)", |
|
|
"framework": trainer_name, |
|
|
"tagging_scheme": "BIO", |
|
|
}, |
|
|
"training": { |
|
|
"dataset": "undertheseanlp/UDD-1", |
|
|
"train_sentences": data_stats["train_sentences"], |
|
|
"train_syllables": data_stats["train_syllables"], |
|
|
"val_sentences": data_stats["val_sentences"], |
|
|
"val_syllables": data_stats["val_syllables"], |
|
|
"test_sentences": data_stats["test_sentences"], |
|
|
"test_syllables": data_stats["test_syllables"], |
|
|
"hyperparameters": { |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
}, |
|
|
"duration_seconds": round(training_time, 2), |
|
|
}, |
|
|
"performance": { |
|
|
"syllable_accuracy": round(metrics["syl_accuracy"], 4), |
|
|
"syllable_f1": round(metrics["syl_f1"], 4), |
|
|
"word_precision": round(metrics["word_precision"], 4), |
|
|
"word_recall": round(metrics["word_recall"], 4), |
|
|
"word_f1": round(metrics["word_f1"], 4), |
|
|
}, |
|
|
"environment": { |
|
|
"platform": hw_info["platform"], |
|
|
"cpu_model": hw_info.get("cpu_model", "Unknown"), |
|
|
"python_version": hw_info["python_version"], |
|
|
}, |
|
|
"files": { |
|
|
"model": "model.crfsuite", |
|
|
"config": "../../../configs/word_segmentation.yaml", |
|
|
}, |
|
|
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
|
"author": "undertheseanlp", |
|
|
} |
|
|
|
|
|
metadata_path = output_dir / "metadata.yaml" |
|
|
with open(metadata_path, "w") as f: |
|
|
yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False) |
|
|
click.echo(f"Metadata saved to {metadata_path}") |
|
|
|
|
|
|
|
|
def get_default_version(): |
|
|
"""Generate timestamp-based version.""" |
|
|
return datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"--trainer", "-t", |
|
|
type=click.Choice(TRAINERS), |
|
|
default="python-crfsuite", |
|
|
help="CRF trainer to use", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--version", "-v", |
|
|
default=None, |
|
|
help="Model version (default: timestamp, e.g., 20260131_154530)", |
|
|
) |
|
|
@click.option( |
|
|
"--output", "-o", |
|
|
default=None, |
|
|
help="Custom output path (overrides version-based path)", |
|
|
) |
|
|
@click.option( |
|
|
"--c1", |
|
|
default=1.0, |
|
|
type=float, |
|
|
help="L1 regularization coefficient", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--c2", |
|
|
default=0.001, |
|
|
type=float, |
|
|
help="L2 regularization coefficient", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--max-iterations", |
|
|
default=100, |
|
|
type=int, |
|
|
help="Maximum training iterations", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--wandb/--no-wandb", |
|
|
default=False, |
|
|
help="Enable Weights & Biases logging", |
|
|
) |
|
|
def train(trainer, version, output, c1, c2, max_iterations, wandb): |
|
|
"""Train Vietnamese Word Segmenter using CRF on UDD-1 dataset.""" |
|
|
total_start_time = time.time() |
|
|
start_datetime = datetime.now() |
|
|
|
|
|
|
|
|
crf_trainer = get_trainer(trainer) |
|
|
|
|
|
|
|
|
if version is None: |
|
|
version = get_default_version() |
|
|
|
|
|
|
|
|
if output: |
|
|
output_path = Path(output) |
|
|
output_dir = output_path.parent |
|
|
else: |
|
|
output_dir = PROJECT_ROOT / "models" / "word_segmentation" / version |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
output_path = output_dir / "model.crfsuite" |
|
|
|
|
|
|
|
|
hw_info = get_hardware_info() |
|
|
|
|
|
click.echo("=" * 60) |
|
|
click.echo(f"Word Segmentation Training - {version}") |
|
|
click.echo("=" * 60) |
|
|
click.echo(f"Trainer: {trainer}") |
|
|
click.echo(f"Platform: {hw_info['platform']}") |
|
|
click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}") |
|
|
click.echo(f"Output: {output_path}") |
|
|
click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
|
|
|
train_data, val_data, test_data, data_stats = load_data() |
|
|
|
|
|
click.echo(f"\nTrain: {len(train_data)} sentences ({data_stats['train_syllables']} syllables)") |
|
|
click.echo(f"Validation: {len(val_data)} sentences ({data_stats['val_syllables']} syllables)") |
|
|
click.echo(f"Test: {len(test_data)} sentences ({data_stats['test_syllables']} syllables)") |
|
|
|
|
|
|
|
|
click.echo("\nExtracting syllable-level features...") |
|
|
feature_start = time.time() |
|
|
X_train = [sentence_to_syllable_features(syls) for syls, _ in train_data] |
|
|
y_train = [labels for _, labels in train_data] |
|
|
click.echo(f"Feature extraction: {format_duration(time.time() - feature_start)}") |
|
|
|
|
|
|
|
|
click.echo(f"\nTraining CRF model with {trainer}...") |
|
|
|
|
|
use_wandb = wandb |
|
|
if use_wandb: |
|
|
try: |
|
|
import wandb as wb |
|
|
wb.init(project="word-segmentation-vietnamese", name=f"crf-{version}") |
|
|
wb.config.update({ |
|
|
"trainer": trainer, |
|
|
"c1": c1, |
|
|
"c2": c2, |
|
|
"max_iterations": max_iterations, |
|
|
"num_feature_templates": len(FEATURE_TEMPLATES), |
|
|
"train_sentences": len(train_data), |
|
|
"val_sentences": len(val_data), |
|
|
"test_sentences": len(test_data), |
|
|
"version": version, |
|
|
"level": "syllable", |
|
|
}) |
|
|
except ImportError: |
|
|
click.echo("wandb not installed, skipping logging", err=True) |
|
|
use_wandb = False |
|
|
|
|
|
crf_start = time.time() |
|
|
crf_trainer.train(X_train, y_train, output_path, c1, c2, max_iterations, verbose=True) |
|
|
crf_time = time.time() - crf_start |
|
|
click.echo(f"\nModel saved to {output_path}") |
|
|
click.echo(f"CRF training: {format_duration(crf_time)}") |
|
|
|
|
|
|
|
|
click.echo("\nEvaluating on test set...") |
|
|
|
|
|
X_test = [sentence_to_syllable_features(syls) for syls, _ in test_data] |
|
|
y_test = [labels for _, labels in test_data] |
|
|
syllables_test = [syls for syls, _ in test_data] |
|
|
|
|
|
y_pred = crf_trainer.predict(output_path, X_test) |
|
|
|
|
|
|
|
|
y_test_flat = [label for labels in y_test for label in labels] |
|
|
y_pred_flat = [label for labels in y_pred for label in labels] |
|
|
|
|
|
syl_accuracy = accuracy_score(y_test_flat, y_pred_flat) |
|
|
syl_f1 = f1_score(y_test_flat, y_pred_flat, average="weighted") |
|
|
|
|
|
click.echo(f"\nSyllable-level Accuracy: {syl_accuracy:.4f}") |
|
|
click.echo(f"Syllable-level F1 (weighted): {syl_f1:.4f}") |
|
|
click.echo("\nSyllable-level Classification Report:") |
|
|
click.echo(classification_report(y_test_flat, y_pred_flat)) |
|
|
|
|
|
|
|
|
precision, recall, word_f1 = compute_word_metrics(y_test, y_pred, syllables_test) |
|
|
click.echo(f"\nWord-level Metrics:") |
|
|
click.echo(f" Precision: {precision:.4f}") |
|
|
click.echo(f" Recall: {recall:.4f}") |
|
|
click.echo(f" F1: {word_f1:.4f}") |
|
|
|
|
|
total_time = time.time() - total_start_time |
|
|
|
|
|
|
|
|
metrics = { |
|
|
"syl_accuracy": syl_accuracy, |
|
|
"syl_f1": syl_f1, |
|
|
"word_precision": precision, |
|
|
"word_recall": recall, |
|
|
"word_f1": word_f1, |
|
|
} |
|
|
|
|
|
|
|
|
if not output: |
|
|
save_metadata(output_dir, version, trainer, data_stats, c1, c2, max_iterations, |
|
|
metrics, hw_info, total_time) |
|
|
|
|
|
|
|
|
click.echo("\n" + "=" * 60) |
|
|
click.echo("Example predictions:") |
|
|
click.echo("=" * 60) |
|
|
for i in range(min(3, len(test_data))): |
|
|
syllables = syllables_test[i] |
|
|
true_words = labels_to_words(syllables, y_test[i]) |
|
|
pred_words = labels_to_words(syllables, y_pred[i]) |
|
|
click.echo(f"\nInput: {' '.join(syllables)}") |
|
|
click.echo(f"True: {' | '.join(true_words)}") |
|
|
click.echo(f"Pred: {' | '.join(pred_words)}") |
|
|
|
|
|
click.echo("\n" + "=" * 60) |
|
|
click.echo("Training Summary") |
|
|
click.echo("=" * 60) |
|
|
click.echo(f"Trainer: {trainer}") |
|
|
click.echo(f"Version: {version}") |
|
|
click.echo(f"Model: {output_path}") |
|
|
click.echo(f"Syllable Accuracy: {syl_accuracy:.4f}") |
|
|
click.echo(f"Word F1: {word_f1:.4f}") |
|
|
click.echo(f"Total time: {format_duration(total_time)}") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
if use_wandb: |
|
|
wb.log(metrics) |
|
|
wb.finish() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|