| """ |
| Training script for Vietnamese Word Segmentation using CRF with Hydra config. |
| |
| Supports 3 CRF trainers: |
| - python-crfsuite: Original Python bindings to CRFsuite |
| - crfsuite-rs: Rust bindings to CRFsuite (pip install crfsuite) |
| - underthesea-core: Underthesea's native Rust CRF implementation |
| |
| Uses BIO tagging at SYLLABLE level: |
| - B: Beginning of a word (first syllable) |
| - I: Inside a word (continuation syllables) |
| |
| Usage: |
| python src/train_word_segmentation.py |
| python src/train_word_segmentation.py --config-name=vlsp2013 |
| python src/train_word_segmentation.py --config-name=udd1 |
| python src/train_word_segmentation.py model.trainer=python-crfsuite |
| python src/train_word_segmentation.py model.c1=0.5 model.c2=0.01 |
| python src/train_word_segmentation.py model.features.trigram=false |
| |
| Feature ablation: |
| python src/train_word_segmentation.py model.features.bigram=false model.features.trigram=false |
| python src/train_word_segmentation.py model.features.type=false model.features.morphology=false |
| """ |
|
|
| import logging |
| import platform |
| import time |
| from abc import ABC, abstractmethod |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import hydra |
| import psutil |
| import yaml |
| from omegaconf import DictConfig, OmegaConf |
| from sklearn.metrics import accuracy_score, classification_report, f1_score |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| FEATURE_GROUPS = { |
| "form": ["S[0]", "S[0].lower"], |
| "type": ["S[0].istitle", "S[0].isupper", "S[0].isdigit", "S[0].ispunct", "S[0].len"], |
| "morphology": ["S[0].prefix2", "S[0].suffix2"], |
| "left": ["S[-1]", "S[-1].lower", "S[-2]", "S[-2].lower"], |
| "right": ["S[1]", "S[1].lower", "S[2]", "S[2].lower"], |
| "bigram": ["S[-1,0]", "S[0,1]"], |
| "trigram": ["S[-1,0,1]"], |
| "dictionary": ["S[-1,0].in_dict", "S[0,1].in_dict"], |
| } |
|
|
|
|
| def get_active_templates(features_cfg): |
| """Build active feature template list from config.""" |
| templates = [] |
| for group_name, group_templates in FEATURE_GROUPS.items(): |
| if features_cfg.get(group_name, True): |
| templates.extend(group_templates) |
| return templates |
|
|
|
|
| def get_active_groups(features_cfg): |
| """Return list of enabled group names.""" |
| return [g for g in FEATURE_GROUPS if features_cfg.get(g, True)] |
|
|
|
|
| |
| |
| |
|
|
| def get_hardware_info(): |
| """Collect hardware and system information.""" |
| info = { |
| "platform": platform.system(), |
| "platform_release": platform.release(), |
| "architecture": platform.machine(), |
| "python_version": platform.python_version(), |
| "cpu_physical_cores": psutil.cpu_count(logical=False), |
| "cpu_logical_cores": psutil.cpu_count(logical=True), |
| "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
| } |
|
|
| try: |
| if platform.system() == "Linux": |
| with open("/proc/cpuinfo", "r") as f: |
| for line in f: |
| if "model name" in line: |
| info["cpu_model"] = line.split(":")[1].strip() |
| break |
| except Exception: |
| info["cpu_model"] = "Unknown" |
|
|
| return info |
|
|
|
|
| def format_duration(seconds): |
| """Format duration in human-readable format.""" |
| if seconds < 60: |
| return f"{seconds:.2f}s" |
| elif seconds < 3600: |
| minutes = int(seconds // 60) |
| secs = seconds % 60 |
| return f"{minutes}m {secs:.2f}s" |
| else: |
| hours = int(seconds // 3600) |
| minutes = int((seconds % 3600) // 60) |
| secs = seconds % 60 |
| return f"{hours}h {minutes}m {secs:.2f}s" |
|
|
|
|
| |
| |
| |
|
|
| def build_word_dictionary(train_data, min_freq=1, min_syls=2): |
| """Build a set of multi-syllable words from training data. |
| |
| Extracts words with min_syls+ syllables from BIO-labeled training |
| sequences. Words must appear at least min_freq times to be included. |
| |
| Args: |
| train_data: List of (syllables, labels) tuples with BIO labels. |
| min_freq: Minimum frequency to include a word (default: 1). |
| min_syls: Minimum number of syllables (default: 2). |
| |
| Returns: |
| Set of lowercased multi-syllable words, e.g. {"chủ nghĩa", "hợp hiến"}. |
| """ |
| from collections import Counter |
| word_counts = Counter() |
| for syllables, labels in train_data: |
| current_word_syls = [] |
| for syl, label in zip(syllables, labels): |
| if label == "B": |
| if len(current_word_syls) >= min_syls: |
| word_counts[" ".join(current_word_syls).lower()] += 1 |
| current_word_syls = [syl] |
| else: |
| current_word_syls.append(syl) |
| if len(current_word_syls) >= min_syls: |
| word_counts[" ".join(current_word_syls).lower()] += 1 |
| return {word for word, count in word_counts.items() if count >= min_freq} |
|
|
|
|
| def load_external_dictionary(min_syls=2): |
| """Load Viet74K + UTS Dictionary from underthesea package (~64K multi-syl entries).""" |
| from underthesea.corpus.readers.dictionary_loader import DictionaryLoader |
| from underthesea.datasets import get_dictionary |
| dictionary = set() |
| for word in DictionaryLoader("Viet74K.txt").words: |
| w = word.lower().strip() |
| if len(w.split()) >= min_syls: |
| dictionary.add(w) |
| for word in get_dictionary(): |
| w = word.lower().strip() |
| if len(w.split()) >= min_syls: |
| dictionary.add(w) |
| return dictionary |
|
|
|
|
| def build_dictionary(train_data, source="external", min_syls=2): |
| """Build dictionary from configured source.""" |
| if source == "training": |
| return build_word_dictionary(train_data, min_freq=1, min_syls=min_syls) |
| elif source == "external": |
| return load_external_dictionary(min_syls=min_syls) |
| elif source == "combined": |
| return build_word_dictionary(train_data, min_freq=1, min_syls=min_syls) | load_external_dictionary(min_syls=min_syls) |
| raise ValueError(f"Unknown dictionary source: {source}") |
|
|
|
|
| def save_dictionary(dictionary, path): |
| """Save dictionary to a text file (one word per line).""" |
| with open(path, "w", encoding="utf-8") as f: |
| for word in sorted(dictionary): |
| f.write(word + "\n") |
|
|
|
|
| def load_dictionary(path): |
| """Load dictionary from a text file (one word per line).""" |
| dictionary = set() |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| dictionary.add(line) |
| return dictionary |
|
|
|
|
| |
| |
| |
|
|
| def get_syllable_at(syllables, position, offset): |
| """Get syllable at position + offset, with boundary handling.""" |
| idx = position + offset |
| if idx < 0: |
| return "__BOS__" |
| elif idx >= len(syllables): |
| return "__EOS__" |
| return syllables[idx] |
|
|
|
|
| def is_punct(s): |
| """Check if string is punctuation.""" |
| return len(s) == 1 and not s.isalnum() |
|
|
|
|
| def extract_syllable_features(syllables, position, active_templates, dictionary=None): |
| """Extract features for a syllable at given position.""" |
| active = set(active_templates) |
| features = {} |
|
|
| s0 = get_syllable_at(syllables, position, 0) |
| is_boundary = s0 in ("__BOS__", "__EOS__") |
|
|
| |
| if "S[0]" in active: |
| features["S[0]"] = s0 |
| if "S[0].lower" in active: |
| features["S[0].lower"] = s0.lower() if not is_boundary else s0 |
|
|
| |
| if "S[0].istitle" in active: |
| features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False" |
| if "S[0].isupper" in active: |
| features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False" |
| if "S[0].isdigit" in active: |
| features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False" |
| if "S[0].ispunct" in active: |
| features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False" |
| if "S[0].len" in active: |
| features["S[0].len"] = str(len(s0)) if not is_boundary else "0" |
|
|
| |
| if "S[0].prefix2" in active: |
| features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0 |
| if "S[0].suffix2" in active: |
| features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0 |
|
|
| |
| s_1 = get_syllable_at(syllables, position, -1) |
| s_2 = get_syllable_at(syllables, position, -2) |
| if "S[-1]" in active: |
| features["S[-1]"] = s_1 |
| if "S[-1].lower" in active: |
| features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1 |
| if "S[-2]" in active: |
| features["S[-2]"] = s_2 |
| if "S[-2].lower" in active: |
| features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2 |
|
|
| |
| s1 = get_syllable_at(syllables, position, 1) |
| s2 = get_syllable_at(syllables, position, 2) |
| if "S[1]" in active: |
| features["S[1]"] = s1 |
| if "S[1].lower" in active: |
| features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1 |
| if "S[2]" in active: |
| features["S[2]"] = s2 |
| if "S[2].lower" in active: |
| features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2 |
|
|
| |
| if "S[-1,0]" in active: |
| features["S[-1,0]"] = f"{s_1}|{s0}" |
| if "S[0,1]" in active: |
| features["S[0,1]"] = f"{s0}|{s1}" |
|
|
| |
| if "S[-1,0,1]" in active: |
| features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}" |
|
|
| |
| if dictionary is not None: |
| n = len(syllables) |
|
|
| |
| if "S[-1,0].in_dict" in active and position >= 1: |
| match = "" |
| for length in range(2, min(6, position + 2)): |
| start = position - length + 1 |
| if start >= 0: |
| ngram = " ".join(syllables[start:position + 1]).lower() |
| if ngram in dictionary: |
| match = ngram |
| features["S[-1,0].in_dict"] = match if match else "0" |
|
|
| |
| if "S[0,1].in_dict" in active and position < n - 1: |
| match = "" |
| for length in range(2, min(6, n - position + 1)): |
| ngram = " ".join(syllables[position:position + length]).lower() |
| if ngram in dictionary: |
| match = ngram |
| features["S[0,1].in_dict"] = match if match else "0" |
|
|
| return features |
|
|
|
|
| def sentence_to_syllable_features(syllables, active_templates, dictionary=None): |
| """Convert syllable sequence to feature sequences.""" |
| return [ |
| [f"{k}={v}" for k, v in extract_syllable_features(syllables, i, active_templates, dictionary).items()] |
| for i in range(len(syllables)) |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def tokens_to_syllable_labels(tokens, regex_tokenize): |
| """Convert tokenized compound words to syllable-level BIO labels.""" |
| syllables = [] |
| labels = [] |
|
|
| for token in tokens: |
| token_syllables = regex_tokenize(token) |
| for i, syl in enumerate(token_syllables): |
| syllables.append(syl) |
| labels.append("B" if i == 0 else "I") |
|
|
| return syllables, labels |
|
|
|
|
| def labels_to_words(syllables, labels): |
| """Convert syllable sequence and BIO labels back to words.""" |
| words = [] |
| current_word = [] |
|
|
| for syl, label in zip(syllables, labels): |
| if label == "B": |
| if current_word: |
| words.append(" ".join(current_word)) |
| current_word = [syl] |
| else: |
| current_word.append(syl) |
|
|
| if current_word: |
| words.append(" ".join(current_word)) |
|
|
| return words |
|
|
|
|
| def compute_word_metrics(y_true, y_pred, syllables_list): |
| """Compute word-level F1 score.""" |
| correct = 0 |
| total_pred = 0 |
| total_true = 0 |
|
|
| for syllables, true_labels, pred_labels in zip(syllables_list, y_true, y_pred): |
| true_words = labels_to_words(syllables, true_labels) |
| pred_words = labels_to_words(syllables, pred_labels) |
|
|
| total_true += len(true_words) |
| total_pred += len(pred_words) |
|
|
| true_boundaries = set() |
| pred_boundaries = set() |
|
|
| pos = 0 |
| for word in true_words: |
| n_syls = len(word.split()) |
| true_boundaries.add((pos, pos + n_syls)) |
| pos += n_syls |
|
|
| pos = 0 |
| for word in pred_words: |
| n_syls = len(word.split()) |
| pred_boundaries.add((pos, pos + n_syls)) |
| pos += n_syls |
|
|
| correct += len(true_boundaries & pred_boundaries) |
|
|
| precision = correct / total_pred if total_pred > 0 else 0 |
| recall = correct / total_true if total_true > 0 else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
| return precision, recall, f1 |
|
|
|
|
| |
| |
| |
|
|
| def load_data(cfg): |
| """Load dataset based on config.""" |
| if cfg.data.source == "local": |
| return load_data_vlsp2013(cfg) |
| elif cfg.data.source == "huggingface": |
| return load_data_udd1(cfg) |
| else: |
| raise ValueError(f"Unknown data source: {cfg.data.source}") |
|
|
|
|
| def load_data_udd1(cfg): |
| """Load UDD-1 dataset and convert to syllable-level sequences.""" |
| from datasets import load_dataset |
| from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize |
|
|
| log.info("Loading UDD-1 dataset...") |
| dataset = load_dataset(cfg.data.dataset) |
|
|
| def extract_syllable_sequences(split): |
| sequences = [] |
| for item in split: |
| tokens = item["tokens"] |
| if tokens: |
| syllables, labels = tokens_to_syllable_labels(tokens, regex_tokenize) |
| if syllables: |
| sequences.append((syllables, labels)) |
| return sequences |
|
|
| train_data = extract_syllable_sequences(dataset["train"]) |
| val_data = extract_syllable_sequences(dataset["validation"]) |
| test_data = extract_syllable_sequences(dataset["test"]) |
|
|
| train_syls = sum(len(syls) for syls, _ in train_data) |
| val_syls = sum(len(syls) for syls, _ in val_data) |
| test_syls = sum(len(syls) for syls, _ in test_data) |
|
|
| log.info(f"Loaded {len(train_data)} train ({train_syls} syllables), " |
| f"{len(val_data)} val ({val_syls} syllables), " |
| f"{len(test_data)} test ({test_syls} syllables) sentences") |
|
|
| return train_data, val_data, test_data, { |
| "dataset": cfg.data.dataset, |
| "train_sentences": len(train_data), |
| "train_syllables": train_syls, |
| "val_sentences": len(val_data), |
| "val_syllables": val_syls, |
| "test_sentences": len(test_data), |
| "test_syllables": test_syls, |
| } |
|
|
|
|
| def load_data_vlsp2013(cfg): |
| """Load VLSP 2013 WTK dataset (syllable-level BIO format).""" |
| log.info("Loading VLSP 2013 WTK dataset...") |
|
|
| dataset_dir = Path(cfg.data.data_dir) |
| tag_map = {"B-W": "B", "I-W": "I"} |
|
|
| def load_file(path): |
| sequences = [] |
| current_syls = [] |
| current_labels = [] |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| if current_syls: |
| sequences.append((current_syls, current_labels)) |
| current_syls = [] |
| current_labels = [] |
| else: |
| parts = line.split("\t") |
| if len(parts) == 2: |
| current_syls.append(parts[0]) |
| current_labels.append(tag_map.get(parts[1], parts[1])) |
| if current_syls: |
| sequences.append((current_syls, current_labels)) |
| return sequences |
|
|
| train_data = load_file(dataset_dir / "train.txt") |
| test_data = load_file(dataset_dir / "test.txt") |
|
|
| train_syls = sum(len(syls) for syls, _ in train_data) |
| test_syls = sum(len(syls) for syls, _ in test_data) |
|
|
| log.info(f"Loaded {len(train_data)} train ({train_syls} syllables), " |
| f"{len(test_data)} test ({test_syls} syllables) sentences") |
|
|
| return train_data, None, test_data, { |
| "dataset": "VLSP-2013-WTK", |
| "train_sentences": len(train_data), |
| "train_syllables": train_syls, |
| "val_sentences": 0, |
| "val_syllables": 0, |
| "test_sentences": len(test_data), |
| "test_syllables": test_syls, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class CRFTrainerBase(ABC): |
| """Abstract base class for CRF trainers.""" |
|
|
| name: str = "base" |
|
|
| @abstractmethod |
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| """Train the CRF model and save to output_path.""" |
| pass |
|
|
| @abstractmethod |
| def predict(self, model_path, X_test): |
| """Load model and predict on test data.""" |
| pass |
|
|
|
|
| class PythonCRFSuiteTrainer(CRFTrainerBase): |
| """Trainer using python-crfsuite (original Python bindings).""" |
|
|
| name = "python-crfsuite" |
|
|
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| import pycrfsuite |
|
|
| trainer = pycrfsuite.Trainer(verbose=verbose) |
|
|
| for xseq, yseq in zip(X_train, y_train): |
| trainer.append(xseq, yseq) |
|
|
| trainer.set_params({ |
| "c1": c1, |
| "c2": c2, |
| "max_iterations": max_iterations, |
| "feature.possible_transitions": True, |
| }) |
|
|
| trainer.train(str(output_path)) |
|
|
| def predict(self, model_path, X_test): |
| import pycrfsuite |
|
|
| tagger = pycrfsuite.Tagger() |
| tagger.open(str(model_path)) |
| return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
| class CRFSuiteRsTrainer(CRFTrainerBase): |
| """Trainer using crfsuite-rs (Rust bindings via pip install crfsuite).""" |
|
|
| name = "crfsuite-rs" |
|
|
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| import crfsuite |
|
|
| trainer = crfsuite.Trainer() |
|
|
| trainer.set_params({ |
| "c1": c1, |
| "c2": c2, |
| "max_iterations": max_iterations, |
| "feature.possible_transitions": True, |
| }) |
|
|
| for xseq, yseq in zip(X_train, y_train): |
| trainer.append(xseq, yseq) |
|
|
| trainer.train(str(output_path)) |
|
|
| def predict(self, model_path, X_test): |
| import crfsuite |
|
|
| model = crfsuite.Model(str(model_path)) |
| return [model.tag(xseq) for xseq in X_test] |
|
|
|
|
| class UndertheseaCoreTrainer(CRFTrainerBase): |
| """Trainer using underthesea-core native Rust CRF with LBFGS optimization.""" |
|
|
| name = "underthesea-core" |
|
|
| def _check_trainer_import(self): |
| try: |
| from underthesea_core import CRFTrainer |
| return CRFTrainer |
| except ImportError: |
| pass |
| try: |
| from underthesea_core.underthesea_core import CRFTrainer |
| return CRFTrainer |
| except ImportError: |
| pass |
| raise ImportError("CRFTrainer not available in underthesea_core.") |
|
|
| def _check_tagger_import(self): |
| try: |
| from underthesea_core import CRFModel, CRFTagger |
| return CRFModel, CRFTagger |
| except ImportError: |
| pass |
| try: |
| from underthesea_core.underthesea_core import CRFModel, CRFTagger |
| return CRFModel, CRFTagger |
| except ImportError: |
| pass |
| raise ImportError("CRFModel/CRFTagger not available in underthesea_core") |
|
|
| def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True): |
| CRFTrainer = self._check_trainer_import() |
|
|
| trainer = CRFTrainer( |
| loss_function="lbfgs", |
| l1_penalty=c1, |
| l2_penalty=c2, |
| max_iterations=max_iterations, |
| verbose=1 if verbose else 0, |
| ) |
|
|
| model = trainer.train(X_train, y_train) |
|
|
| output_path_str = str(output_path) |
| if output_path_str.endswith('.crfsuite'): |
| output_path_str = output_path_str.replace('.crfsuite', '.crf') |
| model.save(output_path_str) |
|
|
| self._model_path = output_path_str |
|
|
| def predict(self, model_path, X_test): |
| CRFModel, CRFTagger = self._check_tagger_import() |
|
|
| model_path_str = str(model_path) |
| if hasattr(self, '_model_path'): |
| model_path_str = self._model_path |
| elif model_path_str.endswith('.crfsuite'): |
| model_path_str = model_path_str.replace('.crfsuite', '.crf') |
|
|
| model = CRFModel.load(model_path_str) |
| tagger = CRFTagger.from_model(model) |
| return [tagger.tag(xseq) for xseq in X_test] |
|
|
|
|
| def get_trainer(trainer_name: str) -> CRFTrainerBase: |
| """Get trainer instance by name.""" |
| trainers = { |
| "python-crfsuite": PythonCRFSuiteTrainer, |
| "crfsuite-rs": CRFSuiteRsTrainer, |
| "underthesea-core": UndertheseaCoreTrainer, |
| } |
| if trainer_name not in trainers: |
| raise ValueError(f"Unknown trainer: {trainer_name}. Available: {list(trainers.keys())}") |
| return trainers[trainer_name]() |
|
|
|
|
| |
| |
| |
|
|
| def save_metadata(output_dir, cfg, data_stats, active_groups, active_templates, metrics, hw_info, training_time): |
| """Save model metadata to YAML file.""" |
| model_cfg = cfg.model |
| metadata = { |
| "model": { |
| "name": "Vietnamese Word Segmentation", |
| "type": "CRF (Conditional Random Field)", |
| "framework": model_cfg.trainer, |
| "tagging_scheme": "BIO", |
| }, |
| "training": { |
| "dataset": data_stats.get("dataset", "unknown"), |
| "train_sentences": data_stats["train_sentences"], |
| "train_syllables": data_stats["train_syllables"], |
| "val_sentences": data_stats["val_sentences"], |
| "val_syllables": data_stats["val_syllables"], |
| "test_sentences": data_stats["test_sentences"], |
| "test_syllables": data_stats["test_syllables"], |
| "hyperparameters": { |
| "c1": model_cfg.c1, |
| "c2": model_cfg.c2, |
| "max_iterations": model_cfg.max_iterations, |
| }, |
| "feature_groups": active_groups, |
| "num_feature_templates": len(active_templates), |
| "feature_templates": active_templates, |
| "duration_seconds": round(training_time, 2), |
| }, |
| "performance": { |
| "syllable_accuracy": round(metrics["syl_accuracy"], 4), |
| "syllable_f1": round(metrics["syl_f1"], 4), |
| "word_precision": round(metrics["word_precision"], 4), |
| "word_recall": round(metrics["word_recall"], 4), |
| "word_f1": round(metrics["word_f1"], 4), |
| }, |
| "environment": { |
| "platform": hw_info["platform"], |
| "cpu_model": hw_info.get("cpu_model", "Unknown"), |
| "python_version": hw_info["python_version"], |
| }, |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "author": "undertheseanlp", |
| } |
|
|
| metadata_path = output_dir / "metadata.yaml" |
| with open(metadata_path, "w") as f: |
| yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False) |
| log.info(f"Metadata saved to {metadata_path}") |
|
|
|
|
| |
| |
| |
|
|
| @hydra.main(version_base=None, config_path="conf", config_name="config") |
| def train(cfg: DictConfig): |
| """Train Vietnamese Word Segmenter using CRF.""" |
| total_start_time = time.time() |
| start_datetime = datetime.now() |
|
|
| log.info(f"Config:\n{OmegaConf.to_yaml(cfg)}") |
|
|
| model_cfg = cfg.model |
|
|
| |
| active_templates = get_active_templates(model_cfg.features) |
| active_groups = get_active_groups(model_cfg.features) |
|
|
| |
| crf_trainer = get_trainer(model_cfg.trainer) |
|
|
| |
| original_cwd = Path(hydra.utils.get_original_cwd()) |
| output_dir = original_cwd / cfg.output |
| output_dir.mkdir(parents=True, exist_ok=True) |
| output_path = output_dir / "model.crfsuite" |
|
|
| |
| hw_info = get_hardware_info() |
|
|
| log.info("=" * 60) |
| log.info(f"Word Segmentation Training") |
| log.info("=" * 60) |
| log.info(f"Dataset: {cfg.data.name}") |
| log.info(f"Trainer: {model_cfg.trainer}") |
| log.info(f"Features: {active_groups} ({len(active_templates)} templates)") |
| log.info(f"Platform: {hw_info['platform']}") |
| log.info(f"CPU: {hw_info.get('cpu_model', 'Unknown')}") |
| log.info(f"Output: {output_dir}") |
| log.info(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") |
| log.info("=" * 60) |
|
|
| |
| train_data, val_data, test_data, data_stats = load_data(cfg) |
|
|
| log.info(f"Train: {len(train_data)} sentences ({data_stats['train_syllables']} syllables)") |
| if val_data: |
| log.info(f"Validation: {len(val_data)} sentences ({data_stats['val_syllables']} syllables)") |
| log.info(f"Test: {len(test_data)} sentences ({data_stats['test_syllables']} syllables)") |
|
|
| |
| dictionary = None |
| if model_cfg.features.get("dictionary", True): |
| dict_source = model_cfg.features.get("dictionary_source", "external") |
| log.info(f"Building dictionary (source={dict_source})...") |
| dictionary = build_dictionary(train_data, source=dict_source) |
| log.info(f"Dictionary: {len(dictionary)} multi-syllable words") |
| save_dictionary(dictionary, output_dir / "dictionary.txt") |
| log.info(f"Dictionary saved to {output_dir / 'dictionary.txt'}") |
|
|
| |
| log.info("Extracting syllable-level features...") |
| feature_start = time.time() |
| X_train = [sentence_to_syllable_features(syls, active_templates, dictionary) for syls, _ in train_data] |
| y_train = [labels for _, labels in train_data] |
| log.info(f"Feature extraction: {format_duration(time.time() - feature_start)}") |
|
|
| |
| log.info(f"Training CRF model with {model_cfg.trainer}...") |
| crf_start = time.time() |
| crf_trainer.train( |
| X_train, y_train, output_path, |
| model_cfg.c1, model_cfg.c2, model_cfg.max_iterations, |
| verbose=True, |
| ) |
| crf_time = time.time() - crf_start |
| log.info(f"Model saved to {output_path}") |
| log.info(f"CRF training: {format_duration(crf_time)}") |
|
|
| |
| log.info("Evaluating on test set...") |
|
|
| X_test = [sentence_to_syllable_features(syls, active_templates, dictionary) for syls, _ in test_data] |
| y_test = [labels for _, labels in test_data] |
| syllables_test = [syls for syls, _ in test_data] |
|
|
| y_pred = crf_trainer.predict(output_path, X_test) |
|
|
| |
| y_test_flat = [label for labels in y_test for label in labels] |
| y_pred_flat = [label for labels in y_pred for label in labels] |
|
|
| syl_accuracy = accuracy_score(y_test_flat, y_pred_flat) |
| syl_f1 = f1_score(y_test_flat, y_pred_flat, average="weighted") |
|
|
| log.info(f"Syllable-level Accuracy: {syl_accuracy:.4f}") |
| log.info(f"Syllable-level F1 (weighted): {syl_f1:.4f}") |
| log.info(f"Syllable-level Classification Report:\n{classification_report(y_test_flat, y_pred_flat)}") |
|
|
| |
| precision, recall, word_f1 = compute_word_metrics(y_test, y_pred, syllables_test) |
| log.info(f"Word-level Precision: {precision:.4f}") |
| log.info(f"Word-level Recall: {recall:.4f}") |
| log.info(f"Word-level F1: {word_f1:.4f}") |
|
|
| total_time = time.time() - total_start_time |
|
|
| metrics = { |
| "syl_accuracy": syl_accuracy, |
| "syl_f1": syl_f1, |
| "word_precision": precision, |
| "word_recall": recall, |
| "word_f1": word_f1, |
| } |
|
|
| |
| save_metadata(output_dir, cfg, data_stats, active_groups, active_templates, metrics, hw_info, total_time) |
|
|
| |
| log.info("=" * 60) |
| log.info("Example predictions:") |
| log.info("=" * 60) |
| for i in range(min(3, len(test_data))): |
| syllables = syllables_test[i] |
| true_words = labels_to_words(syllables, y_test[i]) |
| pred_words = labels_to_words(syllables, y_pred[i]) |
| log.info(f"Input: {' '.join(syllables)}") |
| log.info(f"True: {' | '.join(true_words)}") |
| log.info(f"Pred: {' | '.join(pred_words)}") |
|
|
| log.info("=" * 60) |
| log.info("Training Summary") |
| log.info("=" * 60) |
| log.info(f"Dataset: {cfg.data.name}") |
| log.info(f"Trainer: {model_cfg.trainer}") |
| log.info(f"Features: {active_groups} ({len(active_templates)} templates)") |
| log.info(f"Model: {output_path}") |
| log.info(f"Syllable Accuracy: {syl_accuracy:.4f}") |
| log.info(f"Word F1: {word_f1:.4f}") |
| log.info(f"Total time: {format_duration(total_time)}") |
| log.info("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|