rain1024 commited on
Commit
1489e5c
·
1 Parent(s): c56099b

Add word segmentation support and underthesea-core integration

Browse files

- Update handler.py to support both pycrfsuite and underthesea-core formats
- Add word segmentation training and prediction scripts
- Add training configurations (configs/pos_tagger.yaml, configs/word_segmentation.yaml)
- Update training scripts with multi-trainer support
- Update CLAUDE.md with folder structure and word segmentation docs

.gitignore CHANGED
@@ -30,3 +30,5 @@ per_tag_metrics.png
30
  # Logs
31
  *.log
32
  wandb/
 
 
 
30
  # Logs
31
  *.log
32
  wandb/
33
+
34
+ models
CLAUDE.md CHANGED
@@ -4,14 +4,48 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
4
 
5
  ## Project Overview
6
 
7
- Vietnamese POS Tagger (TRE-1) - a CRF-based Part-of-Speech tagger for Vietnamese, deployed on Hugging Face at [undertheseanlp/tre-1](https://huggingface.co/undertheseanlp/tre-1). Uses python-crfsuite with 27 handcrafted feature templates. Trained on UDD-v0.1 dataset (80/20 train/test split, random_state=42). Achieves 95.57% accuracy.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  ## Running the Model
10
 
11
  **Local inference:**
12
  ```python
13
  from handler import EndpointHandler
14
- handler = EndpointHandler(path="./")
15
  result = handler({"inputs": "Tôi yêu Việt Nam"})
16
  ```
17
 
@@ -21,40 +55,67 @@ result = handler({"inputs": "Tôi yêu Việt Nam"})
21
 
22
  Scripts use inline script metadata (PEP 723) - no separate requirements file needed.
23
 
 
 
24
  ```bash
25
- # Train model from scratch
26
  uv run scripts/train.py
27
 
28
- # Train with custom output path and W&B logging
29
- uv run scripts/train.py --output model.crfsuite --wandb
 
 
 
 
 
 
30
 
31
- # Evaluate trained model
32
- uv run scripts/evaluate.py --model pos_tagger.crfsuite
33
 
34
- # Evaluate with confusion matrix and per-tag plots
35
- uv run scripts/evaluate.py --model pos_tagger.crfsuite --save-plots
36
 
37
- # Inference (formats: inline, json, conll)
38
  uv run scripts/predict.py "Tôi yêu Việt Nam"
39
- uv run scripts/predict.py --format json "Hà Nội là thủ đô"
40
- echo "Học sinh đang học bài" | uv run scripts/predict.py -
 
 
 
 
 
 
 
 
 
 
 
 
41
  ```
42
 
43
  ## Architecture
44
 
45
  Single-file implementation (`handler.py`) following Hugging Face Custom Handler pattern:
46
 
47
- - **PythonCRFFeaturizer**: Extracts 27 linguistic features per token (word form, case, prefix/suffix, context windows, dictionary lookups)
48
  - **EndpointHandler**: Hugging Face API entry point - loads CRF model, handles tokenization and inference
49
- - **pos_tagger.crfsuite**: Binary CRF model (Git LFS tracked)
50
 
51
  **Data flow:** Input text → whitespace tokenization → feature extraction → CRF prediction → `[{"token": "...", "tag": "..."}]`
52
 
53
  ## Key Constraints
54
 
55
  - Input must be pre-tokenized (whitespace-separated Vietnamese tokens)
56
- - No word segmentation - expects already segmented Vietnamese text
57
  - Feature template syntax: `T[index].attribute` (e.g., `T[-1].lower`, `T[0,1].is_in_dict`)
58
- - Predicts 15 Universal POS tags: ADJ, ADP, ADV, AUX, CCONJ, DET, NOUN, NUM, PART, PRON, PROPN, PUNCT, SCONJ, VERB, X
 
59
  - CRF training params: c1=1.0 (L1), c2=0.001 (L2), max_iterations=100
60
- - Model trained on legal documents domain (UDD-v0.1) - may underperform on casual/social text
 
 
 
 
 
 
 
 
 
4
 
5
  ## Project Overview
6
 
7
+ Vietnamese NLP Models (TRE-1) - CRF-based models for Vietnamese NLP tasks, deployed on Hugging Face at [undertheseanlp/tre-1](https://huggingface.co/undertheseanlp/tre-1). Includes:
8
+ - **POS Tagger**: 27 handcrafted feature templates, predicts 15 Universal POS tags
9
+ - **Word Segmentation**: BIO tagging at syllable level, 21 feature templates
10
+
11
+ Trained on UDD-1 dataset from Hugging Face.
12
+
13
+ ## Folder Structure
14
+
15
+ ```
16
+ tre-1/
17
+ ├── models/ # Trained models (versioned by timestamp)
18
+ │ ├── pos_tagger/
19
+ │ │ └── 20260131_154530/ # YYYYMMDD_HHMMSS format
20
+ │ │ ├── model.crfsuite
21
+ │ │ └── metadata.yaml
22
+ │ └── word_segmentation/
23
+ │ └── 20260131_154530/
24
+ │ ├── model.crfsuite
25
+ │ └── metadata.yaml
26
+ ├── configs/ # Training configurations
27
+ │ ├── pos_tagger.yaml
28
+ │ └── word_segmentation.yaml
29
+ ├── results/ # Evaluation outputs (plots, metrics)
30
+ │ ├── pos_tagger/
31
+ │ └── word_segmentation/
32
+ ├── scripts/ # Training, evaluation, inference scripts
33
+ │ ├── train.py
34
+ │ ├── train_word_segmentation.py
35
+ │ ├── evaluate.py
36
+ │ ├── predict.py
37
+ │ └── predict_word_segmentation.py
38
+ ├── handler.py # Hugging Face Custom Handler
39
+ ├── pos_tagger.crfsuite # Legacy model (for HF deployment)
40
+ └── CLAUDE.md
41
+ ```
42
 
43
  ## Running the Model
44
 
45
  **Local inference:**
46
  ```python
47
  from handler import EndpointHandler
48
+ handler = EndpointHandler(path="models/pos_tagger/20260131_000000")
49
  result = handler({"inputs": "Tôi yêu Việt Nam"})
50
  ```
51
 
 
55
 
56
  Scripts use inline script metadata (PEP 723) - no separate requirements file needed.
57
 
58
+ ### POS Tagger
59
+
60
  ```bash
61
+ # Train model (auto-generates timestamp version, e.g., 20260131_154530)
62
  uv run scripts/train.py
63
 
64
+ # Train with custom version name
65
+ uv run scripts/train.py --version my_experiment
66
+
67
+ # Train with W&B logging
68
+ uv run scripts/train.py --wandb
69
+
70
+ # Evaluate latest model
71
+ uv run scripts/evaluate.py
72
 
73
+ # Evaluate specific version
74
+ uv run scripts/evaluate.py --version 20260131_000000
75
 
76
+ # Evaluate with plots (saves to results/pos_tagger/)
77
+ uv run scripts/evaluate.py --save-plots
78
 
79
+ # Inference (uses latest model by default)
80
  uv run scripts/predict.py "Tôi yêu Việt Nam"
81
+ uv run scripts/predict.py --version 20260131_000000 --format json "Hà Nội là thủ đô"
82
+ ```
83
+
84
+ ### Word Segmentation
85
+
86
+ ```bash
87
+ # Train model (auto-generates timestamp version)
88
+ uv run scripts/train_word_segmentation.py
89
+
90
+ # Train with custom version name
91
+ uv run scripts/train_word_segmentation.py --version my_experiment
92
+
93
+ # Inference
94
+ uv run scripts/predict_word_segmentation.py "Tôi yêu Việt Nam"
95
  ```
96
 
97
  ## Architecture
98
 
99
  Single-file implementation (`handler.py`) following Hugging Face Custom Handler pattern:
100
 
101
+ - **PythonCRFFeaturizer**: Extracts linguistic features per token (word form, case, prefix/suffix, context windows, dictionary lookups)
102
  - **EndpointHandler**: Hugging Face API entry point - loads CRF model, handles tokenization and inference
 
103
 
104
  **Data flow:** Input text → whitespace tokenization → feature extraction → CRF prediction → `[{"token": "...", "tag": "..."}]`
105
 
106
  ## Key Constraints
107
 
108
  - Input must be pre-tokenized (whitespace-separated Vietnamese tokens)
 
109
  - Feature template syntax: `T[index].attribute` (e.g., `T[-1].lower`, `T[0,1].is_in_dict`)
110
+ - POS Tagger predicts 15 Universal POS tags: ADJ, ADP, ADV, AUX, CCONJ, DET, NOUN, NUM, PART, PRON, PROPN, PUNCT, SCONJ, VERB, X
111
+ - Word Segmentation uses BIO tagging: B (beginning), I (inside)
112
  - CRF training params: c1=1.0 (L1), c2=0.001 (L2), max_iterations=100
113
+
114
+ ## Model Versioning
115
+
116
+ Models use timestamp-based versioning (`YYYYMMDD_HHMMSS`):
117
+ - Each version has its own directory under `models/{task}/{timestamp}/`
118
+ - Auto-generated when training without `--version` flag
119
+ - Scripts default to **latest** version (sorted alphabetically)
120
+ - `metadata.yaml` contains training info, hyperparameters, and performance metrics
121
+ - `configs/` stores reusable training configurations
configs/pos_tagger.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # POS Tagger Training Configuration
2
+ # Dataset: UDD-1 from Hugging Face
3
+
4
+ model:
5
+ name: pos_tagger
6
+ type: crf
7
+ version: v1.0.0
8
+
9
+ training:
10
+ c1: 1.0 # L1 regularization coefficient
11
+ c2: 0.001 # L2 regularization coefficient
12
+ max_iterations: 100
13
+ feature_possible_transitions: true
14
+
15
+ data:
16
+ dataset: undertheseanlp/UDD-1
17
+ train_split: train
18
+ val_split: validation
19
+ test_split: test
20
+
21
+ features:
22
+ num_templates: 27
23
+ templates:
24
+ - T[0]
25
+ - T[0].lower
26
+ - T[0].istitle
27
+ - T[0].isupper
28
+ - T[0].isdigit
29
+ - T[0].isalpha
30
+ - T[0].prefix2
31
+ - T[0].prefix3
32
+ - T[0].suffix2
33
+ - T[0].suffix3
34
+ - T[-1]
35
+ - T[-1].lower
36
+ - T[-1].istitle
37
+ - T[-1].isupper
38
+ - T[-2]
39
+ - T[-2].lower
40
+ - T[1]
41
+ - T[1].lower
42
+ - T[1].istitle
43
+ - T[1].isupper
44
+ - T[2]
45
+ - T[2].lower
46
+ - T[-1,0]
47
+ - T[0,1]
48
+ - T[0].is_in_dict
49
+ - T[-1,0].is_in_dict
50
+ - T[0,1].is_in_dict
51
+
52
+ output:
53
+ model_dir: models/pos_tagger
54
+ results_dir: results/pos_tagger
configs/word_segmentation.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Word Segmentation Training Configuration
2
+ # Dataset: UDD-1 from Hugging Face
3
+
4
+ model:
5
+ name: word_segmentation
6
+ type: crf
7
+ version: v1.0.0
8
+ tagging_scheme: BIO # B=Beginning, I=Inside
9
+
10
+ training:
11
+ c1: 1.0 # L1 regularization coefficient
12
+ c2: 0.001 # L2 regularization coefficient
13
+ max_iterations: 100
14
+ feature_possible_transitions: true
15
+
16
+ data:
17
+ dataset: undertheseanlp/UDD-1
18
+ train_split: train
19
+ val_split: validation
20
+ test_split: test
21
+ preprocessing: underthesea.regex_tokenize # Syllable splitting
22
+
23
+ features:
24
+ num_templates: 21
25
+ templates:
26
+ - S[0]
27
+ - S[0].lower
28
+ - S[0].istitle
29
+ - S[0].isupper
30
+ - S[0].isdigit
31
+ - S[0].ispunct
32
+ - S[0].len
33
+ - S[0].prefix2
34
+ - S[0].suffix2
35
+ - S[-1]
36
+ - S[-1].lower
37
+ - S[-2]
38
+ - S[-2].lower
39
+ - S[1]
40
+ - S[1].lower
41
+ - S[2]
42
+ - S[2].lower
43
+ - S[-1,0]
44
+ - S[0,1]
45
+ - S[-1,0,1]
46
+
47
+ output:
48
+ model_dir: models/word_segmentation
49
+ results_dir: results/word_segmentation
handler.py CHANGED
@@ -1,11 +1,32 @@
1
  """
2
  Custom handler for Vietnamese POS Tagger inference on Hugging Face.
 
 
 
 
3
  """
4
 
 
5
  import re
6
- import pycrfsuite
7
  from typing import Dict, List, Any
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class PythonCRFFeaturizer:
11
  """
@@ -101,10 +122,39 @@ class EndpointHandler:
101
 
102
  self.featurizer = PythonCRFFeaturizer(self.feature_templates)
103
 
104
- # Load CRF model
105
- model_path = os.path.join(path, "pos_tagger.crfsuite")
106
- self.tagger = pycrfsuite.Tagger()
107
- self.tagger.open(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def _tokenize(self, text: str) -> List[str]:
110
  """Simple whitespace tokenization."""
 
1
  """
2
  Custom handler for Vietnamese POS Tagger inference on Hugging Face.
3
+
4
+ Supports two model formats:
5
+ - CRFsuite format (.crfsuite) - loaded with pycrfsuite
6
+ - underthesea-core format (.crf) - loaded with underthesea_core
7
  """
8
 
9
+ import os
10
  import re
 
11
  from typing import Dict, List, Any
12
 
13
+ # Try importing both taggers
14
+ try:
15
+ import pycrfsuite
16
+ HAS_PYCRFSUITE = True
17
+ except ImportError:
18
+ HAS_PYCRFSUITE = False
19
+
20
+ try:
21
+ from underthesea_core import CRFModel, CRFTagger
22
+ HAS_UNDERTHESEA_CORE = True
23
+ except ImportError:
24
+ try:
25
+ from underthesea_core.underthesea_core import CRFModel, CRFTagger
26
+ HAS_UNDERTHESEA_CORE = True
27
+ except ImportError:
28
+ HAS_UNDERTHESEA_CORE = False
29
+
30
 
31
  class PythonCRFFeaturizer:
32
  """
 
122
 
123
  self.featurizer = PythonCRFFeaturizer(self.feature_templates)
124
 
125
+ # Load CRF model - check multiple possible locations and formats
126
+ # Priority: .crfsuite (pycrfsuite) > .crf (underthesea-core)
127
+ model_candidates = [
128
+ (os.path.join(path, "model.crfsuite"), "pycrfsuite"),
129
+ (os.path.join(path, "pos_tagger.crfsuite"), "pycrfsuite"),
130
+ (os.path.join(path, "model.crf"), "underthesea-core"),
131
+ ]
132
+
133
+ model_path = None
134
+ model_format = None
135
+ for candidate, fmt in model_candidates:
136
+ if os.path.exists(candidate):
137
+ model_path = candidate
138
+ model_format = fmt
139
+ break
140
+
141
+ if model_path is None:
142
+ raise FileNotFoundError(
143
+ f"No model found. Checked: {[c for c, _ in model_candidates]}"
144
+ )
145
+
146
+ # Load model based on format
147
+ self.model_format = model_format
148
+ if model_format == "pycrfsuite":
149
+ if not HAS_PYCRFSUITE:
150
+ raise ImportError("pycrfsuite not installed. Install with: pip install python-crfsuite")
151
+ self.tagger = pycrfsuite.Tagger()
152
+ self.tagger.open(model_path)
153
+ elif model_format == "underthesea-core":
154
+ if not HAS_UNDERTHESEA_CORE:
155
+ raise ImportError("underthesea-core not installed")
156
+ model = CRFModel.load(model_path)
157
+ self.tagger = CRFTagger.from_model(model)
158
 
159
  def _tokenize(self, text: str) -> List[str]:
160
  """Simple whitespace tokenization."""
scripts/evaluate.py CHANGED
@@ -6,24 +6,29 @@
6
  # "scikit-learn>=1.6.1",
7
  # "matplotlib>=3.5.0",
8
  # "seaborn>=0.12.0",
 
9
  # ]
10
  # ///
11
  """
12
  Evaluation script for Vietnamese POS Tagger (TRE-1).
13
 
14
- Generates detailed metrics, confusion matrix, and visualizations
15
- as described in TECHNICAL_REPORT.md.
16
-
17
  Usage:
18
  uv run scripts/evaluate.py
19
- uv run scripts/evaluate.py --model pos_tagger.crfsuite
20
- uv run scripts/evaluate.py --save-plots # Save confusion matrix and charts
 
21
  """
22
 
23
- import argparse
 
 
 
 
24
  import pycrfsuite
25
  from datasets import load_dataset
26
- from sklearn.model_selection import train_test_split
 
 
27
  from sklearn.metrics import (
28
  accuracy_score,
29
  precision_recall_fscore_support,
@@ -83,7 +88,6 @@ def apply_attribute(value, attribute, dictionary=None):
83
 
84
 
85
  def parse_template(template):
86
- import re
87
  match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template)
88
  if not match:
89
  return None, None
@@ -122,20 +126,18 @@ def sentence_to_features(tokens):
122
 
123
 
124
  def load_test_data():
125
- print("Loading UDD-v0.1 dataset...")
126
- dataset = load_dataset("undertheseanlp/UDD-v0.1")
127
 
128
  sentences = []
129
- for item in dataset["train"]:
130
  tokens = item["tokens"]
131
  tags = item["upos"]
132
  if tokens and tags:
133
  sentences.append((tokens, tags))
134
 
135
- # Use same split as training
136
- _, test_data = train_test_split(sentences, test_size=0.2, random_state=42)
137
- print(f"Test set: {len(test_data)} sentences")
138
- return test_data
139
 
140
 
141
  def plot_confusion_matrix(y_true, y_pred, labels, output_path):
@@ -159,13 +161,12 @@ def plot_confusion_matrix(y_true, y_pred, labels, output_path):
159
  plt.tight_layout()
160
  plt.savefig(output_path, dpi=150)
161
  plt.close()
162
- print(f"Confusion matrix saved to {output_path}")
163
 
164
 
165
  def plot_per_tag_metrics(report_dict, output_path):
166
  import matplotlib.pyplot as plt
167
 
168
- # Filter out aggregate metrics
169
  tags = [k for k in report_dict.keys() if k not in ("accuracy", "macro avg", "weighted avg")]
170
 
171
  precision = [report_dict[t]["precision"] for t in tags]
@@ -192,13 +193,11 @@ def plot_per_tag_metrics(report_dict, output_path):
192
  plt.tight_layout()
193
  plt.savefig(output_path, dpi=150)
194
  plt.close()
195
- print(f"Per-tag metrics saved to {output_path}")
196
 
197
 
198
  def analyze_errors(y_true, y_pred, tokens_flat, top_n=10):
199
  """Analyze common error patterns."""
200
- from collections import Counter
201
-
202
  errors = Counter()
203
  error_examples = {}
204
 
@@ -209,24 +208,69 @@ def analyze_errors(y_true, y_pred, tokens_flat, top_n=10):
209
  if key not in error_examples:
210
  error_examples[key] = token
211
 
212
- print(f"\nTop {top_n} Error Patterns:")
213
- print("-" * 60)
214
- print(f"{'True':<10} {'Predicted':<10} {'Count':<8} {'Example'}")
215
- print("-" * 60)
216
 
217
  for (true, pred), count in errors.most_common(top_n):
218
  example = error_examples.get((true, pred), "")
219
- print(f"{true:<10} {pred:<10} {count:<8} {example}")
220
 
221
 
222
- def evaluate(model_path, save_plots=False):
223
- print(f"Loading model from {model_path}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  tagger = pycrfsuite.Tagger()
225
- tagger.open(model_path)
226
 
227
  test_data = load_test_data()
228
 
229
- print("Extracting features and predicting...")
230
  X_test = [sentence_to_features(tokens) for tokens, _ in test_data]
231
  y_test = [tags for _, tags in test_data]
232
  tokens_test = [tokens for tokens, _ in test_data]
@@ -246,70 +290,54 @@ def evaluate(model_path, save_plots=False):
246
  precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
247
  y_test_flat, y_pred_flat, average="macro"
248
  )
249
- precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
250
  y_test_flat, y_pred_flat, average="weighted"
251
  )
252
 
253
- print("\n" + "=" * 60)
254
- print("EVALUATION RESULTS")
255
- print("=" * 60)
256
 
257
- print("\nOverall Metrics:")
258
- print(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
259
- print(f" Precision (macro): {precision_macro:.4f}")
260
- print(f" Recall (macro): {recall_macro:.4f}")
261
- print(f" F1 (macro): {f1_macro:.4f}")
262
- print(f" F1 (weighted): {f1_weighted:.4f}")
263
 
264
- print("\nPer-Tag Classification Report:")
265
  report = classification_report(y_test_flat, y_pred_flat, digits=4)
266
- print(report)
267
 
268
  # Error analysis
269
  analyze_errors(y_test_flat, y_pred_flat, tokens_flat)
270
 
271
  # Dataset statistics
272
- from collections import Counter
273
  tag_counts = Counter(y_test_flat)
274
  total_tokens = len(y_test_flat)
275
 
276
- print("\nTest Set Tag Distribution:")
277
- print("-" * 40)
278
  for tag in labels:
279
  count = tag_counts[tag]
280
  pct = count / total_tokens * 100
281
- print(f" {tag:<8} {count:>6} ({pct:>5.2f}%)")
282
 
283
  if save_plots:
 
284
  plot_confusion_matrix(
285
  y_test_flat, y_pred_flat, labels,
286
- "confusion_matrix.png"
287
  )
288
 
289
  report_dict = classification_report(
290
  y_test_flat, y_pred_flat, output_dict=True
291
  )
292
- plot_per_tag_metrics(report_dict, "per_tag_metrics.png")
 
293
 
294
  return accuracy
295
 
296
 
297
- def main():
298
- parser = argparse.ArgumentParser(description="Evaluate Vietnamese POS Tagger")
299
- parser.add_argument(
300
- "--model", "-m",
301
- default="pos_tagger.crfsuite",
302
- help="Path to trained model"
303
- )
304
- parser.add_argument(
305
- "--save-plots",
306
- action="store_true",
307
- help="Save confusion matrix and per-tag metrics plots"
308
- )
309
- args = parser.parse_args()
310
-
311
- evaluate(args.model, save_plots=args.save_plots)
312
-
313
-
314
  if __name__ == "__main__":
315
- main()
 
6
  # "scikit-learn>=1.6.1",
7
  # "matplotlib>=3.5.0",
8
  # "seaborn>=0.12.0",
9
+ # "click>=8.0.0",
10
  # ]
11
  # ///
12
  """
13
  Evaluation script for Vietnamese POS Tagger (TRE-1).
14
 
 
 
 
15
  Usage:
16
  uv run scripts/evaluate.py
17
+ uv run scripts/evaluate.py --version v1.0.0
18
+ uv run scripts/evaluate.py --model models/pos_tagger/v1.0.0/model.crfsuite
19
+ uv run scripts/evaluate.py --save-plots
20
  """
21
 
22
+ import re
23
+ from collections import Counter
24
+ from pathlib import Path
25
+
26
+ import click
27
  import pycrfsuite
28
  from datasets import load_dataset
29
+
30
+ # Get project root directory
31
+ PROJECT_ROOT = Path(__file__).parent.parent
32
  from sklearn.metrics import (
33
  accuracy_score,
34
  precision_recall_fscore_support,
 
88
 
89
 
90
  def parse_template(template):
 
91
  match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template)
92
  if not match:
93
  return None, None
 
126
 
127
 
128
  def load_test_data():
129
+ click.echo("Loading UDD-1 dataset...")
130
+ dataset = load_dataset("undertheseanlp/UDD-1")
131
 
132
  sentences = []
133
+ for item in dataset["test"]:
134
  tokens = item["tokens"]
135
  tags = item["upos"]
136
  if tokens and tags:
137
  sentences.append((tokens, tags))
138
 
139
+ click.echo(f"Test set: {len(sentences)} sentences")
140
+ return sentences
 
 
141
 
142
 
143
  def plot_confusion_matrix(y_true, y_pred, labels, output_path):
 
161
  plt.tight_layout()
162
  plt.savefig(output_path, dpi=150)
163
  plt.close()
164
+ click.echo(f"Confusion matrix saved to {output_path}")
165
 
166
 
167
  def plot_per_tag_metrics(report_dict, output_path):
168
  import matplotlib.pyplot as plt
169
 
 
170
  tags = [k for k in report_dict.keys() if k not in ("accuracy", "macro avg", "weighted avg")]
171
 
172
  precision = [report_dict[t]["precision"] for t in tags]
 
193
  plt.tight_layout()
194
  plt.savefig(output_path, dpi=150)
195
  plt.close()
196
+ click.echo(f"Per-tag metrics saved to {output_path}")
197
 
198
 
199
  def analyze_errors(y_true, y_pred, tokens_flat, top_n=10):
200
  """Analyze common error patterns."""
 
 
201
  errors = Counter()
202
  error_examples = {}
203
 
 
208
  if key not in error_examples:
209
  error_examples[key] = token
210
 
211
+ click.echo(f"\nTop {top_n} Error Patterns:")
212
+ click.echo("-" * 60)
213
+ click.echo(f"{'True':<10} {'Predicted':<10} {'Count':<8} {'Example'}")
214
+ click.echo("-" * 60)
215
 
216
  for (true, pred), count in errors.most_common(top_n):
217
  example = error_examples.get((true, pred), "")
218
+ click.echo(f"{true:<10} {pred:<10} {count:<8} {example}")
219
 
220
 
221
+ def get_latest_version(task="pos_tagger"):
222
+ """Get the latest model version (sorted by timestamp)."""
223
+ models_dir = PROJECT_ROOT / "models" / task
224
+ if not models_dir.exists():
225
+ return None
226
+ versions = [d.name for d in models_dir.iterdir() if d.is_dir()]
227
+ if not versions:
228
+ return None
229
+ return sorted(versions)[-1] # Latest timestamp
230
+
231
+
232
+ @click.command()
233
+ @click.option(
234
+ "--version", "-v",
235
+ default=None,
236
+ help="Model version to evaluate (default: latest)",
237
+ )
238
+ @click.option(
239
+ "--model", "-m",
240
+ default=None,
241
+ help="Custom model path (overrides version-based path)",
242
+ )
243
+ @click.option(
244
+ "--save-plots",
245
+ is_flag=True,
246
+ help="Save confusion matrix and per-tag metrics plots",
247
+ )
248
+ def evaluate(version, model, save_plots):
249
+ """Evaluate Vietnamese POS Tagger on UDD-1 test set."""
250
+ # Use latest version if not specified
251
+ if version is None and model is None:
252
+ version = get_latest_version("pos_tagger")
253
+ if version is None:
254
+ raise click.ClickException("No models found in models/pos_tagger/")
255
+
256
+ # Determine model path
257
+ if model:
258
+ model_path = Path(model)
259
+ else:
260
+ model_path = PROJECT_ROOT / "models" / "pos_tagger" / version / "model.crfsuite"
261
+
262
+ # Determine output directory for plots
263
+ if save_plots:
264
+ results_dir = PROJECT_ROOT / "results" / "pos_tagger"
265
+ results_dir.mkdir(parents=True, exist_ok=True)
266
+
267
+ click.echo(f"Loading model from {model_path}...")
268
  tagger = pycrfsuite.Tagger()
269
+ tagger.open(str(model_path))
270
 
271
  test_data = load_test_data()
272
 
273
+ click.echo("Extracting features and predicting...")
274
  X_test = [sentence_to_features(tokens) for tokens, _ in test_data]
275
  y_test = [tags for _, tags in test_data]
276
  tokens_test = [tokens for tokens, _ in test_data]
 
290
  precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
291
  y_test_flat, y_pred_flat, average="macro"
292
  )
293
+ _, _, f1_weighted, _ = precision_recall_fscore_support(
294
  y_test_flat, y_pred_flat, average="weighted"
295
  )
296
 
297
+ click.echo("\n" + "=" * 60)
298
+ click.echo("EVALUATION RESULTS")
299
+ click.echo("=" * 60)
300
 
301
+ click.echo("\nOverall Metrics:")
302
+ click.echo(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
303
+ click.echo(f" Precision (macro): {precision_macro:.4f}")
304
+ click.echo(f" Recall (macro): {recall_macro:.4f}")
305
+ click.echo(f" F1 (macro): {f1_macro:.4f}")
306
+ click.echo(f" F1 (weighted): {f1_weighted:.4f}")
307
 
308
+ click.echo("\nPer-Tag Classification Report:")
309
  report = classification_report(y_test_flat, y_pred_flat, digits=4)
310
+ click.echo(report)
311
 
312
  # Error analysis
313
  analyze_errors(y_test_flat, y_pred_flat, tokens_flat)
314
 
315
  # Dataset statistics
 
316
  tag_counts = Counter(y_test_flat)
317
  total_tokens = len(y_test_flat)
318
 
319
+ click.echo("\nTest Set Tag Distribution:")
320
+ click.echo("-" * 40)
321
  for tag in labels:
322
  count = tag_counts[tag]
323
  pct = count / total_tokens * 100
324
+ click.echo(f" {tag:<8} {count:>6} ({pct:>5.2f}%)")
325
 
326
  if save_plots:
327
+ cm_path = results_dir / f"confusion_matrix_{version}.png"
328
  plot_confusion_matrix(
329
  y_test_flat, y_pred_flat, labels,
330
+ str(cm_path)
331
  )
332
 
333
  report_dict = classification_report(
334
  y_test_flat, y_pred_flat, output_dict=True
335
  )
336
+ metrics_path = results_dir / f"per_tag_metrics_{version}.png"
337
+ plot_per_tag_metrics(report_dict, str(metrics_path))
338
 
339
  return accuracy
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  if __name__ == "__main__":
343
+ evaluate()
scripts/predict.py CHANGED
@@ -2,6 +2,8 @@
2
  # requires-python = ">=3.9"
3
  # dependencies = [
4
  # "python-crfsuite>=0.9.11",
 
 
5
  # ]
6
  # ///
7
  """
@@ -9,68 +11,98 @@ Inference script for Vietnamese POS Tagger (TRE-1).
9
 
10
  Usage:
11
  uv run scripts/predict.py "Tôi yêu Việt Nam"
12
- uv run scripts/predict.py --model pos_tagger.crfsuite "Hà Nội là thủ đô"
 
13
  echo "Học sinh đang học bài" | uv run scripts/predict.py -
14
  """
15
 
16
- import argparse
17
  import sys
18
  import os
 
 
 
19
 
20
  # Add parent directory to import handler
21
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
 
 
 
 
23
  from handler import EndpointHandler
24
 
25
 
26
- def main():
27
- parser = argparse.ArgumentParser(description="Vietnamese POS Tagger Inference")
28
- parser.add_argument(
29
- "text",
30
- nargs="?",
31
- default="-",
32
- help="Text to tag (use '-' for stdin)"
33
- )
34
- parser.add_argument(
35
- "--model", "-m",
36
- default=".",
37
- help="Path to model directory (default: current directory)"
38
- )
39
- parser.add_argument(
40
- "--format", "-f",
41
- choices=["inline", "json", "conll"],
42
- default="inline",
43
- help="Output format"
44
- )
45
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Read input
48
- if args.text == "-":
49
  text = sys.stdin.read().strip()
50
- else:
51
- text = args.text
52
 
53
  if not text:
54
- print("Error: No input text provided", file=sys.stderr)
55
- sys.exit(1)
56
 
57
  # Load model
58
- handler = EndpointHandler(path=args.model)
59
 
60
  # Predict
61
  result = handler({"inputs": text})
62
 
63
  # Format output
64
- if args.format == "json":
65
- import json
66
- print(json.dumps(result, ensure_ascii=False, indent=2))
67
- elif args.format == "conll":
68
  for i, item in enumerate(result, 1):
69
- print(f"{i}\t{item['token']}\t{item['tag']}")
70
  else: # inline
71
  tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result)
72
- print(tagged)
73
 
74
 
75
  if __name__ == "__main__":
76
- main()
 
2
  # requires-python = ">=3.9"
3
  # dependencies = [
4
  # "python-crfsuite>=0.9.11",
5
+ # "click>=8.0.0",
6
+ # "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
7
  # ]
8
  # ///
9
  """
 
11
 
12
  Usage:
13
  uv run scripts/predict.py "Tôi yêu Việt Nam"
14
+ uv run scripts/predict.py --version v1.0.0 "Hà Nội là thủ đô"
15
+ uv run scripts/predict.py --model models/pos_tagger/v1.0.0 "Test"
16
  echo "Học sinh đang học bài" | uv run scripts/predict.py -
17
  """
18
 
19
+ import json
20
  import sys
21
  import os
22
+ from pathlib import Path
23
+
24
+ import click
25
 
26
  # Add parent directory to import handler
27
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28
 
29
+ # Get project root directory
30
+ PROJECT_ROOT = Path(__file__).parent.parent
31
+
32
  from handler import EndpointHandler
33
 
34
 
35
+ def get_latest_version(task="pos_tagger"):
36
+ """Get the latest model version (sorted by timestamp)."""
37
+ models_dir = PROJECT_ROOT / "models" / task
38
+ if not models_dir.exists():
39
+ return None
40
+ versions = [d.name for d in models_dir.iterdir() if d.is_dir()]
41
+ if not versions:
42
+ return None
43
+ return sorted(versions)[-1] # Latest timestamp
44
+
45
+
46
+ @click.command()
47
+ @click.argument("text", default="-")
48
+ @click.option(
49
+ "--version", "-v",
50
+ default=None,
51
+ help="Model version to use (default: latest)",
52
+ )
53
+ @click.option(
54
+ "--model", "-m",
55
+ default=None,
56
+ help="Custom model directory path (overrides version-based path)",
57
+ )
58
+ @click.option(
59
+ "--format", "-f",
60
+ "output_format",
61
+ type=click.Choice(["inline", "json", "conll"]),
62
+ default="inline",
63
+ help="Output format",
64
+ show_default=True,
65
+ )
66
+ def predict(text, version, model, output_format):
67
+ """Tag Vietnamese text with POS tags.
68
+
69
+ TEXT is the input text to tag. Use '-' to read from stdin.
70
+ """
71
+ # Use latest version if not specified
72
+ if version is None and model is None:
73
+ version = get_latest_version("pos_tagger")
74
+ if version is None:
75
+ raise click.ClickException("No models found in models/pos_tagger/")
76
+
77
+ # Determine model path
78
+ if model:
79
+ model_path = model
80
+ else:
81
+ model_path = str(PROJECT_ROOT / "models" / "pos_tagger" / version)
82
 
83
  # Read input
84
+ if text == "-":
85
  text = sys.stdin.read().strip()
 
 
86
 
87
  if not text:
88
+ raise click.ClickException("No input text provided")
 
89
 
90
  # Load model
91
+ handler = EndpointHandler(path=model_path)
92
 
93
  # Predict
94
  result = handler({"inputs": text})
95
 
96
  # Format output
97
+ if output_format == "json":
98
+ click.echo(json.dumps(result, ensure_ascii=False, indent=2))
99
+ elif output_format == "conll":
 
100
  for i, item in enumerate(result, 1):
101
+ click.echo(f"{i}\t{item['token']}\t{item['tag']}")
102
  else: # inline
103
  tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result)
