| """ |
| Error analysis script for Vietnamese Word Segmentation (TRE-1). |
| |
| Loads a trained VLSP 2013 model, predicts on the test set, and performs |
| detailed error analysis across multiple dimensions: |
| - Syllable-level confusion (B/I) |
| - Word-level false splits and false joins |
| - Error rate by word length |
| - Top error patterns with context |
| - Boundary errors (near sentence start/end) |
| |
| Usage: |
| source .venv/bin/activate |
| python src/evaluate_word_segmentation.py |
| python src/evaluate_word_segmentation.py --model models/word_segmentation/vlsp2013 |
| python src/evaluate_word_segmentation.py --output results/word_segmentation |
| """ |
|
|
| import csv |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| import click |
|
|
| PROJECT_ROOT = Path(__file__).parent.parent |
|
|
|
|
| |
| |
| |
|
|
| FEATURE_GROUPS = { |
| "form": ["S[0]", "S[0].lower"], |
| "type": ["S[0].istitle", "S[0].isupper", "S[0].isdigit", "S[0].ispunct", "S[0].len"], |
| "morphology": ["S[0].prefix2", "S[0].suffix2"], |
| "left": ["S[-1]", "S[-1].lower", "S[-2]", "S[-2].lower"], |
| "right": ["S[1]", "S[1].lower", "S[2]", "S[2].lower"], |
| "bigram": ["S[-1,0]", "S[0,1]"], |
| "trigram": ["S[-1,0,1]"], |
| "dictionary": ["S[-1,0].in_dict", "S[0,1].in_dict"], |
| } |
|
|
|
|
| def get_all_templates(): |
| """Return all feature templates (all groups enabled).""" |
| templates = [] |
| for group_templates in FEATURE_GROUPS.values(): |
| templates.extend(group_templates) |
| return templates |
|
|
|
|
| def get_syllable_at(syllables, position, offset): |
| idx = position + offset |
| if idx < 0: |
| return "__BOS__" |
| elif idx >= len(syllables): |
| return "__EOS__" |
| return syllables[idx] |
|
|
|
|
| def is_punct(s): |
| return len(s) == 1 and not s.isalnum() |
|
|
|
|
| def load_dictionary(path): |
| """Load dictionary from a text file (one word per line).""" |
| dictionary = set() |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| dictionary.add(line) |
| return dictionary |
|
|
|
|
| def extract_syllable_features(syllables, position, active_templates, dictionary=None): |
| active = set(active_templates) |
| features = {} |
|
|
| s0 = get_syllable_at(syllables, position, 0) |
| is_boundary = s0 in ("__BOS__", "__EOS__") |
|
|
| if "S[0]" in active: |
| features["S[0]"] = s0 |
| if "S[0].lower" in active: |
| features["S[0].lower"] = s0.lower() if not is_boundary else s0 |
| if "S[0].istitle" in active: |
| features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False" |
| if "S[0].isupper" in active: |
| features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False" |
| if "S[0].isdigit" in active: |
| features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False" |
| if "S[0].ispunct" in active: |
| features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False" |
| if "S[0].len" in active: |
| features["S[0].len"] = str(len(s0)) if not is_boundary else "0" |
| if "S[0].prefix2" in active: |
| features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0 |
| if "S[0].suffix2" in active: |
| features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0 |
|
|
| s_1 = get_syllable_at(syllables, position, -1) |
| s_2 = get_syllable_at(syllables, position, -2) |
| if "S[-1]" in active: |
| features["S[-1]"] = s_1 |
| if "S[-1].lower" in active: |
| features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1 |
| if "S[-2]" in active: |
| features["S[-2]"] = s_2 |
| if "S[-2].lower" in active: |
| features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2 |
|
|
| s1 = get_syllable_at(syllables, position, 1) |
| s2 = get_syllable_at(syllables, position, 2) |
| if "S[1]" in active: |
| features["S[1]"] = s1 |
| if "S[1].lower" in active: |
| features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1 |
| if "S[2]" in active: |
| features["S[2]"] = s2 |
| if "S[2].lower" in active: |
| features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2 |
|
|
| if "S[-1,0]" in active: |
| features["S[-1,0]"] = f"{s_1}|{s0}" |
| if "S[0,1]" in active: |
| features["S[0,1]"] = f"{s0}|{s1}" |
| if "S[-1,0,1]" in active: |
| features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}" |
|
|
| |
| if dictionary is not None: |
| n = len(syllables) |
|
|
| if "S[-1,0].in_dict" in active and position >= 1: |
| match = "" |
| for length in range(2, min(6, position + 2)): |
| start = position - length + 1 |
| if start >= 0: |
| ngram = " ".join(syllables[start:position + 1]).lower() |
| if ngram in dictionary: |
| match = ngram |
| features["S[-1,0].in_dict"] = match if match else "0" |
|
|
| if "S[0,1].in_dict" in active and position < n - 1: |
| match = "" |
| for length in range(2, min(6, n - position + 1)): |
| ngram = " ".join(syllables[position:position + length]).lower() |
| if ngram in dictionary: |
| match = ngram |
| features["S[0,1].in_dict"] = match if match else "0" |
|
|
| return features |
|
|
|
|
| def sentence_to_syllable_features(syllables, active_templates, dictionary=None): |
| return [ |
| [f"{k}={v}" for k, v in extract_syllable_features(syllables, i, active_templates, dictionary).items()] |
| for i in range(len(syllables)) |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def load_vlsp2013_test(data_dir): |
| """Load VLSP 2013 test set.""" |
| tag_map = {"B-W": "B", "I-W": "I"} |
| sequences = [] |
| current_syls = [] |
| current_labels = [] |
|
|
| with open(data_dir / "test.txt", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| if current_syls: |
| sequences.append((current_syls, current_labels)) |
| current_syls = [] |
| current_labels = [] |
| else: |
| parts = line.split("\t") |
| if len(parts) == 2: |
| current_syls.append(parts[0]) |
| current_labels.append(tag_map.get(parts[1], parts[1])) |
| if current_syls: |
| sequences.append((current_syls, current_labels)) |
|
|
| return sequences |
|
|
|
|
| |
| |
| |
|
|
| def labels_to_words(syllables, labels): |
| """Convert syllable sequence and BIO labels back to words.""" |
| words = [] |
| current_word = [] |
| for syl, label in zip(syllables, labels): |
| if label == "B": |
| if current_word: |
| words.append(" ".join(current_word)) |
| current_word = [syl] |
| else: |
| current_word.append(syl) |
| if current_word: |
| words.append(" ".join(current_word)) |
| return words |
|
|
|
|
| def labels_to_word_spans(syllables, labels): |
| """Convert BIO labels to word spans as (start_idx, end_idx, word_text).""" |
| spans = [] |
| start = 0 |
| for i, (syl, label) in enumerate(zip(syllables, labels)): |
| if label == "B" and i > 0: |
| word = " ".join(syllables[start:i]) |
| spans.append((start, i, word)) |
| start = i |
| if start < len(syllables): |
| word = " ".join(syllables[start:]) |
| spans.append((start, len(syllables), word)) |
| return spans |
|
|
|
|
| |
| |
| |
|
|
| def analyze_syllable_errors(all_true, all_pred): |
| """Analyze syllable-level B/I confusion.""" |
| b_to_i = 0 |
| i_to_b = 0 |
| total_b = 0 |
| total_i = 0 |
|
|
| for true_labels, pred_labels in zip(all_true, all_pred): |
| for t, p in zip(true_labels, pred_labels): |
| if t == "B": |
| total_b += 1 |
| if p == "I": |
| b_to_i += 1 |
| elif t == "I": |
| total_i += 1 |
| if p == "B": |
| i_to_b += 1 |
|
|
| return { |
| "total_b": total_b, |
| "total_i": total_i, |
| "b_to_i": b_to_i, |
| "i_to_b": i_to_b, |
| "b_to_i_rate": b_to_i / total_b if total_b > 0 else 0, |
| "i_to_b_rate": i_to_b / total_i if total_i > 0 else 0, |
| } |
|
|
|
|
| def analyze_word_errors(all_syllables, all_true, all_pred): |
| """Analyze word-level errors: false splits and false joins.""" |
| false_splits = [] |
| false_joins = [] |
|
|
| for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): |
| true_spans = set() |
| pred_spans = set() |
|
|
| for start, end, word in labels_to_word_spans(syllables, true_labels): |
| true_spans.add((start, end)) |
| for start, end, word in labels_to_word_spans(syllables, pred_labels): |
| pred_spans.add((start, end)) |
|
|
| true_words = labels_to_words(syllables, true_labels) |
| pred_words = labels_to_words(syllables, pred_labels) |
|
|
| |
| true_span_list = labels_to_word_spans(syllables, true_labels) |
| pred_span_list = labels_to_word_spans(syllables, pred_labels) |
|
|
| for start, end, word in true_span_list: |
| n_syls = end - start |
| if n_syls > 1 and (start, end) not in pred_spans: |
| |
| |
| pred_parts = [] |
| for ps, pe, pw in pred_span_list: |
| if ps >= start and pe <= end: |
| pred_parts.append(pw) |
| elif ps < end and pe > start: |
| pred_parts.append(pw) |
| if len(pred_parts) > 1: |
| context_start = max(0, start - 2) |
| context_end = min(len(syllables), end + 2) |
| context = " ".join(syllables[context_start:context_end]) |
| false_splits.append((word, pred_parts, context)) |
|
|
| for start, end, word in pred_span_list: |
| n_syls = end - start |
| if n_syls > 1 and (start, end) not in true_spans: |
| |
| |
| true_parts = [] |
| for ts, te, tw in true_span_list: |
| if ts >= start and te <= end: |
| true_parts.append(tw) |
| elif ts < end and te > start: |
| true_parts.append(tw) |
| if len(true_parts) > 1: |
| context_start = max(0, start - 2) |
| context_end = min(len(syllables), end + 2) |
| context = " ".join(syllables[context_start:context_end]) |
| false_joins.append((word, true_parts, context)) |
|
|
| return false_splits, false_joins |
|
|
|
|
| def analyze_errors_by_word_length(all_syllables, all_true, all_pred): |
| """Compute error rates broken down by true word length (in syllables).""" |
| correct_by_len = Counter() |
| total_by_len = Counter() |
|
|
| for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): |
| true_spans = set() |
| pred_spans = set() |
|
|
| for start, end, word in labels_to_word_spans(syllables, true_labels): |
| true_spans.add((start, end)) |
| n_syls = end - start |
| total_by_len[n_syls] += 1 |
|
|
| for start, end, word in labels_to_word_spans(syllables, pred_labels): |
| pred_spans.add((start, end)) |
|
|
| for span in true_spans: |
| n_syls = span[1] - span[0] |
| if span in pred_spans: |
| correct_by_len[n_syls] += 1 |
|
|
| results = {} |
| for length in sorted(total_by_len.keys()): |
| total = total_by_len[length] |
| correct = correct_by_len[length] |
| results[length] = { |
| "total": total, |
| "correct": correct, |
| "errors": total - correct, |
| "accuracy": correct / total if total > 0 else 0, |
| "error_rate": (total - correct) / total if total > 0 else 0, |
| } |
| return results |
|
|
|
|
| def analyze_boundary_errors(all_syllables, all_true, all_pred, window=3): |
| """Analyze errors near sentence start/end.""" |
| start_errors = 0 |
| start_total = 0 |
| end_errors = 0 |
| end_total = 0 |
| middle_errors = 0 |
| middle_total = 0 |
|
|
| for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): |
| n = len(syllables) |
| for i, (t, p) in enumerate(zip(true_labels, pred_labels)): |
| if i < window: |
| start_total += 1 |
| if t != p: |
| start_errors += 1 |
| elif i >= n - window: |
| end_total += 1 |
| if t != p: |
| end_errors += 1 |
| else: |
| middle_total += 1 |
| if t != p: |
| middle_errors += 1 |
|
|
| return { |
| "start": {"errors": start_errors, "total": start_total, |
| "error_rate": start_errors / start_total if start_total > 0 else 0}, |
| "end": {"errors": end_errors, "total": end_total, |
| "error_rate": end_errors / end_total if end_total > 0 else 0}, |
| "middle": {"errors": middle_errors, "total": middle_total, |
| "error_rate": middle_errors / middle_total if middle_total > 0 else 0}, |
| } |
|
|
|
|
| def get_top_error_patterns(all_syllables, all_true, all_pred, top_n=20): |
| """Find the most common incorrectly segmented syllable pairs.""" |
| error_patterns = Counter() |
|
|
| for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): |
| for i, (t, p) in enumerate(zip(true_labels, pred_labels)): |
| if t != p: |
| syl = syllables[i] |
| prev_syl = syllables[i - 1] if i > 0 else "__BOS__" |
| next_syl = syllables[i + 1] if i < len(syllables) - 1 else "__EOS__" |
| error_type = f"{t}→{p}" |
| pattern = (prev_syl, syl, next_syl, error_type) |
| error_patterns[pattern] += 1 |
|
|
| return error_patterns.most_common(top_n) |
|
|
|
|
| def compute_word_metrics(all_syllables, all_true, all_pred): |
| """Compute word-level precision, recall, F1.""" |
| correct = 0 |
| total_pred = 0 |
| total_true = 0 |
|
|
| for syllables, true_labels, pred_labels in zip(all_syllables, all_true, all_pred): |
| true_words = labels_to_words(syllables, true_labels) |
| pred_words = labels_to_words(syllables, pred_labels) |
|
|
| total_true += len(true_words) |
| total_pred += len(pred_words) |
|
|
| true_boundaries = set() |
| pred_boundaries = set() |
|
|
| pos = 0 |
| for word in true_words: |
| n_syls = len(word.split()) |
| true_boundaries.add((pos, pos + n_syls)) |
| pos += n_syls |
|
|
| pos = 0 |
| for word in pred_words: |
| n_syls = len(word.split()) |
| pred_boundaries.add((pos, pos + n_syls)) |
| pos += n_syls |
|
|
| correct += len(true_boundaries & pred_boundaries) |
|
|
| precision = correct / total_pred if total_pred > 0 else 0 |
| recall = correct / total_true if total_true > 0 else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
| return { |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "total_true": total_true, |
| "total_pred": total_pred, |
| "correct": correct, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def format_report(syl_errors, word_metrics, false_splits, false_joins, |
| length_errors, boundary_errors, top_patterns, |
| num_sentences, num_syllables): |
| """Format error analysis as text report.""" |
| lines = [] |
| lines.append("=" * 70) |
| lines.append("Word Segmentation Error Analysis — VLSP 2013 Test Set") |
| lines.append("=" * 70) |
| lines.append("") |
|
|
| |
| total_syl_errors = syl_errors["b_to_i"] + syl_errors["i_to_b"] |
| lines.append("1. Summary") |
| lines.append("-" * 40) |
| lines.append(f" Sentences: {num_sentences:,}") |
| lines.append(f" Syllables: {num_syllables:,}") |
| lines.append(f" True words: {word_metrics['total_true']:,}") |
| lines.append(f" Predicted words: {word_metrics['total_pred']:,}") |
| lines.append(f" Correct words: {word_metrics['correct']:,}") |
| lines.append(f" Word Precision: {word_metrics['precision']:.4f} ({word_metrics['precision']*100:.2f}%)") |
| lines.append(f" Word Recall: {word_metrics['recall']:.4f} ({word_metrics['recall']*100:.2f}%)") |
| lines.append(f" Word F1: {word_metrics['f1']:.4f} ({word_metrics['f1']*100:.2f}%)") |
| lines.append(f" Syllable errors: {total_syl_errors:,} / {num_syllables:,} ({total_syl_errors/num_syllables*100:.2f}%)") |
| lines.append(f" Word errors (FN): {word_metrics['total_true'] - word_metrics['correct']:,}") |
| lines.append(f" Word errors (FP): {word_metrics['total_pred'] - word_metrics['correct']:,}") |
| lines.append("") |
|
|
| |
| lines.append("2. Syllable-Level Confusion (B/I)") |
| lines.append("-" * 40) |
| lines.append(f" True B, Predicted I (false join): {syl_errors['b_to_i']:,} / {syl_errors['total_b']:,} ({syl_errors['b_to_i_rate']*100:.2f}%)") |
| lines.append(f" True I, Predicted B (false split): {syl_errors['i_to_b']:,} / {syl_errors['total_i']:,} ({syl_errors['i_to_b_rate']*100:.2f}%)") |
| lines.append("") |
| lines.append(" Confusion Matrix:") |
| lines.append(f" Pred B Pred I") |
| lines.append(f" True B {syl_errors['total_b'] - syl_errors['b_to_i']:>8,} {syl_errors['b_to_i']:>8,}") |
| lines.append(f" True I {syl_errors['i_to_b']:>8,} {syl_errors['total_i'] - syl_errors['i_to_b']:>8,}") |
| lines.append("") |
|
|
| |
| split_counter = Counter() |
| for word, parts, context in false_splits: |
| split_counter[word] += 1 |
|
|
| lines.append("3. Top False Splits (compound words broken apart)") |
| lines.append("-" * 70) |
| lines.append(f" Total false splits: {len(false_splits):,}") |
| lines.append(f" Unique words affected: {len(split_counter):,}") |
| lines.append("") |
| lines.append(f" {'Word':<25} {'Count':<8} {'Example context'}") |
| lines.append(f" {'----':<25} {'-----':<8} {'---------------'}") |
| for word, count in split_counter.most_common(20): |
| |
| for w, parts, ctx in false_splits: |
| if w == word: |
| lines.append(f" {word:<25} {count:<8} {ctx}") |
| break |
| lines.append("") |
|
|
| |
| join_counter = Counter() |
| for word, parts, context in false_joins: |
| join_counter[word] += 1 |
|
|
| lines.append("4. Top False Joins (separate words merged)") |
| lines.append("-" * 70) |
| lines.append(f" Total false joins: {len(false_joins):,}") |
| lines.append(f" Unique words affected: {len(join_counter):,}") |
| lines.append("") |
| lines.append(f" {'Merged as':<25} {'Count':<8} {'Should be':<30} {'Context'}") |
| lines.append(f" {'---------':<25} {'-----':<8} {'---------':<30} {'-------'}") |
| for word, count in join_counter.most_common(20): |
| for w, parts, ctx in false_joins: |
| if w == word: |
| should_be = " | ".join(parts) |
| lines.append(f" {word:<25} {count:<8} {should_be:<30} {ctx}") |
| break |
| lines.append("") |
|
|
| |
| lines.append("5. Error Rate by Word Length (syllables)") |
| lines.append("-" * 70) |
| lines.append(f" {'Length':<10} {'Total':<10} {'Correct':<10} {'Errors':<10} {'Accuracy':<12} {'Error Rate'}") |
| lines.append(f" {'------':<10} {'-----':<10} {'-------':<10} {'------':<10} {'--------':<12} {'----------'}") |
| for length, stats in sorted(length_errors.items()): |
| label = f"{length}-syl" |
| lines.append(f" {label:<10} {stats['total']:<10,} {stats['correct']:<10,} {stats['errors']:<10,} {stats['accuracy']*100:>8.2f}% {stats['error_rate']*100:.2f}%") |
| lines.append("") |
|
|
| |
| lines.append("6. Error Rate by Position in Sentence") |
| lines.append("-" * 40) |
| for region, stats in boundary_errors.items(): |
| label = f"{region.capitalize()} (first/last 3 syls)" if region != "middle" else "Middle" |
| lines.append(f" {label:<35} {stats['errors']:,} / {stats['total']:,} ({stats['error_rate']*100:.2f}%)") |
| lines.append("") |
|
|
| |
| lines.append("7. Top Error Patterns (syllable in context)") |
| lines.append("-" * 70) |
| lines.append(f" {'Prev syl':<15} {'Current':<15} {'Next syl':<15} {'Error':<8} {'Count'}") |
| lines.append(f" {'--------':<15} {'-------':<15} {'--------':<15} {'-----':<8} {'-----'}") |
| for (prev_syl, syl, next_syl, error_type), count in top_patterns: |
| lines.append(f" {prev_syl:<15} {syl:<15} {next_syl:<15} {error_type:<8} {count}") |
| lines.append("") |
|
|
| lines.append("=" * 70) |
| return "\n".join(lines) |
|
|
|
|
| def save_errors_csv(output_path, false_splits, false_joins, length_errors): |
| """Save error details to CSV files.""" |
| output_dir = output_path.parent |
|
|
| |
| splits_path = output_dir / "false_splits.csv" |
| split_counter = Counter() |
| split_examples = {} |
| for word, parts, context in false_splits: |
| split_counter[word] += 1 |
| if word not in split_examples: |
| split_examples[word] = (parts, context) |
|
|
| with open(splits_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerow(["word", "count", "predicted_parts", "context"]) |
| for word, count in split_counter.most_common(): |
| parts, ctx = split_examples[word] |
| writer.writerow([word, count, " | ".join(parts), ctx]) |
|
|
| |
| joins_path = output_dir / "false_joins.csv" |
| join_counter = Counter() |
| join_examples = {} |
| for word, parts, context in false_joins: |
| join_counter[word] += 1 |
| if word not in join_examples: |
| join_examples[word] = (parts, context) |
|
|
| with open(joins_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerow(["merged_word", "count", "true_parts", "context"]) |
| for word, count in join_counter.most_common(): |
| parts, ctx = join_examples[word] |
| writer.writerow([word, count, " | ".join(parts), ctx]) |
|
|
| |
| length_path = output_dir / "error_by_length.csv" |
| with open(length_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| writer.writerow(["word_length_syllables", "total", "correct", "errors", "accuracy", "error_rate"]) |
| for length, stats in sorted(length_errors.items()): |
| writer.writerow([length, stats["total"], stats["correct"], stats["errors"], |
| f"{stats['accuracy']:.4f}", f"{stats['error_rate']:.4f}"]) |
|
|
| return splits_path, joins_path, length_path |
|
|
|
|
| |
| |
| |
|
|
| @click.command() |
| @click.option( |
| "--model", "-m", |
| default=None, |
| help="Model directory (default: models/word_segmentation/vlsp2013)", |
| ) |
| @click.option( |
| "--data-dir", "-d", |
| default=None, |
| help="Dataset directory (default: datasets/c7veardo0e)", |
| ) |
| @click.option( |
| "--output", "-o", |
| default=None, |
| help="Output directory for results (default: results/word_segmentation)", |
| ) |
| def main(model, data_dir, output): |
| """Run error analysis on VLSP 2013 word segmentation test set.""" |
| |
| model_dir = Path(model) if model else PROJECT_ROOT / "models" / "word_segmentation" / "vlsp2013" |
| data_path = Path(data_dir) if data_dir else PROJECT_ROOT / "datasets" / "c7veardo0e" |
| output_dir = Path(output) if output else PROJECT_ROOT / "results" / "word_segmentation" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| model_path = model_dir / "model.crf" |
| if not model_path.exists(): |
| model_path = model_dir / "model.crfsuite" |
| if not model_path.exists(): |
| raise click.ClickException(f"No model file found in {model_dir}") |
|
|
| click.echo(f"Model: {model_path}") |
| click.echo(f"Data: {data_path}") |
| click.echo(f"Output: {output_dir}") |
| click.echo("") |
|
|
| |
| click.echo("Loading model...") |
| model_path_str = str(model_path) |
| if model_path_str.endswith(".crf"): |
| from underthesea_core import CRFModel, CRFTagger |
| crf_model = CRFModel.load(model_path_str) |
| tagger = CRFTagger.from_model(crf_model) |
| predict_fn = lambda X: [tagger.tag(xseq) for xseq in X] |
| else: |
| import pycrfsuite |
| tagger = pycrfsuite.Tagger() |
| tagger.open(model_path_str) |
| predict_fn = lambda X: [tagger.tag(xseq) for xseq in X] |
|
|
| |
| click.echo("Loading VLSP 2013 test set...") |
| test_data = load_vlsp2013_test(data_path) |
| click.echo(f" {len(test_data)} sentences") |
|
|
| all_syllables = [syls for syls, _ in test_data] |
| all_true = [labels for _, labels in test_data] |
| num_syllables = sum(len(syls) for syls in all_syllables) |
| click.echo(f" {num_syllables:,} syllables") |
|
|
| |
| dict_path = model_dir / "dictionary.txt" |
| dictionary = None |
| if dict_path.exists(): |
| dictionary = load_dictionary(dict_path) |
| click.echo(f" Dictionary: {len(dictionary)} words from {dict_path}") |
|
|
| |
| click.echo("Extracting features...") |
| active_templates = get_all_templates() |
| if dictionary is None: |
| active_templates = [t for t in active_templates if t not in FEATURE_GROUPS["dictionary"]] |
| X_test = [sentence_to_syllable_features(syls, active_templates, dictionary) for syls in all_syllables] |
|
|
| click.echo("Predicting...") |
| all_pred = predict_fn(X_test) |
|
|
| |
| click.echo("Analyzing errors...") |
|
|
| |
| syl_errors = analyze_syllable_errors(all_true, all_pred) |
|
|
| |
| word_metrics = compute_word_metrics(all_syllables, all_true, all_pred) |
|
|
| |
| false_splits, false_joins = analyze_word_errors(all_syllables, all_true, all_pred) |
|
|
| |
| length_errors = analyze_errors_by_word_length(all_syllables, all_true, all_pred) |
|
|
| |
| boundary_errors = analyze_boundary_errors(all_syllables, all_true, all_pred) |
|
|
| |
| top_patterns = get_top_error_patterns(all_syllables, all_true, all_pred, top_n=20) |
|
|
| |
| report = format_report( |
| syl_errors, word_metrics, false_splits, false_joins, |
| length_errors, boundary_errors, top_patterns, |
| len(test_data), num_syllables, |
| ) |
|
|
| |
| click.echo("") |
| click.echo(report) |
|
|
| |
| report_path = output_dir / "error_analysis.txt" |
| with open(report_path, "w", encoding="utf-8") as f: |
| f.write(report) |
| click.echo(f"\nReport saved to {report_path}") |
|
|
| |
| splits_csv, joins_csv, length_csv = save_errors_csv( |
| report_path, false_splits, false_joins, length_errors |
| ) |
| click.echo(f"False splits CSV: {splits_csv}") |
| click.echo(f"False joins CSV: {joins_csv}") |
| click.echo(f"Error by length: {length_csv}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|