""" Error analysis script for Vietnamese Word Segmentation (TRE-1). Loads a trained VLSP 2013 model, predicts on the test set, and performs detailed error analysis across multiple dimensions: - Syllable-level confusion (B/I) - Word-level false splits and false joins - Error rate by word length - Top error patterns with context - Boundary errors (near sentence start/end) Usage: source .venv/bin/activate python src/evaluate_word_segmentation.py python src/evaluate_word_segmentation.py --model models/word_segmentation/vlsp2013 python src/evaluate_word_segmentation.py --output results/word_segmentation """ import csv from collections import Counter, defaultdict from pathlib import Path import click PROJECT_ROOT = Path(__file__).parent.parent # ============================================================================ # Feature Extraction (duplicated from train_word_segmentation.py) # ============================================================================ FEATURE_GROUPS = { "form": ["S[0]", "S[0].lower"], "type": ["S[0].istitle", "S[0].isupper", "S[0].isdigit", "S[0].ispunct", "S[0].len"], "morphology": ["S[0].prefix2", "S[0].suffix2"], "left": ["S[-1]", "S[-1].lower", "S[-2]", "S[-2].lower"], "right": ["S[1]", "S[1].lower", "S[2]", "S[2].lower"], "bigram": ["S[-1,0]", "S[0,1]"], "trigram": ["S[-1,0,1]"], "dictionary": ["S[-1,0].in_dict", "S[0,1].in_dict"], } def get_all_templates(): """Return all feature templates (all groups enabled).""" templates = [] for group_templates in FEATURE_GROUPS.values(): templates.extend(group_templates) return templates def get_syllable_at(syllables, position, offset): idx = position + offset if idx < 0: return "__BOS__" elif idx >= len(syllables): return "__EOS__" return syllables[idx] def is_punct(s): return len(s) == 1 and not s.isalnum() def load_dictionary(path): """Load dictionary from a text file (one word per line).""" dictionary = set() with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if line: dictionary.add(line) return dictionary def extract_syllable_features(syllables, position, active_templates, dictionary=None): active = set(active_templates) features = {} s0 = get_syllable_at(syllables, position, 0) is_boundary = s0 in ("__BOS__", "__EOS__") if "S[0]" in active: features["S[0]"] = s0 if "S[0].lower" in active: features["S[0].lower"] = s0.lower() if not is_boundary else s0 if "S[0].istitle" in active: features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False" if "S[0].isupper" in active: features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False" if "S[0].isdigit" in active: features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False" if "S[0].ispunct" in active: features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False" if "S[0].len" in active: features["S[0].len"] = str(len(s0)) if not is_boundary else "0" if "S[0].prefix2" in active: features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0 if "S[0].suffix2" in active: features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0 s_1 = get_syllable_at(syllables, position, -1) s_2 = get_syllable_at(syllables, position, -2) if "S[-1]" in active: features["S[-1]"] = s_1 if "S[-1].lower" in active: features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1 if "S[-2]" in active: features["S[-2]"] = s_2 if "S[-2].lower" in active: features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2 s1 = get_syllable_at(syllables, position, 1) s2 = get_syllable_at(syllables, position, 2) if "S[1]" in active: features["S[1]"] = s1 if "S[1].lower" in active: features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1 if "S[2]" in active: features["S[2]"] = s2 if "S[2].lower" in active: features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2 if "S[-1,0]" in active: features["S[-1,0]"] = f"{s_1}|{s0}" if "S[0,1]" in active: features["S[0,1]"] = f"{s0}|{s1}" if "S[-1,0,1]" in active: features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}" # G8: Dictionary lookup — longest match for bigram windows if dictionary is not None: n = len(syllables) if "S[-1,0].in_dict" in active and position >= 1: match = "" for length in range(2, min(6, position + 2)): start = position - length + 1 if start >= 0: ngram = " ".join(syllables[start:position + 1]).lower() if ngram in dictionary: match = ngram features["S[-1,0].in_dict"] = match if match else "0" if "S[0,1].in_dict" in active and position < n - 1: match = "" for length in range(2, min(6, n - position + 1)): ngram = " ".join(syllables[position:position + length]).lower() if ngram in dictionary: match = ngram features["S[0,1].in_dict"] = match if match else "0" return features def sentence_to_syllable_features(syllables, active_templates, dictionary=None): return [ [f"{k}={v}" for k, v in extract_syllable_features(syllables, i, active_templates, dictionary).items()] for i in range(len(syllables)) ] # ============================================================================ # Data Loading # ============================================================================ def load_vlsp2013_test(data_dir): """Load VLSP 2013 test set.""" tag_map = {"B-W": "B", "I-W": "I"} sequences = [] current_syls = [] current_labels = [] with open(data_dir / "test.txt", encoding="utf-8") as f: for line in f: line = line.strip() if not line: if current_syls: sequences.append((current_syls, current_labels)) current_syls = [] current_labels = [] else: parts = line.split("\t") if len(parts) == 2: current_syls.append(parts[0]) current_labels.append(tag_map.get(parts[1], parts[1])) if current_syls: sequences.append((current_syls, current_labels)) return sequences # ============================================================================ # Label Utilities # ============================================================================ def labels_to_words(syllables, labels): """Convert syllable sequence and BIO labels back to words.""" words = [] current_word = [] for syl, label in zip(syllables, labels): if label == "B": if current_word: words.append(" ".join(current_word)) current_word = [syl] else: current_word.append(syl) if current_word: words.append(" ".join(current_word)) return words def labels_to_word_spans(syllables, labels): """Convert BIO labels to word spans as (start_idx, end_idx, word_text).""" spans = [] start = 0 for i, (syl, label) in enumerate(zip(syllables, labels)): if label == "B" and i > 0: word = " ".join(syllables[start:i]) spans.append((start, i, word)) start = i if start < len(syllables): word = " ".join(syllables[start:]) spans.append((start, len(syllables), word)) return spans # ============================================================================ # Error Analysis # ============================================================================ def analyze_syllable_errors(all_true, all_pred): """Analyze syllable-level B/I confusion.""" b_to_i = 0 # false join: predicted I where truth is B i_to_b = 0 # false split: predicted B where truth is I total_b = 0 total_i = 0 for true_labels, pred_labels in zip(all_true, all_pred): for t, p in zip(true_labels, pred_labels): if t == "B": total_b += 1 if p == "I": b_to_i += 1 elif t == "I": total_i += 1 if p == "B": i_to_b += 1 return { "total_b": total_b, "total_i": total_i, "b_to_i": b_to_i, "i_to_b": i_to_b, "b_to_i_rate": b_to_i / total_b if total_b > 0 else 0, "i_to_b_rate": i_to_b / total_i if total_i > 0 else 0, } def analyze_word_errors(all_syllables, all_true, all_pred): """Analyze word-level errors: false splits and false joins.""" false_splits = [] # compound words incorrectly broken apart (I→B) false_joins = [] # separate words incorrectly merged (B→I) for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): true_spans = set() pred_spans = set() for start, end, word in labels_to_word_spans(syllables, true_labels): true_spans.add((start, end)) for start, end, word in labels_to_word_spans(syllables, pred_labels): pred_spans.add((start, end)) true_words = labels_to_words(syllables, true_labels) pred_words = labels_to_words(syllables, pred_labels) # Find words in truth that were split in prediction true_span_list = labels_to_word_spans(syllables, true_labels) pred_span_list = labels_to_word_spans(syllables, pred_labels) for start, end, word in true_span_list: n_syls = end - start if n_syls > 1 and (start, end) not in pred_spans: # This true multi-syllable word was not predicted as a unit # Find what the prediction did with these syllables pred_parts = [] for ps, pe, pw in pred_span_list: if ps >= start and pe <= end: pred_parts.append(pw) elif ps < end and pe > start: pred_parts.append(pw) if len(pred_parts) > 1: context_start = max(0, start - 2) context_end = min(len(syllables), end + 2) context = " ".join(syllables[context_start:context_end]) false_splits.append((word, pred_parts, context)) for start, end, word in pred_span_list: n_syls = end - start if n_syls > 1 and (start, end) not in true_spans: # This predicted multi-syllable word was not in truth # Find what truth had for these syllables true_parts = [] for ts, te, tw in true_span_list: if ts >= start and te <= end: true_parts.append(tw) elif ts < end and te > start: true_parts.append(tw) if len(true_parts) > 1: context_start = max(0, start - 2) context_end = min(len(syllables), end + 2) context = " ".join(syllables[context_start:context_end]) false_joins.append((word, true_parts, context)) return false_splits, false_joins def analyze_errors_by_word_length(all_syllables, all_true, all_pred): """Compute error rates broken down by true word length (in syllables).""" correct_by_len = Counter() total_by_len = Counter() for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): true_spans = set() pred_spans = set() for start, end, word in labels_to_word_spans(syllables, true_labels): true_spans.add((start, end)) n_syls = end - start total_by_len[n_syls] += 1 for start, end, word in labels_to_word_spans(syllables, pred_labels): pred_spans.add((start, end)) for span in true_spans: n_syls = span[1] - span[0] if span in pred_spans: correct_by_len[n_syls] += 1 results = {} for length in sorted(total_by_len.keys()): total = total_by_len[length] correct = correct_by_len[length] results[length] = { "total": total, "correct": correct, "errors": total - correct, "accuracy": correct / total if total > 0 else 0, "error_rate": (total - correct) / total if total > 0 else 0, } return results def analyze_boundary_errors(all_syllables, all_true, all_pred, window=3): """Analyze errors near sentence start/end.""" start_errors = 0 start_total = 0 end_errors = 0 end_total = 0 middle_errors = 0 middle_total = 0 for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): n = len(syllables) for i, (t, p) in enumerate(zip(true_labels, pred_labels)): if i < window: start_total += 1 if t != p: start_errors += 1 elif i >= n - window: end_total += 1 if t != p: end_errors += 1 else: middle_total += 1 if t != p: middle_errors += 1 return { "start": {"errors": start_errors, "total": start_total, "error_rate": start_errors / start_total if start_total > 0 else 0}, "end": {"errors": end_errors, "total": end_total, "error_rate": end_errors / end_total if end_total > 0 else 0}, "middle": {"errors": middle_errors, "total": middle_total, "error_rate": middle_errors / middle_total if middle_total > 0 else 0}, } def get_top_error_patterns(all_syllables, all_true, all_pred, top_n=20): """Find the most common incorrectly segmented syllable pairs.""" error_patterns = Counter() for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): for i, (t, p) in enumerate(zip(true_labels, pred_labels)): if t != p: syl = syllables[i] prev_syl = syllables[i - 1] if i > 0 else "__BOS__" next_syl = syllables[i + 1] if i < len(syllables) - 1 else "__EOS__" error_type = f"{t}→{p}" pattern = (prev_syl, syl, next_syl, error_type) error_patterns[pattern] += 1 return error_patterns.most_common(top_n) def compute_word_metrics(all_syllables, all_true, all_pred): """Compute word-level precision, recall, F1.""" correct = 0 total_pred = 0 total_true = 0 for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): true_words = labels_to_words(syllables, true_labels) pred_words = labels_to_words(syllables, pred_labels) total_true += len(true_words) total_pred += len(pred_words) true_boundaries = set() pred_boundaries = set() pos = 0 for word in true_words: n_syls = len(word.split()) true_boundaries.add((pos, pos + n_syls)) pos += n_syls pos = 0 for word in pred_words: n_syls = len(word.split()) pred_boundaries.add((pos, pos + n_syls)) pos += n_syls correct += len(true_boundaries & pred_boundaries) precision = correct / total_pred if total_pred > 0 else 0 recall = correct / total_true if total_true > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return { "precision": precision, "recall": recall, "f1": f1, "total_true": total_true, "total_pred": total_pred, "correct": correct, } # ============================================================================ # Reporting # ============================================================================ def format_report(syl_errors, word_metrics, false_splits, false_joins, length_errors, boundary_errors, top_patterns, num_sentences, num_syllables): """Format error analysis as text report.""" lines = [] lines.append("=" * 70) lines.append("Word Segmentation Error Analysis — VLSP 2013 Test Set") lines.append("=" * 70) lines.append("") # Summary total_syl_errors = syl_errors["b_to_i"] + syl_errors["i_to_b"] lines.append("1. Summary") lines.append("-" * 40) lines.append(f" Sentences: {num_sentences:,}") lines.append(f" Syllables: {num_syllables:,}") lines.append(f" True words: {word_metrics['total_true']:,}") lines.append(f" Predicted words: {word_metrics['total_pred']:,}") lines.append(f" Correct words: {word_metrics['correct']:,}") lines.append(f" Word Precision: {word_metrics['precision']:.4f} ({word_metrics['precision']*100:.2f}%)") lines.append(f" Word Recall: {word_metrics['recall']:.4f} ({word_metrics['recall']*100:.2f}%)") lines.append(f" Word F1: {word_metrics['f1']:.4f} ({word_metrics['f1']*100:.2f}%)") lines.append(f" Syllable errors: {total_syl_errors:,} / {num_syllables:,} ({total_syl_errors/num_syllables*100:.2f}%)") lines.append(f" Word errors (FN): {word_metrics['total_true'] - word_metrics['correct']:,}") lines.append(f" Word errors (FP): {word_metrics['total_pred'] - word_metrics['correct']:,}") lines.append("") # Syllable confusion lines.append("2. Syllable-Level Confusion (B/I)") lines.append("-" * 40) lines.append(f" True B, Predicted I (false join): {syl_errors['b_to_i']:,} / {syl_errors['total_b']:,} ({syl_errors['b_to_i_rate']*100:.2f}%)") lines.append(f" True I, Predicted B (false split): {syl_errors['i_to_b']:,} / {syl_errors['total_i']:,} ({syl_errors['i_to_b_rate']*100:.2f}%)") lines.append("") lines.append(" Confusion Matrix:") lines.append(f" Pred B Pred I") lines.append(f" True B {syl_errors['total_b'] - syl_errors['b_to_i']:>8,} {syl_errors['b_to_i']:>8,}") lines.append(f" True I {syl_errors['i_to_b']:>8,} {syl_errors['total_i'] - syl_errors['i_to_b']:>8,}") lines.append("") # False splits split_counter = Counter() for word, parts, context in false_splits: split_counter[word] += 1 lines.append("3. Top False Splits (compound words broken apart)") lines.append("-" * 70) lines.append(f" Total false splits: {len(false_splits):,}") lines.append(f" Unique words affected: {len(split_counter):,}") lines.append("") lines.append(f" {'Word':<25} {'Count':<8} {'Example context'}") lines.append(f" {'----':<25} {'-----':<8} {'---------------'}") for word, count in split_counter.most_common(20): # Find an example context for this word for w, parts, ctx in false_splits: if w == word: lines.append(f" {word:<25} {count:<8} {ctx}") break lines.append("") # False joins join_counter = Counter() for word, parts, context in false_joins: join_counter[word] += 1 lines.append("4. Top False Joins (separate words merged)") lines.append("-" * 70) lines.append(f" Total false joins: {len(false_joins):,}") lines.append(f" Unique words affected: {len(join_counter):,}") lines.append("") lines.append(f" {'Merged as':<25} {'Count':<8} {'Should be':<30} {'Context'}") lines.append(f" {'---------':<25} {'-----':<8} {'---------':<30} {'-------'}") for word, count in join_counter.most_common(20): for w, parts, ctx in false_joins: if w == word: should_be = " | ".join(parts) lines.append(f" {word:<25} {count:<8} {should_be:<30} {ctx}") break lines.append("") # Error by word length lines.append("5. Error Rate by Word Length (syllables)") lines.append("-" * 70) lines.append(f" {'Length':<10} {'Total':<10} {'Correct':<10} {'Errors':<10} {'Accuracy':<12} {'Error Rate'}") lines.append(f" {'------':<10} {'-----':<10} {'-------':<10} {'------':<10} {'--------':<12} {'----------'}") for length, stats in sorted(length_errors.items()): label = f"{length}-syl" lines.append(f" {label:<10} {stats['total']:<10,} {stats['correct']:<10,} {stats['errors']:<10,} {stats['accuracy']*100:>8.2f}% {stats['error_rate']*100:.2f}%") lines.append("") # Boundary errors lines.append("6. Error Rate by Position in Sentence") lines.append("-" * 40) for region, stats in boundary_errors.items(): label = f"{region.capitalize()} (first/last 3 syls)" if region != "middle" else "Middle" lines.append(f" {label:<35} {stats['errors']:,} / {stats['total']:,} ({stats['error_rate']*100:.2f}%)") lines.append("") # Top error patterns lines.append("7. Top Error Patterns (syllable in context)") lines.append("-" * 70) lines.append(f" {'Prev syl':<15} {'Current':<15} {'Next syl':<15} {'Error':<8} {'Count'}") lines.append(f" {'--------':<15} {'-------':<15} {'--------':<15} {'-----':<8} {'-----'}") for (prev_syl, syl, next_syl, error_type), count in top_patterns: lines.append(f" {prev_syl:<15} {syl:<15} {next_syl:<15} {error_type:<8} {count}") lines.append("") lines.append("=" * 70) return "\n".join(lines) def save_errors_csv(output_path, false_splits, false_joins, length_errors): """Save error details to CSV files.""" output_dir = output_path.parent # False splits CSV splits_path = output_dir / "false_splits.csv" split_counter = Counter() split_examples = {} for word, parts, context in false_splits: split_counter[word] += 1 if word not in split_examples: split_examples[word] = (parts, context) with open(splits_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["word", "count", "predicted_parts", "context"]) for word, count in split_counter.most_common(): parts, ctx = split_examples[word] writer.writerow([word, count, " | ".join(parts), ctx]) # False joins CSV joins_path = output_dir / "false_joins.csv" join_counter = Counter() join_examples = {} for word, parts, context in false_joins: join_counter[word] += 1 if word not in join_examples: join_examples[word] = (parts, context) with open(joins_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["merged_word", "count", "true_parts", "context"]) for word, count in join_counter.most_common(): parts, ctx = join_examples[word] writer.writerow([word, count, " | ".join(parts), ctx]) # Word length error rates CSV length_path = output_dir / "error_by_length.csv" with open(length_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["word_length_syllables", "total", "correct", "errors", "accuracy", "error_rate"]) for length, stats in sorted(length_errors.items()): writer.writerow([length, stats["total"], stats["correct"], stats["errors"], f"{stats['accuracy']:.4f}", f"{stats['error_rate']:.4f}"]) return splits_path, joins_path, length_path # ============================================================================ # Main # ============================================================================ @click.command() @click.option( "--model", "-m", default=None, help="Model directory (default: models/word_segmentation/vlsp2013)", ) @click.option( "--data-dir", "-d", default=None, help="Dataset directory (default: datasets/c7veardo0e)", ) @click.option( "--output", "-o", default=None, help="Output directory for results (default: results/word_segmentation)", ) def main(model, data_dir, output): """Run error analysis on VLSP 2013 word segmentation test set.""" # Resolve paths model_dir = Path(model) if model else PROJECT_ROOT / "models" / "word_segmentation" / "vlsp2013" data_path = Path(data_dir) if data_dir else PROJECT_ROOT / "datasets" / "c7veardo0e" output_dir = Path(output) if output else PROJECT_ROOT / "results" / "word_segmentation" output_dir.mkdir(parents=True, exist_ok=True) model_path = model_dir / "model.crf" if not model_path.exists(): model_path = model_dir / "model.crfsuite" if not model_path.exists(): raise click.ClickException(f"No model file found in {model_dir}") click.echo(f"Model: {model_path}") click.echo(f"Data: {data_path}") click.echo(f"Output: {output_dir}") click.echo("") # Load model click.echo("Loading model...") model_path_str = str(model_path) if model_path_str.endswith(".crf"): from underthesea_core import CRFModel, CRFTagger crf_model = CRFModel.load(model_path_str) tagger = CRFTagger.from_model(crf_model) predict_fn = lambda X: [tagger.tag(xseq) for xseq in X] else: import pycrfsuite tagger = pycrfsuite.Tagger() tagger.open(model_path_str) predict_fn = lambda X: [tagger.tag(xseq) for xseq in X] # Load test data click.echo("Loading VLSP 2013 test set...") test_data = load_vlsp2013_test(data_path) click.echo(f" {len(test_data)} sentences") all_syllables = [syls for syls, _ in test_data] all_true = [labels for _, labels in test_data] num_syllables = sum(len(syls) for syls in all_syllables) click.echo(f" {num_syllables:,} syllables") # Load dictionary if available dict_path = model_dir / "dictionary.txt" dictionary = None if dict_path.exists(): dictionary = load_dictionary(dict_path) click.echo(f" Dictionary: {len(dictionary)} words from {dict_path}") # Extract features and predict click.echo("Extracting features...") active_templates = get_all_templates() if dictionary is None: active_templates = [t for t in active_templates if t not in FEATURE_GROUPS["dictionary"]] X_test = [sentence_to_syllable_features(syls, active_templates, dictionary) for syls in all_syllables] click.echo("Predicting...") all_pred = predict_fn(X_test) # Run analyses click.echo("Analyzing errors...") # 1. Syllable confusion syl_errors = analyze_syllable_errors(all_true, all_pred) # 2. Word metrics word_metrics = compute_word_metrics(all_syllables, all_true, all_pred) # 3. Word-level errors false_splits, false_joins = analyze_word_errors(all_syllables, all_true, all_pred) # 4. Error by word length length_errors = analyze_errors_by_word_length(all_syllables, all_true, all_pred) # 5. Boundary errors boundary_errors = analyze_boundary_errors(all_syllables, all_true, all_pred) # 6. Top error patterns top_patterns = get_top_error_patterns(all_syllables, all_true, all_pred, top_n=20) # Generate report report = format_report( syl_errors, word_metrics, false_splits, false_joins, length_errors, boundary_errors, top_patterns, len(test_data), num_syllables, ) # Print to console click.echo("") click.echo(report) # Save report report_path = output_dir / "error_analysis.txt" with open(report_path, "w", encoding="utf-8") as f: f.write(report) click.echo(f"\nReport saved to {report_path}") # Save CSVs splits_csv, joins_csv, length_csv = save_errors_csv( report_path, false_splits, false_joins, length_errors ) click.echo(f"False splits CSV: {splits_csv}") click.echo(f"False joins CSV: {joins_csv}") click.echo(f"Error by length: {length_csv}") if __name__ == "__main__": main()