104
+ click.echo(tagged)
105
 
106
 
107
  if __name__ == "__main__":
108
+ predict()
scripts/predict_word_segmentation.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.9"
3
+ # dependencies = [
4
+ # "python-crfsuite>=0.9.11",
5
+ # "click>=8.0.0",
6
+ # "underthesea>=6.8.0",
7
+ # "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
8
+ # ]
9
+ # ///
10
+ """
11
+ Prediction script for Vietnamese Word Segmentation.
12
+
13
+ Uses underthesea regex_tokenize to split text into syllables,
14
+ then applies CRF model at syllable level to decide word boundaries.
15
+
16
+ Usage:
17
+ uv run scripts/predict_word_segmentation.py "Trên thế giới, giá vàng đang giao dịch"
18
+ echo "Text here" | uv run scripts/predict_word_segmentation.py -
19
+ """
20
+
21
+ import sys
22
+
23
+ import click
24
+ import pycrfsuite
25
+ from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize
26
+
27
+
28
+ def get_syllable_at(syllables, position, offset):
29
+ """Get syllable at position + offset, with boundary handling."""
30
+ idx = position + offset
31
+ if idx < 0:
32
+ return "__BOS__"
33
+ elif idx >= len(syllables):
34
+ return "__EOS__"
35
+ return syllables[idx]
36
+
37
+
38
+ def is_punct(s):
39
+ """Check if string is punctuation."""
40
+ return len(s) == 1 and not s.isalnum()
41
+
42
+
43
+ def extract_syllable_features(syllables, position):
44
+ """Extract features for a syllable at given position."""
45
+ features = {}
46
+
47
+ # Current syllable
48
+ s0 = get_syllable_at(syllables, position, 0)
49
+ is_boundary = s0 in ("__BOS__", "__EOS__")
50
+
51
+ features["S[0]"] = s0
52
+ features["S[0].lower"] = s0.lower() if not is_boundary else s0
53
+ features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False"
54
+ features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False"
55
+ features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False"
56
+ features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False"
57
+ features["S[0].len"] = str(len(s0)) if not is_boundary else "0"
58
+ features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0
59
+ features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0
60
+
61
+ # Previous syllables
62
+ s_1 = get_syllable_at(syllables, position, -1)
63
+ s_2 = get_syllable_at(syllables, position, -2)
64
+ features["S[-1]"] = s_1
65
+ features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1
66
+ features["S[-2]"] = s_2
67
+ features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2
68
+
69
+ # Next syllables
70
+ s1 = get_syllable_at(syllables, position, 1)
71
+ s2 = get_syllable_at(syllables, position, 2)
72
+ features["S[1]"] = s1
73
+ features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1
74
+ features["S[2]"] = s2
75
+ features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2
76
+
77
+ # Bigrams
78
+ features["S[-1,0]"] = f"{s_1}|{s0}"
79
+ features["S[0,1]"] = f"{s0}|{s1}"
80
+
81
+ # Trigrams
82
+ features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}"
83
+
84
+ return features
85
+
86
+
87
+ def sentence_to_syllable_features(syllables):
88
+ """Convert syllable sequence to feature sequences."""
89
+ return [
90
+ [f"{k}={v}" for k, v in extract_syllable_features(syllables, i).items()]
91
+ for i in range(len(syllables))
92
+ ]
93
+
94
+
95
+ def labels_to_words(syllables, labels):
96
+ """Convert syllable sequence and BIO labels back to words."""
97
+ words = []
98
+ current_word = []
99
+
100
+ for syl, label in zip(syllables, labels):
101
+ if label == "B":
102
+ if current_word:
103
+ words.append(" ".join(current_word))
104
+ current_word = [syl]
105
+ else: # I
106
+ current_word.append(syl)
107
+
108
+ if current_word:
109
+ words.append(" ".join(current_word))
110
+
111
+ return words
112
+
113
+
114
+ def segment_text(text, tagger):
115
+ """
116
+ Full pipeline: regex tokenize -> CRF segment -> output words.
117
+ """
118
+ # Step 1: Regex tokenize into syllables
119
+ syllables = regex_tokenize(text)
120
+
121
+ if not syllables:
122
+ return ""
123
+
124
+ # Step 2: Extract syllable features
125
+ X = sentence_to_syllable_features(syllables)
126
+
127
+ # Step 3: Predict BIO labels
128
+ labels = tagger.tag(X)
129
+
130
+ # Step 4: Convert to words (syllables joined with underscore for compound words)
131
+ words = labels_to_words(syllables, labels)
132
+
133
+ return "_".join(words).replace(" ", "_").replace("_", " ").replace(" ", " _ ")
134
+
135
+
136
+ def segment_text_formatted(text, tagger, use_underscore=True):
137
+ """
138
+ Full pipeline with formatted output.
139
+ """
140
+ syllables = regex_tokenize(text)
141
+
142
+ if not syllables:
143
+ return ""
144
+
145
+ X = sentence_to_syllable_features(syllables)
146
+ labels = tagger.tag(X)
147
+ words = labels_to_words(syllables, labels)
148
+
149
+ if use_underscore:
150
+ # Join compound word syllables with underscore
151
+ return " ".join(w.replace(" ", "_") for w in words)
152
+ else:
153
+ return " ".join(words)
154
+
155
+
156
+ @click.command()
157
+ @click.argument("text", required=False)
158
+ @click.option(
159
+ "--model", "-m",
160
+ default="word_segmenter.crfsuite",
161
+ help="Path to CRF model file",
162
+ show_default=True,
163
+ )
164
+ @click.option(
165
+ "--underscore/--no-underscore",
166
+ default=True,
167
+ help="Use underscore to join compound word syllables",
168
+ )
169
+ def main(text, model, underscore):
170
+ """Segment Vietnamese text into words."""
171
+ # Handle stdin input
172
+ if text == "-" or text is None:
173
+ text = sys.stdin.read().strip()
174
+
175
+ if not text:
176
+ click.echo("No input text provided", err=True)
177
+ return
178
+
179
+ # Load model - support both pycrfsuite and underthesea-core formats
180
+ if model.endswith(".crf"):
181
+ # underthesea-core format
182
+ try:
183
+ from underthesea_core import CRFModel, CRFTagger
184
+ except ImportError:
185
+ from underthesea_core.underthesea_core import CRFModel, CRFTagger
186
+ crf_model = CRFModel.load(model)
187
+ tagger = CRFTagger.from_model(crf_model)
188
+ else:
189
+ # pycrfsuite format
190
+ tagger = pycrfsuite.Tagger()
191
+ tagger.open(model)
192
+
193
+ # Process each line
194
+ for line in text.split("\n"):
195
+ if line.strip():
196
+ result = segment_text_formatted(line, tagger, use_underscore=underscore)
197
+ click.echo(result)
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
scripts/train.py CHANGED
@@ -2,25 +2,94 @@
2
  # requires-python = ">=3.9"
