| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Training script for Vietnamese POS Tagger (TRE-1). |
| |
| Supports 3 CRF trainers: |
| - python-crfsuite: Original Python bindings to CRFsuite |
| - crfsuite-rs: Rust bindings to CRFsuite (pip install crfsuite) |
| - underthesea-core: Underthesea's native Rust CRF implementation |
| |
| Models are saved to: models/pos_tagger/{version}/model.crfsuite |
| |
| Usage: |
| uv run scripts/train.py |
| uv run scripts/train.py --trainer crfsuite-rs |
| uv run scripts/train.py --trainer underthesea-core |
| uv run scripts/train.py --version v1.1.0 |
| uv run scripts/train.py --wandb |
| uv run scripts/train.py --c1 0.5 --c2 0.01 --max-iterations 200 |
| """ |
|
|
| import platform |
| import re |
| import time |
| from abc import ABC, abstractmethod |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import click |
| import psutil |
| import yaml |
| from datasets import load_dataset |
| from sklearn.metrics import accuracy_score, classification_report |
|
|
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
|
|
| |
| TRAINERS = ["python-crfsuite", "crfsuite-rs", "underthesea-core"] |
|
|
|
|
| def get_hardware_info(): |
| """Collect hardware and system information.""" |
| info = { |
| "platform": platform.system(), |
| "platform_release": platform.release(), |
| "architecture": platform.machine(), |
| "python_version": platform.python_version(), |
| "cpu_physical_cores": psutil.cpu_count(logical=False), |
| "cpu_logical_cores": psutil.cpu_count(logical=True), |
| "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
| } |
|
|
| try: |
| if platform.system() == "Linux": |
| with open("/proc/cpuinfo", "r") as f: |
| for line in f: |
| if "model name" in line: |
| info["cpu_model"] = line.split(":")[1].strip() |
| break |
| except Exception: |
| info["cpu_model"] = "Unknown" |
|
|
| return info |
|
|
|
|
| def format_duration(seconds): |
| """Format duration in human-readable format.""" |
| if seconds < 60: |
| return f"{seconds:.2f}s" |
| elif seconds < 3600: |
| minutes = int(seconds // 60) |
| secs = seconds % 60 |
| return f"{minutes}m {secs:.2f}s" |
| else: |
| hours = int(seconds // 3600) |
| minutes = int((seconds % 3600) // 60) |
| secs = seconds % 60 |
| return f"{hours}h {minutes}m {secs:.2f}s" |
|
|
|
|
| FEATURE_TEMPLATES = [ |
| "T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", |
| "T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", |
| "T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", |
| "T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", |
| "T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", |
| "T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", |
| "T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", |
| ] |
|
|
|
|
| def get_token_value(tokens, position, index): |
| actual_pos = position + index |
| if actual_pos < 0: |
| return "__BOS__" |
| elif actual_pos >= len(tokens): |
| return "__EOS__" |
| return tokens[actual_pos] |
|
|
|
|
| def apply_attribute(value, attribute, dictionary=None): |
| if value in ("__BOS__", "__EOS__"): |
| return value |
| if attribute is None: |
| return value |
| elif attribute == "lower": |
| return value.lower() |
| elif attribute == "upper": |
| return value.upper() |
| elif attribute == "istitle": |
| return str(value.istitle()) |
| elif attribute == "isupper": |
| return str(value.isupper()) |
| elif attribute == "islower": |
| return str(value.islower()) |
| elif attribute == "isdigit": |
| return str(value.isdigit()) |
| elif attribute == "isalpha": |
| return str(value.isalpha()) |
| elif attribute == "is_in_dict": |
| return str(value in dictionary) if dictionary else "False" |
| elif attribute.startswith("prefix"): |
| n = int(attribute[6:]) if len(attribute) > 6 else 2 |
| return value[:n] if len(value) >= n else value |
| elif attribute.startswith("suffix"): |
| n = int(attribute[6:]) if len(attribute) > 6 else 2 |
| return value[-n:] if len(value) >= n else value |
| return value |
|
|
|
|
| def parse_template(template): |
| match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template) |
| if not match: |
| return None, None |
| indices_str = match.group(1) |
| attribute = match.group(2) |
| indices = [int(i.strip()) for i in indices_str.split(",")] |
| return indices, attribute |
|
|
|
|
| def extract_features(tokens, position, dictionary=None): |
| features = {} |
| for template in FEATURE_TEMPLATES: |
| indices, attribute = parse_template(template) |
| if indices is None: |
| continue |
| if len(indices) == 1: |
| value = get_token_value(tokens, position, indices[0]) |
| value = apply_attribute(value, attribute, dictionary) |
| features[template] = value |
| else: |
| values = [get_token_value(tokens, position, idx) for idx in indices] |
| if attribute == "is_in_dict": |
| combined = " ".join(values) |
| features[template] = str(combined in dictionary) if dictionary else "False" |
| else: |
| combined = "|".join(values) |
| features[template] = combined |
| return features |
|
|
|
|
| def sentence_to_features(tokens): |
| return [ |
| [f"{k}={v}" for k, v in extract_features(tokens, i).items()] |
| for i in range(len(tokens)) |
| ] |
|
|
|
|
| |
| |
| |
|
|
| class CRFTrainerBase(ABC): |
| """Abstract base class for CRF trainers.""" |
|
|
| name: str = "base" |
|
|
| @abstractmethod |
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| """Train the CRF model and save to output_path.""" |
| pass |
|
|
| @abstractmethod |
| def predict(self, model_path, X_test): |
| """Load model and predict on test data.""" |
| pass |
|
|
|
|
| class PythonCRFSuiteTrainer(CRFTrainerBase): |
| """Trainer using python-crfsuite (original Python bindings).""" |
|
|
| name = "python-crfsuite" |
|
|
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| import pycrfsuite |
|
|
| trainer = pycrfsuite.Trainer(verbose=verbose) |
|
|
| for xseq, yseq in zip(X_train, y_train): |
| trainer.append(xseq, yseq) |
|
|
| trainer.set_params({ |
| "c1": c1, |
| "c2": c2, |
| "max_iterations": max_iterations, |
| "feature.possible_transitions": True, |
| }) |
|
|
| trainer.train(str(output_path)) |
|
|
| def predict(self, model_path, X_test): |
| import pycrfsuite |
|
|
| tagger = pycrfsuite.Tagger() |
| tagger.open(str(model_path)) |
| return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
| class CRFSuiteRsTrainer(CRFTrainerBase): |
| """Trainer using crfsuite-rs (Rust bindings via pip install crfsuite).""" |
|
|
| name = "crfsuite-rs" |
|
|
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| import crfsuite |
|
|
| trainer = crfsuite.Trainer() |
|
|
| |
| trainer.set_params({ |
| "c1": c1, |
| "c2": c2, |
| "max_iterations": max_iterations, |
| "feature.possible_transitions": True, |
| }) |
|
|
| |
| for xseq, yseq in zip(X_train, y_train): |
| trainer.append(xseq, yseq) |
|
|
| |
| trainer.train(str(output_path)) |
|
|
| def predict(self, model_path, X_test): |
| import crfsuite |
|
|
| model = crfsuite.Model(str(model_path)) |
| return [model.tag(xseq) for xseq in X_test] |
|
|
|
|
| class UndertheseaCoreTrainer(CRFTrainerBase): |
| """Trainer using underthesea-core native Rust CRF with LBFGS optimization. |
| |
| This trainer uses the native underthesea-core Rust CRF implementation |
| with L-BFGS optimization, matching CRFsuite performance. |
| |
| Requires building underthesea-core from source: |
| cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core |
| uv venv && source .venv/bin/activate |
| uv pip install maturin |
| maturin develop --release |
| """ |
|
|
| name = "underthesea-core" |
|
|
| def _check_trainer_import(self): |
| """Check if CRFTrainer is available.""" |
| try: |
| from underthesea_core import CRFTrainer |
| return CRFTrainer |
| except ImportError: |
| pass |
|
|
| try: |
| from underthesea_core.underthesea_core import CRFTrainer |
| return CRFTrainer |
| except ImportError: |
| pass |
|
|
| raise ImportError( |
| "CRFTrainer not available in underthesea_core.\n" |
| "Build from source with LBFGS support:\n" |
| " cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core\n" |
| " source .venv/bin/activate && maturin develop --release" |
| ) |
|
|
| def _check_tagger_import(self): |
| """Check if CRFModel and CRFTagger are available.""" |
| try: |
| from underthesea_core import CRFModel, CRFTagger |
| return CRFModel, CRFTagger |
| except ImportError: |
| pass |
|
|
| try: |
| from underthesea_core.underthesea_core import CRFModel, CRFTagger |
| return CRFModel, CRFTagger |
| except ImportError: |
| pass |
|
|
| raise ImportError("CRFModel/CRFTagger not available in underthesea_core") |
|
|
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| CRFTrainer = self._check_trainer_import() |
|
|
| |
| trainer = CRFTrainer( |
| loss_function="lbfgs", |
| l1_penalty=c1, |
| l2_penalty=c2, |
| max_iterations=max_iterations, |
| verbose=1 if verbose else 0, |
| ) |
|
|
| |
| model = trainer.train(X_train, y_train) |
|
|
| |
| output_path_str = str(output_path) |
| if output_path_str.endswith('.crfsuite'): |
| output_path_str = output_path_str.replace('.crfsuite', '.crf') |
| model.save(output_path_str) |
|
|
| |
| self._model_path = output_path_str |
|
|
| def predict(self, model_path, X_test): |
| CRFModel, CRFTagger = self._check_tagger_import() |
|
|
| |
| model_path_str = str(model_path) |
| if hasattr(self, '_model_path'): |
| model_path_str = self._model_path |
| elif model_path_str.endswith('.crfsuite'): |
| model_path_str = model_path_str.replace('.crfsuite', '.crf') |
|
|
| model = CRFModel.load(model_path_str) |
| tagger = CRFTagger.from_model(model) |
| return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
| def get_trainer(trainer_name: str) -> CRFTrainerBase: |
| """Get trainer instance by name.""" |
| trainers = { |
| "python-crfsuite": PythonCRFSuiteTrainer, |
| "crfsuite-rs": CRFSuiteRsTrainer, |
| "underthesea-core": UndertheseaCoreTrainer, |
| } |
| if trainer_name not in trainers: |
| raise ValueError(f"Unknown trainer: {trainer_name}. Available: {list(trainers.keys())}") |
| return trainers[trainer_name]() |
|
|
|
|
| |
| |
| |
|
|
| def load_data(): |
| click.echo("Loading UDD-1 dataset...") |
| dataset = load_dataset("undertheseanlp/UDD-1") |
|
|
| def extract_sentences(split): |
| sentences = [] |
| for item in split: |
| tokens = item["tokens"] |
| tags = item["upos"] |
| if tokens and tags: |
| sentences.append((tokens, tags)) |
| return sentences |
|
|
| train_data = extract_sentences(dataset["train"]) |
| val_data = extract_sentences(dataset["validation"]) |
| test_data = extract_sentences(dataset["test"]) |
|
|
| click.echo(f"Loaded {len(train_data)} train, {len(val_data)} val, {len(test_data)} test sentences") |
| return train_data, val_data, test_data |
|
|
|
|
| def save_metadata(output_dir, version, trainer_name, train_data, val_data, test_data, c1, c2, max_iterations, accuracy, hw_info, training_time): |
| """Save model metadata to YAML file.""" |
| metadata = { |
| "model": { |
| "name": "Vietnamese POS Tagger", |
| "version": version, |
| "type": "CRF (Conditional Random Field)", |
| "framework": trainer_name, |
| }, |
| "training": { |
| "dataset": "undertheseanlp/UDD-1", |
| "train_sentences": len(train_data), |
| "val_sentences": len(val_data), |
| "test_sentences": len(test_data), |
| "hyperparameters": { |
| "c1": c1, |
| "c2": c2, |
| "max_iterations": max_iterations, |
| }, |
| "duration_seconds": round(training_time, 2), |
| }, |
| "performance": { |
| "test_accuracy": round(accuracy, 4), |
| }, |
| "environment": { |
| "platform": hw_info["platform"], |
| "cpu_model": hw_info.get("cpu_model", "Unknown"), |
| "python_version": hw_info["python_version"], |
| }, |
| "files": { |
| "model": "model.crfsuite", |
| "config": "../../../configs/pos_tagger.yaml", |
| }, |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "author": "undertheseanlp", |
| } |
|
|
| metadata_path = output_dir / "metadata.yaml" |
| with open(metadata_path, "w") as f: |
| yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False) |
| click.echo(f"Metadata saved to {metadata_path}") |
|
|
|
|
| def get_default_version(): |
| """Generate timestamp-based version.""" |
| return datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
| @click.command() |
| @click.option( |
| "--trainer", "-t", |
| type=click.Choice(TRAINERS), |
| default="python-crfsuite", |
| help="CRF trainer to use", |
| show_default=True, |
| ) |
| @click.option( |
| "--version", "-v", |
| default=None, |
| help="Model version (default: timestamp, e.g., 20260131_154530)", |
| ) |
| @click.option( |
| "--output", "-o", |
| default=None, |
| help="Custom output path (overrides version-based path)", |
| ) |
| @click.option( |
| "--c1", |
| default=1.0, |
| type=float, |
| help="L1 regularization coefficient", |
| show_default=True, |
| ) |
| @click.option( |
| "--c2", |
| default=0.001, |
| type=float, |
| help="L2 regularization coefficient", |
| show_default=True, |
| ) |
| @click.option( |
| "--max-iterations", |
| default=100, |
| type=int, |
| help="Maximum training iterations", |
| show_default=True, |
| ) |
| @click.option( |
| "--wandb/--no-wandb", |
| default=False, |
| help="Enable Weights & Biases logging", |
| ) |
| def train(trainer, version, output, c1, c2, max_iterations, wandb): |
| """Train Vietnamese POS Tagger using CRF on UDD-1 dataset.""" |
| total_start_time = time.time() |
| start_datetime = datetime.now() |
|
|
| |
| crf_trainer = get_trainer(trainer) |
|
|
| |
| if version is None: |
| version = get_default_version() |
|
|
| |
| if output: |
| output_path = Path(output) |
| output_dir = output_path.parent |
| else: |
| output_dir = PROJECT_ROOT / "models" / "pos_tagger" / version |
| output_dir.mkdir(parents=True, exist_ok=True) |
| output_path = output_dir / "model.crfsuite" |
|
|
| |
| hw_info = get_hardware_info() |
|
|
| click.echo("=" * 60) |
| click.echo(f"POS Tagger Training - {version}") |
| click.echo("=" * 60) |
| click.echo(f"Trainer: {trainer}") |
| click.echo(f"Platform: {hw_info['platform']}") |
| click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}") |
| click.echo(f"Output: {output_path}") |
| click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") |
| click.echo("=" * 60) |
|
|
| train_data, val_data, test_data = load_data() |
|
|
| click.echo(f"\nTrain: {len(train_data)} sentences") |
| click.echo(f"Validation: {len(val_data)} sentences") |
| click.echo(f"Test: {len(test_data)} sentences") |
|
|
| |
| click.echo("\nExtracting features...") |
| feature_start = time.time() |
| X_train = [sentence_to_features(tokens) for tokens, _ in train_data] |
| y_train = [tags for _, tags in train_data] |
| click.echo(f"Feature extraction: {format_duration(time.time() - feature_start)}") |
|
|
| |
| click.echo(f"\nTraining CRF model with {trainer}...") |
|
|
| use_wandb = wandb |
| if use_wandb: |
| try: |
| import wandb as wb |
| wb.init(project="pos-tagger-vietnamese", name=f"crf-{trainer}-{version}") |
| wb.config.update({ |
| "trainer": trainer, |
| "c1": c1, |
| "c2": c2, |
| "max_iterations": max_iterations, |
| "num_features": len(FEATURE_TEMPLATES), |
| "train_sentences": len(train_data), |
| "val_sentences": len(val_data), |
| "test_sentences": len(test_data), |
| "version": version, |
| }) |
| except ImportError: |
| click.echo("wandb not installed, skipping logging", err=True) |
| use_wandb = False |
|
|
| crf_start = time.time() |
| crf_trainer.train(X_train, y_train, output_path, c1, c2, max_iterations, verbose=True) |
| crf_time = time.time() - crf_start |
| click.echo(f"\nModel saved to {output_path}") |
| click.echo(f"CRF training: {format_duration(crf_time)}") |
|
|
| |
| click.echo("\nEvaluating on test set...") |
| X_test = [sentence_to_features(tokens) for tokens, _ in test_data] |
| y_test = [tags for _, tags in test_data] |
|
|
| y_pred = crf_trainer.predict(output_path, X_test) |
|
|
| |
| y_test_flat = [tag for tags in y_test for tag in tags] |
| y_pred_flat = [tag for tags in y_pred for tag in tags] |
|
|
| accuracy = accuracy_score(y_test_flat, y_pred_flat) |
|
|
| total_time = time.time() - total_start_time |
|
|
| click.echo(f"\nAccuracy: {accuracy:.4f}") |
| click.echo("\nClassification Report:") |
| click.echo(classification_report(y_test_flat, y_pred_flat)) |
|
|
| |
| if not output: |
| save_metadata(output_dir, version, trainer, train_data, val_data, test_data, |
| c1, c2, max_iterations, accuracy, hw_info, total_time) |
|
|
| click.echo("\n" + "=" * 60) |
| click.echo("Training Summary") |
| click.echo("=" * 60) |
| click.echo(f"Trainer: {trainer}") |
| click.echo(f"Version: {version}") |
| click.echo(f"Model: {output_path}") |
| click.echo(f"Accuracy: {accuracy:.4f}") |
| click.echo(f"Total time: {format_duration(total_time)}") |
| click.echo("=" * 60) |
|
|
| if use_wandb: |
| wb.log({"accuracy": accuracy}) |
| wb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|