3
  # dependencies = [
4
  # "python-crfsuite>=0.9.11",
 
5
  # "datasets>=4.5.0",
6
  # "scikit-learn>=1.6.1",
 
 
 
 
 
7
  # ]
8
  # ///
9
  """
10
  Training script for Vietnamese POS Tagger (TRE-1).
11
 
12
- Reproduces the training process from TECHNICAL_REPORT.md.
 
 
 
 
 
13
 
14
  Usage:
15
  uv run scripts/train.py
16
- uv run scripts/train.py --output model.crfsuite
17
- uv run scripts/train.py --wandb # Enable W&B logging
 
 
 
18
  """
19
 
20
- import argparse
21
- import pycrfsuite
 
 
 
 
 
 
 
 
22
  from datasets import load_dataset
23
- from sklearn.model_selection import train_test_split
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  FEATURE_TEMPLATES = [
@@ -74,7 +143,6 @@ def apply_attribute(value, attribute, dictionary=None):
74
 
75
 
76
  def parse_template(template):
77
- import re
78
  match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template)
79
  if not match:
80
  return None, None
@@ -112,115 +180,422 @@ def sentence_to_features(tokens):
112
  ]
113
 
114
 
115
- def load_data():
116
- print("Loading UDD-v0.1 dataset...")
117
- dataset = load_dataset("undertheseanlp/UDD-v0.1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- sentences = []
120
- for item in dataset["train"]:
121
- tokens = item["tokens"]
122
- tags = item["upos"]
123
- if tokens and tags:
124
- sentences.append((tokens, tags))
125
 
126
- print(f"Loaded {len(sentences)} sentences")
127
- return sentences
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- def train(output_path, use_wandb=False):
131
- sentences = load_data()
132
 
133
- # Split 80/20 as per technical report
134
- train_data, test_data = train_test_split(
135
- sentences, test_size=0.2, random_state=42
136
- )
 
137
 
138
- print(f"Train: {len(train_data)} sentences")
139
- print(f"Test: {len(test_data)} sentences")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # Prepare training data
142
- print("Extracting features...")
 
143
  X_train = [sentence_to_features(tokens) for tokens, _ in train_data]
144
  y_train = [tags for _, tags in train_data]
 
145
 
146
  # Train CRF
147
- print("Training CRF model...")
148
- trainer = pycrfsuite.Trainer(verbose=True)
149
-
150
- for xseq, yseq in zip(X_train, y_train):
151
- trainer.append(xseq, yseq)
152
-
153
- # Training parameters from technical report
154
- trainer.set_params({
155
- "c1": 1.0, # L1 regularization
156
- "c2": 0.001, # L2 regularization
157
- "max_iterations": 100, # Max iterations
158
- "feature.possible_transitions": True,
159
- })
160
 
 
161
  if use_wandb:
162
  try:
163
- import wandb
164
- wandb.init(project="pos-tagger-vietnamese", name="underthesea-crf")
165
- wandb.config.update({
166
- "c1": 1.0,
167
- "c2": 0.001,
168
- "max_iterations": 100,
 
169
  "num_features": len(FEATURE_TEMPLATES),
170
  "train_sentences": len(train_data),
 
171
  "test_sentences": len(test_data),
 
172
  })
173
  except ImportError:
174
- print("wandb not installed, skipping logging")
175
  use_wandb = False
176
 
177
- trainer.train(output_path)
178
- print(f"Model saved to {output_path}")
179
-
180
- # Quick evaluation
181
- print("\nEvaluating on test set...")
182
- tagger = pycrfsuite.Tagger()
183
- tagger.open(output_path)
184
 
 
 
185
  X_test = [sentence_to_features(tokens) for tokens, _ in test_data]
186
  y_test = [tags for _, tags in test_data]
187
 
188
- y_pred = [tagger.tag(xseq) for xseq in X_test]
189
 
190
  # Flatten for metrics
191
  y_test_flat = [tag for tags in y_test for tag in tags]
192
  y_pred_flat = [tag for tags in y_pred for tag in tags]
193
 
194
- from sklearn.metrics import accuracy_score, classification_report
195
-
196
  accuracy = accuracy_score(y_test_flat, y_pred_flat)
197
- print(f"\nAccuracy: {accuracy:.4f}")
198
- print("\nClassification Report:")
199
- print(classification_report(y_test_flat, y_pred_flat))
200
 
201
- if use_wandb:
202
- wandb.log({"accuracy": accuracy})
203
- wandb.finish()
204
 
205
- return output_path
 
 
206
 
 
 
 
 
207
 
208
- def main():
209
- parser = argparse.ArgumentParser(description="Train Vietnamese POS Tagger")
210
- parser.add_argument(
211
- "--output", "-o",
212
- default="pos_tagger.crfsuite",
213
- help="Output model path"
214
- )
215
- parser.add_argument(
216
- "--wandb",
217
- action="store_true",
218
- help="Enable Weights & Biases logging"
219
- )
220
- args = parser.parse_args()
221
 
222
- train(args.output, use_wandb=args.wandb)
 
 
223
 
224
 
225
  if __name__ == "__main__":
226
- main()
 
2
  # requires-python = ">=3.9"
3
  # dependencies = [
4
  # "python-crfsuite>=0.9.11",
5
+ # "crfsuite>=0.3.0",
6
  # "datasets>=4.5.0",
7
  # "scikit-learn>=1.6.1",
8
+ # "click>=8.0.0",
9
+ # "psutil>=5.9.0",
10
+ # "pyyaml>=6.0.0",
11
+ # "underthesea>=6.8.0",
12
+ # "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
13
  # ]
14
  # ///
15
  """
16
  Training script for Vietnamese POS Tagger (TRE-1).
17
 
18
+ Supports 3 CRF trainers:
19
+ - python-crfsuite: Original Python bindings to CRFsuite
20
+ - crfsuite-rs: Rust bindings to CRFsuite (pip install crfsuite)
21
+ - underthesea-core: Underthesea's native Rust CRF implementation
22
+
23
+ Models are saved to: models/pos_tagger/{version}/model.crfsuite
24
 
25
  Usage:
26
  uv run scripts/train.py
27
+ uv run scripts/train.py --trainer crfsuite-rs
28
+ uv run scripts/train.py --trainer underthesea-core
29
+ uv run scripts/train.py --version v1.1.0
30
+ uv run scripts/train.py --wandb
31
+ uv run scripts/train.py --c1 0.5 --c2 0.01 --max-iterations 200
32
  """
33
 
34
+ import platform
35
+ import re
36
+ import time
37
+ from abc import ABC, abstractmethod
38
+ from datetime import datetime
39
+ from pathlib import Path
40
+
41
+ import click
42
+ import psutil
43
+ import yaml
44
  from datasets import load_dataset
45
+ from sklearn.metrics import accuracy_score, classification_report
46
+
47
+
48
+ # Get project root directory
49
+ PROJECT_ROOT = Path(__file__).parent.parent
50
+
51
+ # Available trainers
52
+ TRAINERS = ["python-crfsuite", "crfsuite-rs", "underthesea-core"]
53
+
54
+
55
+ def get_hardware_info():
56
+ """Collect hardware and system information."""
57
+ info = {
58
+ "platform": platform.system(),
59
+ "platform_release": platform.release(),
60
+ "architecture": platform.machine(),
61
+ "python_version": platform.python_version(),
62
+ "cpu_physical_cores": psutil.cpu_count(logical=False),
63
+ "cpu_logical_cores": psutil.cpu_count(logical=True),
64
+ "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
65
+ }
66
+
67
+ try:
68
+ if platform.system() == "Linux":
69
+ with open("/proc/cpuinfo", "r") as f:
70
+ for line in f:
71
+ if "model name" in line:
72
+ info["cpu_model"] = line.split(":")[1].strip()
73
+ break
74
+ except Exception:
75
+ info["cpu_model"] = "Unknown"
76
+
77
+ return info
78
+
79
+
80
+ def format_duration(seconds):
81
+ """Format duration in human-readable format."""
82
+ if seconds < 60:
83
+ return f"{seconds:.2f}s"
84
+ elif seconds < 3600:
85
+ minutes = int(seconds // 60)
86
+ secs = seconds % 60
87
+ return f"{minutes}m {secs:.2f}s"
88
+ else:
89
+ hours = int(seconds // 3600)
90
+ minutes = int((seconds % 3600) // 60)
91
+ secs = seconds % 60
92
+ return f"{hours}h {minutes}m {secs:.2f}s"
93
 
94
 
95
  FEATURE_TEMPLATES = [
 
143
 
144
 
145
  def parse_template(template):
 
146
  match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template)
147
  if not match:
148
  return None, None
 
180
  ]
181
 
182
 
183
+ # ============================================================================
184
+ # Trainer Abstraction
185
+ # ============================================================================
186
+
187
+ class CRFTrainerBase(ABC):
188
+ """Abstract base class for CRF trainers."""
189
+
190
+ name: str = "base"
191
+
192
+ @abstractmethod
193
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
194
+ """Train the CRF model and save to output_path."""
195
+ pass
196
+
197
+ @abstractmethod
198
+ def predict(self, model_path, X_test):
199
+ """Load model and predict on test data."""
200
+ pass
201
+
202
+
203
+ class PythonCRFSuiteTrainer(CRFTrainerBase):
204
+ """Trainer using python-crfsuite (original Python bindings)."""
205
+
206
+ name = "python-crfsuite"
207
+
208
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
209
+ import pycrfsuite
210
+
211
+ trainer = pycrfsuite.Trainer(verbose=verbose)
212
+
213
+ for xseq, yseq in zip(X_train, y_train):
214
+ trainer.append(xseq, yseq)
215
+
216
+ trainer.set_params({
217
+ "c1": c1,
218
+ "c2": c2,
219
+ "max_iterations": max_iterations,
220
+ "feature.possible_transitions": True,
221
+ })
222
+
223
+ trainer.train(str(output_path))
224
+
225
+ def predict(self, model_path, X_test):
226
+ import pycrfsuite
227
+
228
+ tagger = pycrfsuite.Tagger()
229
+ tagger.open(str(model_path))
230
+ return [tagger.tag(xseq) for xseq in X_test]
231
+
232
+
233
+ class CRFSuiteRsTrainer(CRFTrainerBase):
234
+ """Trainer using crfsuite-rs (Rust bindings via pip install crfsuite)."""
235
+
236
+ name = "crfsuite-rs"
237
+
238
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
239
+ import crfsuite
240
+
241
+ trainer = crfsuite.Trainer()
242
+
243
+ # Set parameters
244
+ trainer.set_params({
245
+ "c1": c1,
246
+ "c2": c2,
247
+ "max_iterations": max_iterations,
248
+ "feature.possible_transitions": True,
249
+ })
250
+
251
+ # Add training data
252
+ for xseq, yseq in zip(X_train, y_train):
253
+ trainer.append(xseq, yseq)
254
+
255
+ # Train
256
+ trainer.train(str(output_path))
257
+
258
+ def predict(self, model_path, X_test):
259
+ import crfsuite
260
+
261
+ model = crfsuite.Model(str(model_path))
262
+ return [model.tag(xseq) for xseq in X_test]
263
+
264
+
265
+ class UndertheseaCoreTrainer(CRFTrainerBase):
266
+ """Trainer using underthesea-core native Rust CRF with LBFGS optimization.
267
+
268
+ This trainer uses the native underthesea-core Rust CRF implementation
269
+ with L-BFGS optimization, matching CRFsuite performance.
270
 
271
+ Requires building underthesea-core from source:
272
+ cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core
273
+ uv venv && source .venv/bin/activate
274
+ uv pip install maturin
275
+ maturin develop --release
276
+ """
277
 
278
+ name = "underthesea-core"
 
279
 
280
+ def _check_trainer_import(self):
281
+ """Check if CRFTrainer is available."""
282
+ try:
283
+ from underthesea_core import CRFTrainer
284
+ return CRFTrainer
285
+ except ImportError:
286
+ pass
287
+
288
+ try:
289
+ from underthesea_core.underthesea_core import CRFTrainer
290
+ return CRFTrainer
291
+ except ImportError:
292
+ pass
293
+
294
+ raise ImportError(
295
+ "CRFTrainer not available in underthesea_core.\n"
296
+ "Build from source with LBFGS support:\n"
297
+ " cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core\n"
298
+ " source .venv/bin/activate && maturin develop --release"
299
+ )
300
+
301
+ def _check_tagger_import(self):
302
+ """Check if CRFModel and CRFTagger are available."""
303
+ try:
304
+ from underthesea_core import CRFModel, CRFTagger
305
+ return CRFModel, CRFTagger
306
+ except ImportError:
307
+ pass
308
+
309
+ try:
310
+ from underthesea_core.underthesea_core import CRFModel, CRFTagger
311
+ return CRFModel, CRFTagger
312
+ except ImportError:
313
+ pass
314
+
315
+ raise ImportError("CRFModel/CRFTagger not available in underthesea_core")
316
+
317
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
318
+ CRFTrainer = self._check_trainer_import()
319
+
320
+ # Use LBFGS (default, fast)
321
+ trainer = CRFTrainer(
322
+ loss_function="lbfgs",
323
+ l1_penalty=c1,
324
+ l2_penalty=c2,
325
+ max_iterations=max_iterations,
326
+ verbose=1 if verbose else 0,
327
+ )
328
 
329
+ # Train
330
+ model = trainer.train(X_train, y_train)
331
 
332
+ # Save model
333
+ output_path_str = str(output_path)
334
+ if output_path_str.endswith('.crfsuite'):
335
+ output_path_str = output_path_str.replace('.crfsuite', '.crf')
336
+ model.save(output_path_str)
337
 
338
+ # Store the actual path for prediction
339
+ self._model_path = output_path_str
340
+
341
+ def predict(self, model_path, X_test):
342
+ CRFModel, CRFTagger = self._check_tagger_import()
343
+
344
+ # Use the actual saved path if available
345
+ model_path_str = str(model_path)
346
+ if hasattr(self, '_model_path'):
347
+ model_path_str = self._model_path
348
+ elif model_path_str.endswith('.crfsuite'):
349
+ model_path_str = model_path_str.replace('.crfsuite', '.crf')
350
+
351
+ model = CRFModel.load(model_path_str)
352
+ tagger = CRFTagger.from_model(model)
353
+ return [tagger.tag(xseq) for xseq in X_test]
354
+
355
+
356
+ def get_trainer(trainer_name: str) -> CRFTrainerBase:
357
+ """Get trainer instance by name."""
358
+ trainers = {
359
+ "python-crfsuite": PythonCRFSuiteTrainer,
360
+ "crfsuite-rs": CRFSuiteRsTrainer,
361
+ "underthesea-core": UndertheseaCoreTrainer,
362
+ }
363
+ if trainer_name not in trainers:
364
+ raise ValueError(f"Unknown trainer: {trainer_name}. Available: {list(trainers.keys())}")
365
+ return trainers[trainer_name]()
366
+
367
+
368
+ # ============================================================================
369
+ # Data Loading
370
+ # ============================================================================
371
+
372
+ def load_data():
373
+ click.echo("Loading UDD-1 dataset...")
374
+ dataset = load_dataset("undertheseanlp/UDD-1")
375
+
376
+ def extract_sentences(split):
377
+ sentences = []
378
+ for item in split:
379
+ tokens = item["tokens"]
380
+ tags = item["upos"]
381
+ if tokens and tags:
382
+ sentences.append((tokens, tags))
383
+ return sentences
384
+
385
+ train_data = extract_sentences(dataset["train"])
386
+ val_data = extract_sentences(dataset["validation"])
387
+ test_data = extract_sentences(dataset["test"])
388
+
389
+ click.echo(f"Loaded {len(train_data)} train, {len(val_data)} val, {len(test_data)} test sentences")
390
+ return train_data, val_data, test_data
391
+
392
+
393
+ def save_metadata(output_dir, version, trainer_name, train_data, val_data, test_data, c1, c2, max_iterations, accuracy, hw_info, training_time):
394
+ """Save model metadata to YAML file."""
395
+ metadata = {
396
+ "model": {
397
+ "name": "Vietnamese POS Tagger",
398
+ "version": version,
399
+ "type": "CRF (Conditional Random Field)",
400
+ "framework": trainer_name,
401
+ },
402
+ "training": {
403
+ "dataset": "undertheseanlp/UDD-1",
404
+ "train_sentences": len(train_data),
405
+ "val_sentences": len(val_data),
406
+ "test_sentences": len(test_data),
407
+ "hyperparameters": {
408
+ "c1": c1,
409
+ "c2": c2,
410
+ "max_iterations": max_iterations,
411
+ },
412
+ "duration_seconds": round(training_time, 2),
413
+ },
414
+ "performance": {
415
+ "test_accuracy": round(accuracy, 4),
416
+ },
417
+ "environment": {
418
+ "platform": hw_info["platform"],
419
+ "cpu_model": hw_info.get("cpu_model", "Unknown"),
420
+ "python_version": hw_info["python_version"],
421
+ },
422
+ "files": {
423
+ "model": "model.crfsuite",
424
+ "config": "../../../configs/pos_tagger.yaml",
425
+ },
426
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
427
+ "author": "undertheseanlp",
428
+ }
429
+
430
+ metadata_path = output_dir / "metadata.yaml"
431
+ with open(metadata_path, "w") as f:
432
+ yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
433
+ click.echo(f"Metadata saved to {metadata_path}")
434
+
435
+
436
+ def get_default_version():
437
+ """Generate timestamp-based version."""
438
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
439
+
440
+
441
+ @click.command()
442
+ @click.option(
443
+ "--trainer", "-t",
444
+ type=click.Choice(TRAINERS),
445
+ default="python-crfsuite",
446
+ help="CRF trainer to use",
447
+ show_default=True,
448
+ )
449
+ @click.option(
450
+ "--version", "-v",
451
+ default=None,
452
+ help="Model version (default: timestamp, e.g., 20260131_154530)",
453
+ )
454
+ @click.option(
455
+ "--output", "-o",
456
+ default=None,
457
+ help="Custom output path (overrides version-based path)",
458
+ )
459
+ @click.option(
460
+ "--c1",
461
+ default=1.0,
462
+ type=float,
463
+ help="L1 regularization coefficient",
464
+ show_default=True,
465
+ )
466
+ @click.option(
467
+ "--c2",
468
+ default=0.001,
469
+ type=float,
470
+ help="L2 regularization coefficient",
471
+ show_default=True,
472
+ )
473
+ @click.option(
474
+ "--max-iterations",
475
+ default=100,
476
+ type=int,
477
+ help="Maximum training iterations",
478
+ show_default=True,
479
+ )
480
+ @click.option(
481
+ "--wandb/--no-wandb",
482
+ default=False,
483
+ help="Enable Weights & Biases logging",
484
+ )
485
+ def train(trainer, version, output, c1, c2, max_iterations, wandb):
486
+ """Train Vietnamese POS Tagger using CRF on UDD-1 dataset."""
487
+ total_start_time = time.time()
488
+ start_datetime = datetime.now()
489
+
490
+ # Get trainer
491
+ crf_trainer = get_trainer(trainer)
492
+
493
+ # Use timestamp version if not specified
494
+ if version is None:
495
+ version = get_default_version()
496
+
497
+ # Determine output directory
498
+ if output:
499
+ output_path = Path(output)
500
+ output_dir = output_path.parent
501
+ else:
502
+ output_dir = PROJECT_ROOT / "models" / "pos_tagger" / version
503
+ output_dir.mkdir(parents=True, exist_ok=True)
504
+ output_path = output_dir / "model.crfsuite"
505
+
506
+ # Collect hardware info
507
+ hw_info = get_hardware_info()
508
+
509
+ click.echo("=" * 60)
510
+ click.echo(f"POS Tagger Training - {version}")
511
+ click.echo("=" * 60)
512
+ click.echo(f"Trainer: {trainer}")
513
+ click.echo(f"Platform: {hw_info['platform']}")
514
+ click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}")
515
+ click.echo(f"Output: {output_path}")
516
+ click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
517
+ click.echo("=" * 60)
518
+
519
+ train_data, val_data, test_data = load_data()
520
+
521
+ click.echo(f"\nTrain: {len(train_data)} sentences")
522
+ click.echo(f"Validation: {len(val_data)} sentences")
523
+ click.echo(f"Test: {len(test_data)} sentences")
524
 
525
  # Prepare training data
526
+ click.echo("\nExtracting features...")
527
+ feature_start = time.time()
528
  X_train = [sentence_to_features(tokens) for tokens, _ in train_data]
529
  y_train = [tags for _, tags in train_data]
530
+ click.echo(f"Feature extraction: {format_duration(time.time() - feature_start)}")
531
 
532
  # Train CRF
533
+ click.echo(f"\nTraining CRF model with {trainer}...")
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ use_wandb = wandb
536
  if use_wandb:
537
  try:
538
+ import wandb as wb
539
+ wb.init(project="pos-tagger-vietnamese", name=f"crf-{trainer}-{version}")
540
+ wb.config.update({
541
+ "trainer": trainer,
542
+ "c1": c1,
543
+ "c2": c2,
544
+ "max_iterations": max_iterations,
545
  "num_features": len(FEATURE_TEMPLATES),
546
  "train_sentences": len(train_data),
547
+ "val_sentences": len(val_data),
548
  "test_sentences": len(test_data),
549
+ "version": version,
550
  })
551
  except ImportError:
552
+ click.echo("wandb not installed, skipping logging", err=True)
553
  use_wandb = False
554
 
555
+ crf_start = time.time()
556
+ crf_trainer.train(X_train, y_train, output_path, c1, c2, max_iterations, verbose=True)
557
+ crf_time = time.time() - crf_start
558
+ click.echo(f"\nModel saved to {output_path}")
559
+ click.echo(f"CRF training: {format_duration(crf_time)}")
 
 
560
 
561
+ # Evaluation
562
+ click.echo("\nEvaluating on test set...")
563
  X_test = [sentence_to_features(tokens) for tokens, _ in test_data]
564
  y_test = [tags for _, tags in test_data]
565
 
566
+ y_pred = crf_trainer.predict(output_path, X_test)
567
 
568
  # Flatten for metrics
569
  y_test_flat = [tag for tags in y_test for tag in tags]
570
  y_pred_flat = [tag for tags in y_pred for tag in tags]
571
 
 
 
572
  accuracy = accuracy_score(y_test_flat, y_pred_flat)
 
 
 
573
 
574
+ total_time = time.time() - total_start_time
 
 
575
 
576
+ click.echo(f"\nAccuracy: {accuracy:.4f}")
577
+ click.echo("\nClassification Report:")
578
+ click.echo(classification_report(y_test_flat, y_pred_flat))
579
 
580
+ # Save metadata
581
+ if not output:
582
+ save_metadata(output_dir, version, trainer, train_data, val_data, test_data,
583
+ c1, c2, max_iterations, accuracy, hw_info, total_time)
584
 
585
+ click.echo("\n" + "=" * 60)
586
+ click.echo("Training Summary")
587
+ click.echo("=" * 60)
588
+ click.echo(f"Trainer: {trainer}")
589
+ click.echo(f"Version: {version}")
590
+ click.echo(f"Model: {output_path}")
591
+ click.echo(f"Accuracy: {accuracy:.4f}")
592
+ click.echo(f"Total time: {format_duration(total_time)}")
593
+ click.echo("=" * 60)
 
 
 
 
594
 
595
+ if use_wandb:
596
+ wb.log({"accuracy": accuracy})
597
+ wb.finish()
598
 
599
 
600
  if __name__ == "__main__":
601
+ train()
scripts/train_word_segmentation.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.9"
3
+ # dependencies = [
4
+ # "python-crfsuite>=0.9.11",
5
+ # "crfsuite>=0.3.0",
6
+ # "datasets>=4.5.0",
7
+ # "scikit-learn>=1.6.1",
8
+ # "click>=8.0.0",
9
+ # "psutil>=5.9.0",
10
+ # "pyyaml>=6.0.0",
11
+ # "underthesea>=6.8.0",
12
+ # "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
13
+ # ]
14
+ # ///
15
+ # Note: underthesea-core trainer now uses crfsuite (LBFGS) for fast training
16
+ """
17
+ Training script for Vietnamese Word Segmentation using CRF.
18
+
19
+ Supports 3 CRF trainers:
20
+ - python-crfsuite: Original Python bindings to CRFsuite
21
+ - crfsuite-rs: Rust bindings to CRFsuite (pip install crfsuite)
22
+ - underthesea-core: Underthesea's native Rust CRF implementation
23
+
24
+ Models are saved to: models/word_segmentation/{version}/model.crfsuite
25
+
26
+ Uses BIO tagging at SYLLABLE level:
27
+ - B: Beginning of a word (first syllable)
28
+ - I: Inside a word (continuation syllables)
29
+
30
+ Usage:
31
+ uv run scripts/train_word_segmentation.py
32
+ uv run scripts/train_word_segmentation.py --trainer crfsuite-rs
33
+ uv run scripts/train_word_segmentation.py --trainer underthesea-core
34
+ uv run scripts/train_word_segmentation.py --version v1.1.0
35
+ """
36
+
37
+ import os
38
+ import platform
39
+ import time
40
+ from abc import ABC, abstractmethod
41
+ from datetime import datetime
42
+ from pathlib import Path
43
+
44
+ import click
45
+ import psutil
46
+ import yaml
47
+ from datasets import load_dataset
48
+ from sklearn.metrics import accuracy_score, classification_report, f1_score
49
+ from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize
50
+
51
+
52
+ # Get project root directory
53
+ PROJECT_ROOT = Path(__file__).parent.parent
54
+
55
+ # Available trainers
56
+ TRAINERS = ["python-crfsuite", "crfsuite-rs", "underthesea-core"]
57
+
58
+
59
+ def get_hardware_info():
60
+ """Collect hardware and system information."""
61
+ info = {
62
+ "platform": platform.system(),
63
+ "platform_release": platform.release(),
64
+ "architecture": platform.machine(),
65
+ "python_version": platform.python_version(),
66
+ "cpu_physical_cores": psutil.cpu_count(logical=False),
67
+ "cpu_logical_cores": psutil.cpu_count(logical=True),
68
+ "ram_total_gb": round(psutil.virtual_memory().total / (1024**3), 2),
69
+ }
70
+
71
+ try:
72
+ if platform.system() == "Linux":
73
+ with open("/proc/cpuinfo", "r") as f:
74
+ for line in f:
75
+ if "model name" in line:
76
+ info["cpu_model"] = line.split(":")[1].strip()
77
+ break
78
+ except Exception:
79
+ info["cpu_model"] = "Unknown"
80
+
81
+ return info
82
+
83
+
84
+ def format_duration(seconds):
85
+ """Format duration in human-readable format."""
86
+ if seconds < 60:
87
+ return f"{seconds:.2f}s"
88
+ elif seconds < 3600:
89
+ minutes = int(seconds // 60)
90
+ secs = seconds % 60
91
+ return f"{minutes}m {secs:.2f}s"
92
+ else:
93
+ hours = int(seconds // 3600)
94
+ minutes = int((seconds % 3600) // 60)
95
+ secs = seconds % 60
96
+ return f"{hours}h {minutes}m {secs:.2f}s"
97
+
98
+
99
+ # Syllable-level feature templates
100
+ FEATURE_TEMPLATES = [
101
+ # Current syllable
102
+ "S[0]", # Syllable text
103
+ "S[0].lower", # Lowercase
104
+ "S[0].istitle", # Is title case
105
+ "S[0].isupper", # Is all uppercase
106
+ "S[0].isdigit", # Is digit
107
+ "S[0].ispunct", # Is punctuation
108
+ "S[0].len", # Length
109
+ "S[0].prefix2", # First 2 chars
110
+ "S[0].suffix2", # Last 2 chars
111
+ # Previous syllables
112
+ "S[-1]",
113
+ "S[-1].lower",
114
+ "S[-2]",
115
+ "S[-2].lower",
116
+ # Next syllables
117
+ "S[1]",
118
+ "S[1].lower",
119
+ "S[2]",
120
+ "S[2].lower",
121
+ # Bigrams
122
+ "S[-1,0]",
123
+ "S[0,1]",
124
+ # Trigrams
125
+ "S[-1,0,1]",
126
+ ]
127
+
128
+
129
+ def get_syllable_at(syllables, position, offset):
130
+ """Get syllable at position + offset, with boundary handling."""
131
+ idx = position + offset
132
+ if idx < 0:
133
+ return "__BOS__"
134
+ elif idx >= len(syllables):
135
+ return "__EOS__"
136
+ return syllables[idx]
137
+
138
+
139
+ def is_punct(s):
140
+ """Check if string is punctuation."""
141
+ return len(s) == 1 and not s.isalnum()
142
+
143
+
144
+ def extract_syllable_features(syllables, position):
145
+ """Extract features for a syllable at given position."""
146
+ features = {}
147
+
148
+ # Current syllable
149
+ s0 = get_syllable_at(syllables, position, 0)
150
+ is_boundary = s0 in ("__BOS__", "__EOS__")
151
+
152
+ features["S[0]"] = s0
153
+ features["S[0].lower"] = s0.lower() if not is_boundary else s0
154
+ features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False"
155
+ features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False"
156
+ features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False"
157
+ features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False"
158
+ features["S[0].len"] = str(len(s0)) if not is_boundary else "0"
159
+ features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0
160
+ features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0
161
+
162
+ # Previous syllables
163
+ s_1 = get_syllable_at(syllables, position, -1)
164
+ s_2 = get_syllable_at(syllables, position, -2)
165
+ features["S[-1]"] = s_1
166
+ features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1
167
+ features["S[-2]"] = s_2
168
+ features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2
169
+
170
+ # Next syllables
171
+ s1 = get_syllable_at(syllables, position, 1)
172
+ s2 = get_syllable_at(syllables, position, 2)
173
+ features["S[1]"] = s1
174
+ features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1
175
+ features["S[2]"] = s2
176
+ features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2
177
+
178
+ # Bigrams
179
+ features["S[-1,0]"] = f"{s_1}|{s0}"
180
+ features["S[0,1]"] = f"{s0}|{s1}"
181
+
182
+ # Trigrams
183
+ features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}"
184
+
185
+ return features
186
+
187
+
188
+ def sentence_to_syllable_features(syllables):
189
+ """Convert syllable sequence to feature sequences."""
190
+ return [
191
+ [f"{k}={v}" for k, v in extract_syllable_features(syllables, i).items()]
192
+ for i in range(len(syllables))
193
+ ]
194
+
195
+
196
+ def tokens_to_syllable_labels(tokens):
197
+ """
198
+ Convert tokenized compound words to syllable-level BIO labels.
199
+
200
+ Each compound word (e.g., "Thời hạn") is split into syllables,
201
+ first syllable gets 'B', rest get 'I'.
202
+ """
203
+ syllables = []
204
+ labels = []
205
+
206
+ for token in tokens:
207
+ # Split compound word into syllables using regex_tokenize
208
+ token_syllables = regex_tokenize(token)
209
+
210
+ for i, syl in enumerate(token_syllables):
211
+ syllables.append(syl)
212
+ if i == 0:
213
+ labels.append("B")
214
+ else:
215
+ labels.append("I")
216
+
217
+ return syllables, labels
218
+
219
+
220
+ def labels_to_words(syllables, labels):
221
+ """Convert syllable sequence and BIO labels back to words."""
222
+ words = []
223
+ current_word = []
224
+
225
+ for syl, label in zip(syllables, labels):
226
+ if label == "B":
227
+ if current_word:
228
+ words.append(" ".join(current_word))
229
+ current_word = [syl]
230
+ else: # I
231
+ current_word.append(syl)
232
+
233
+ if current_word:
234
+ words.append(" ".join(current_word))
235
+
236
+ return words
237
+
238
+
239
+ def compute_word_metrics(y_true, y_pred, syllables_list):
240
+ """Compute word-level F1 score."""
241
+ correct = 0
242
+ total_pred = 0
243
+ total_true = 0
244
+
245
+ for syllables, true_labels, pred_labels in zip(syllables_list, y_true, y_pred):
246
+ true_words = labels_to_words(syllables, true_labels)
247
+ pred_words = labels_to_words(syllables, pred_labels)
248
+
249
+ total_true += len(true_words)
250
+ total_pred += len(pred_words)
251
+
252
+ # Count exact word matches at same positions
253
+ true_boundaries = set()
254
+ pred_boundaries = set()
255
+
256
+ pos = 0
257
+ for word in true_words:
258
+ n_syls = len(word.split())
259
+ true_boundaries.add((pos, pos + n_syls))
260
+ pos += n_syls
261
+
262
+ pos = 0
263
+ for word in pred_words:
264
+ n_syls = len(word.split())
265
+ pred_boundaries.add((pos, pos + n_syls))
266
+ pos += n_syls
267
+
268
+ correct += len(true_boundaries & pred_boundaries)
269
+
270
+ precision = correct / total_pred if total_pred > 0 else 0
271
+ recall = correct / total_true if total_true > 0 else 0
272
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
273
+
274
+ return precision, recall, f1
275
+
276
+
277
+ def load_data():
278
+ """Load UDD-1 dataset and convert to syllable-level sequences."""
279
+ click.echo("Loading UDD-1 dataset...")
280
+ dataset = load_dataset("undertheseanlp/UDD-1")
281
+
282
+ def extract_syllable_sequences(split):
283
+ sequences = []
284
+ for item in split:
285
+ tokens = item["tokens"]
286
+ if tokens:
287
+ syllables, labels = tokens_to_syllable_labels(tokens)
288
+ if syllables:
289
+ sequences.append((syllables, labels))
290
+ return sequences
291
+
292
+ train_data = extract_syllable_sequences(dataset["train"])
293
+ val_data = extract_syllable_sequences(dataset["validation"])
294
+ test_data = extract_syllable_sequences(dataset["test"])
295
+
296
+ # Statistics
297
+ train_syls = sum(len(syls) for syls, _ in train_data)
298
+ val_syls = sum(len(syls) for syls, _ in val_data)
299
+ test_syls = sum(len(syls) for syls, _ in test_data)
300
+
301
+ click.echo(f"Loaded {len(train_data)} train ({train_syls} syllables), "
302
+ f"{len(val_data)} val ({val_syls} syllables), "
303
+ f"{len(test_data)} test ({test_syls} syllables) sentences")
304
+
305
+ return train_data, val_data, test_data, {
306
+ "train_sentences": len(train_data),
307
+ "train_syllables": train_syls,
308
+ "val_sentences": len(val_data),
309
+ "val_syllables": val_syls,
310
+ "test_sentences": len(test_data),
311
+ "test_syllables": test_syls,
312
+ }
313
+
314
+
315
+ # ============================================================================
316
+ # Trainer Abstraction
317
+ # ============================================================================
318
+
319
+ class CRFTrainerBase(ABC):
320
+ """Abstract base class for CRF trainers."""
321
+
322
+ name: str = "base"
323
+
324
+ @abstractmethod
325
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
326
+ """Train the CRF model and save to output_path."""
327
+ pass
328
+
329
+ @abstractmethod
330
+ def predict(self, model_path, X_test):
331
+ """Load model and predict on test data."""
332
+ pass
333
+
334
+
335
+ class PythonCRFSuiteTrainer(CRFTrainerBase):
336
+ """Trainer using python-crfsuite (original Python bindings)."""
337
+
338
+ name = "python-crfsuite"
339
+
340
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
341
+ import pycrfsuite
342
+
343
+ trainer = pycrfsuite.Trainer(verbose=verbose)
344
+
345
+ for xseq, yseq in zip(X_train, y_train):
346
+ trainer.append(xseq, yseq)
347
+
348
+ trainer.set_params({
349
+ "c1": c1,
350
+ "c2": c2,
351
+ "max_iterations": max_iterations,
352
+ "feature.possible_transitions": True,
353
+ })
354
+
355
+ trainer.train(str(output_path))
356
+
357
+ def predict(self, model_path, X_test):
358
+ import pycrfsuite
359
+
360
+ tagger = pycrfsuite.Tagger()
361
+ tagger.open(str(model_path))
362
+ return [tagger.tag(xseq) for xseq in X_test]
363
+
364
+
365
+ class CRFSuiteRsTrainer(CRFTrainerBase):
366
+ """Trainer using crfsuite-rs (Rust bindings via pip install crfsuite)."""
367
+
368
+ name = "crfsuite-rs"
369
+
370
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
371
+ import crfsuite
372
+
373
+ trainer = crfsuite.Trainer()
374
+
375
+ # Set parameters
376
+ trainer.set_params({
377
+ "c1": c1,
378
+ "c2": c2,
379
+ "max_iterations": max_iterations,
380
+ "feature.possible_transitions": True,
381
+ })
382
+
383
+ # Add training data
384
+ for xseq, yseq in zip(X_train, y_train):
385
+ trainer.append(xseq, yseq)
386
+
387
+ # Train
388
+ trainer.train(str(output_path))
389
+
390
+ def predict(self, model_path, X_test):
391
+ import crfsuite
392
+
393
+ model = crfsuite.Model(str(model_path))
394
+ return [model.tag(xseq) for xseq in X_test]
395
+
396
+
397
+ class UndertheseaCoreTrainer(CRFTrainerBase):
398
+ """Trainer using underthesea-core native Rust CRF with LBFGS optimization.
399
+
400
+ This trainer uses the native underthesea-core Rust CRF implementation
401
+ with L-BFGS optimization, matching CRFsuite performance.
402
+
403
+ Requires building underthesea-core from source:
404
+ cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core
405
+ uv venv && source .venv/bin/activate
406
+ uv pip install maturin
407
+ maturin develop --release
408
+ """
409
+
410
+ name = "underthesea-core"
411
+
412
+ def _check_trainer_import(self):
413
+ """Check if CRFTrainer is available."""
414
+ try:
415
+ from underthesea_core import CRFTrainer
416
+ return CRFTrainer
417
+ except ImportError:
418
+ pass
419
+
420
+ try:
421
+ from underthesea_core.underthesea_core import CRFTrainer
422
+ return CRFTrainer
423
+ except ImportError:
424
+ pass
425
+
426
+ raise ImportError(
427
+ "CRFTrainer not available in underthesea_core.\n"
428
+ "Build from source with LBFGS support:\n"
429
+ " cd ~/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core\n"
430
+ " source .venv/bin/activate && maturin develop --release"
431
+ )
432
+
433
+ def _check_tagger_import(self):
434
+ """Check if CRFModel and CRFTagger are available."""
435
+ try:
436
+ from underthesea_core import CRFModel, CRFTagger
437
+ return CRFModel, CRFTagger
438
+ except ImportError:
439
+ pass
440
+
441
+ try:
442
+ from underthesea_core.underthesea_core import CRFModel, CRFTagger
443
+ return CRFModel, CRFTagger
444
+ except ImportError:
445
+ pass
446
+
447
+ raise ImportError("CRFModel/CRFTagger not available in underthesea_core")
448
+
449
+ def train(self, X_train, y_train, output_path, c1, c2, max_iterations, verbose=True):
450
+ CRFTrainer = self._check_trainer_import()
451
+
452
+ # Use LBFGS (default, fast)
453
+ trainer = CRFTrainer(
454
+ loss_function="lbfgs",
455
+ l1_penalty=c1,
456
+ l2_penalty=c2,
457
+ max_iterations=max_iterations,
458
+ verbose=1 if verbose else 0,
459
+ )
460
+
461
+ # Train
462
+ model = trainer.train(X_train, y_train)
463
+
464
+ # Save model
465
+ output_path_str = str(output_path)
466
+ if output_path_str.endswith('.crfsuite'):
467
+ output_path_str = output_path_str.replace('.crfsuite', '.crf')
468
+ model.save(output_path_str)
469
+
470
+ # Store the actual path for prediction
471
+ self._model_path = output_path_str
472
+
473
+ def predict(self, model_path, X_test):
474
+ CRFModel, CRFTagger = self._check_tagger_import()
475
+
476
+ # Use the actual saved path if available
477
+ model_path_str = str(model_path)
478
+ if hasattr(self, '_model_path'):
479
+ model_path_str = self._model_path
480
+ elif model_path_str.endswith('.crfsuite'):
481
+ model_path_str = model_path_str.replace('.crfsuite', '.crf')
482
+
483
+ model = CRFModel.load(model_path_str)
484
+ tagger = CRFTagger.from_model(model)
485
+ return [tagger.tag(xseq) for xseq in X_test]
486
+
487
+
488
+ def get_trainer(trainer_name: str) -> CRFTrainerBase:
489
+ """Get trainer instance by name."""
490
+ trainers = {
491
+ "python-crfsuite": PythonCRFSuiteTrainer,
492
+ "crfsuite-rs": CRFSuiteRsTrainer,
493
+ "underthesea-core": UndertheseaCoreTrainer,
494
+ }
495
+ if trainer_name not in trainers:
496
+ raise ValueError(f"Unknown trainer: {trainer_name}. Available: {list(trainers.keys())}")
497
+ return trainers[trainer_name]()
498
+
499
+
500
+ # ============================================================================
501
+ # Metadata and CLI
502
+ # ============================================================================
503
+
504
+ def save_metadata(output_dir, version, trainer_name, data_stats, c1, c2, max_iterations, metrics, hw_info, training_time):
505
+ """Save model metadata to YAML file."""
506
+ metadata = {
507
+ "model": {
508
+ "name": "Vietnamese Word Segmentation",
509
+ "version": version,
510
+ "type": "CRF (Conditional Random Field)",
511
+ "framework": trainer_name,
512
+ "tagging_scheme": "BIO",
513
+ },
514
+ "training": {
515
+ "dataset": "undertheseanlp/UDD-1",
516
+ "train_sentences": data_stats["train_sentences"],
517
+ "train_syllables": data_stats["train_syllables"],
518
+ "val_sentences": data_stats["val_sentences"],
519
+ "val_syllables": data_stats["val_syllables"],
520
+ "test_sentences": data_stats["test_sentences"],
521
+ "test_syllables": data_stats["test_syllables"],
522
+ "hyperparameters": {
523
+ "c1": c1,
524
+ "c2": c2,
525
+ "max_iterations": max_iterations,
526
+ },
527
+ "duration_seconds": round(training_time, 2),
528
+ },
529
+ "performance": {
530
+ "syllable_accuracy": round(metrics["syl_accuracy"], 4),
531
+ "syllable_f1": round(metrics["syl_f1"], 4),
532
+ "word_precision": round(metrics["word_precision"], 4),
533
+ "word_recall": round(metrics["word_recall"], 4),
534
+ "word_f1": round(metrics["word_f1"], 4),
535
+ },
536
+ "environment": {
537
+ "platform": hw_info["platform"],
538
+ "cpu_model": hw_info.get("cpu_model", "Unknown"),
539
+ "python_version": hw_info["python_version"],
540
+ },
541
+ "files": {
542
+ "model": "model.crfsuite",
543
+ "config": "../../../configs/word_segmentation.yaml",
544
+ },
545
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
546
+ "author": "undertheseanlp",
547
+ }
548
+
549
+ metadata_path = output_dir / "metadata.yaml"
550
+ with open(metadata_path, "w") as f:
551
+ yaml.dump(metadata, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
552
+ click.echo(f"Metadata saved to {metadata_path}")
553
+
554
+
555
+ def get_default_version():
556
+ """Generate timestamp-based version."""
557
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
558
+
559
+
560
+ @click.command()
561
+ @click.option(
562
+ "--trainer", "-t",
563
+ type=click.Choice(TRAINERS),
564
+ default="python-crfsuite",
565
+ help="CRF trainer to use",
566
+ show_default=True,
567
+ )
568
+ @click.option(
569
+ "--version", "-v",
570
+ default=None,
571
+ help="Model version (default: timestamp, e.g., 20260131_154530)",
572
+ )
573
+ @click.option(
574
+ "--output", "-o",
575
+ default=None,
576
+ help="Custom output path (overrides version-based path)",
577
+ )
578
+ @click.option(
579
+ "--c1",
580
+ default=1.0,
581
+ type=float,
582
+ help="L1 regularization coefficient",
583
+ show_default=True,
584
+ )
585
+ @click.option(
586
+ "--c2",
587
+ default=0.001,
588
+ type=float,
589
+ help="L2 regularization coefficient",
590
+ show_default=True,
591
+ )
592
+ @click.option(
593
+ "--max-iterations",
594
+ default=100,
595
+ type=int,
596
+ help="Maximum training iterations",
597
+ show_default=True,
598
+ )
599
+ @click.option(
600
+ "--wandb/--no-wandb",
601
+ default=False,
602
+ help="Enable Weights & Biases logging",
603
+ )
604
+ def train(trainer, version, output, c1, c2, max_iterations, wandb):
605
+ """Train Vietnamese Word Segmenter using CRF on UDD-1 dataset."""
606
+ total_start_time = time.time()
607
+ start_datetime = datetime.now()
608
+
609
+ # Get trainer
610
+ crf_trainer = get_trainer(trainer)
611
+
612
+ # Use timestamp version if not specified
613
+ if version is None:
614
+ version = get_default_version()
615
+
616
+ # Determine output directory
617
+ if output:
618
+ output_path = Path(output)
619
+ output_dir = output_path.parent
620
+ else:
621
+ output_dir = PROJECT_ROOT / "models" / "word_segmentation" / version
622
+ output_dir.mkdir(parents=True, exist_ok=True)
623
+ output_path = output_dir / "model.crfsuite"
624
+
625
+ # Collect hardware info
626
+ hw_info = get_hardware_info()
627
+
628
+ click.echo("=" * 60)
629
+ click.echo(f"Word Segmentation Training - {version}")
630
+ click.echo("=" * 60)
631
+ click.echo(f"Trainer: {trainer}")
632
+ click.echo(f"Platform: {hw_info['platform']}")
633
+ click.echo(f"CPU: {hw_info.get('cpu_model', 'Unknown')}")
634
+ click.echo(f"Output: {output_path}")
635
+ click.echo(f"Started: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
636
+ click.echo("=" * 60)
637
+
638
+ # Load data
639
+ train_data, val_data, test_data, data_stats = load_data()
640
+
641
+ click.echo(f"\nTrain: {len(train_data)} sentences ({data_stats['train_syllables']} syllables)")
642
+ click.echo(f"Validation: {len(val_data)} sentences ({data_stats['val_syllables']} syllables)")
643
+ click.echo(f"Test: {len(test_data)} sentences ({data_stats['test_syllables']} syllables)")
644
+
645
+ # Prepare training data
646
+ click.echo("\nExtracting syllable-level features...")
647
+ feature_start = time.time()
648
+ X_train = [sentence_to_syllable_features(syls) for syls, _ in train_data]
649
+ y_train = [labels for _, labels in train_data]
650
+ click.echo(f"Feature extraction: {format_duration(time.time() - feature_start)}")
651
+
652
+ # Train CRF
653
+ click.echo(f"\nTraining CRF model with {trainer}...")
654
+
655
+ use_wandb = wandb
656
+ if use_wandb:
657
+ try:
658
+ import wandb as wb
659
+ wb.init(project="word-segmentation-vietnamese", name=f"crf-{version}")
660
+ wb.config.update({
661
+ "trainer": trainer,
662
+ "c1": c1,
663
+ "c2": c2,
664
+ "max_iterations": max_iterations,
665
+ "num_feature_templates": len(FEATURE_TEMPLATES),
666
+ "train_sentences": len(train_data),
667
+ "val_sentences": len(val_data),
668
+ "test_sentences": len(test_data),
669
+ "version": version,
670
+ "level": "syllable",
671
+ })
672
+ except ImportError:
673
+ click.echo("wandb not installed, skipping logging", err=True)
674
+ use_wandb = False
675
+
676
+ crf_start = time.time()
677
+ crf_trainer.train(X_train, y_train, output_path, c1, c2, max_iterations, verbose=True)
678
+ crf_time = time.time() - crf_start
679
+ click.echo(f"\nModel saved to {output_path}")
680
+ click.echo(f"CRF training: {format_duration(crf_time)}")
681
+
682
+ # Evaluation
683
+ click.echo("\nEvaluating on test set...")
684
+
685
+ X_test = [sentence_to_syllable_features(syls) for syls, _ in test_data]
686
+ y_test = [labels for _, labels in test_data]
687
+ syllables_test = [syls for syls, _ in test_data]
688
+
689
+ y_pred = crf_trainer.predict(output_path, X_test)
690
+
691
+ # Syllable-level metrics
692
+ y_test_flat = [label for labels in y_test for label in labels]
693
+ y_pred_flat = [label for labels in y_pred for label in labels]
694
+
695
+ syl_accuracy = accuracy_score(y_test_flat, y_pred_flat)
696
+ syl_f1 = f1_score(y_test_flat, y_pred_flat, average="weighted")
697
+
698
+ click.echo(f"\nSyllable-level Accuracy: {syl_accuracy:.4f}")
699
+ click.echo(f"Syllable-level F1 (weighted): {syl_f1:.4f}")
700
+ click.echo("\nSyllable-level Classification Report:")
701
+ click.echo(classification_report(y_test_flat, y_pred_flat))
702
+
703
+ # Word-level metrics
704
+ precision, recall, word_f1 = compute_word_metrics(y_test, y_pred, syllables_test)
705
+ click.echo(f"\nWord-level Metrics:")
706
+ click.echo(f" Precision: {precision:.4f}")
707
+ click.echo(f" Recall: {recall:.4f}")
708
+ click.echo(f" F1: {word_f1:.4f}")
709
+
710
+ total_time = time.time() - total_start_time
711
+
712
+ # Collect metrics
713
+ metrics = {
714
+ "syl_accuracy": syl_accuracy,
715
+ "syl_f1": syl_f1,
716
+ "word_precision": precision,
717
+ "word_recall": recall,
718
+ "word_f1": word_f1,
719
+ }
720
+
721
+ # Save metadata
722
+ if not output:
723
+ save_metadata(output_dir, version, trainer, data_stats, c1, c2, max_iterations,
724
+ metrics, hw_info, total_time)
725
+
726
+ # Show examples
727
+ click.echo("\n" + "=" * 60)
728
+ click.echo("Example predictions:")
729
+ click.echo("=" * 60)
730
+ for i in range(min(3, len(test_data))):
731
+ syllables = syllables_test[i]
732
+ true_words = labels_to_words(syllables, y_test[i])
733
+ pred_words = labels_to_words(syllables, y_pred[i])
734
+ click.echo(f"\nInput: {' '.join(syllables)}")
735
+ click.echo(f"True: {' | '.join(true_words)}")
736
+ click.echo(f"Pred: {' | '.join(pred_words)}")
737
+
738
+ click.echo("\n" + "=" * 60)
739
+ click.echo("Training Summary")
740
+ click.echo("=" * 60)
741
+ click.echo(f"Trainer: {trainer}")
742
+ click.echo(f"Version: {version}")
743
+ click.echo(f"Model: {output_path}")
744
+ click.echo(f"Syllable Accuracy: {syl_accuracy:.4f}")
745
+ click.echo(f"Word F1: {word_f1:.4f}")
746
+ click.echo(f"Total time: {format_duration(total_time)}")
747
+ click.echo("=" * 60)
748
+
749
+ if use_wandb:
750
+ wb.log(metrics)
751
+ wb.finish()
752
+
753
+
754
+ if __name__ == "__main__":
755
+ train()