| |
| """ |
| Evaluate the ProtocolSegmentClassifier against labeled mixed dictation examples. |
| |
| Measures: |
| 1. Word-level accuracy (is each word correctly classified?) |
| 2. Segment boundary accuracy (does the classifier find the right start/end?) |
| 3. False positive rate on pure natural text |
| 4. False negative rate on protocol segments |
| """ |
|
|
| import json |
| import numpy as np |
| from pathlib import Path |
| import importlib |
| _mod = importlib.import_module("train-segment-classifier") |
| extract_features = _mod.extract_features |
| is_protocol_word = _mod.is_protocol_word |
| is_strong_protocol = _mod.is_strong_protocol |
| STRONG_PROTOCOL = _mod.STRONG_PROTOCOL |
| WEAK_PROTOCOL = _mod.WEAK_PROTOCOL |
| EXPANDED_SYMBOLS = _mod.EXPANDED_SYMBOLS |
|
|
|
|
| def load_model(path: str) -> tuple[np.ndarray, float]: |
| """Load trained model from JSON.""" |
| with open(path) as f: |
| model = json.load(f) |
| return np.array(model["weights"]), model["bias"] |
|
|
|
|
| def classify_words(words: list[str], weights: np.ndarray, bias: float, threshold: float = 0.5) -> list[tuple[str, float, bool]]: |
| """Classify each word and return (word, probability, is_protocol).""" |
| results = [] |
| n = len(words) |
| for i, word in enumerate(words): |
| ctx_start = max(0, i - 2) |
| ctx_end = min(n, i + 3) |
| context = words[ctx_start:ctx_end] |
| features = extract_features(word, context, i, n) |
| logit = np.dot(features, weights) + bias |
| prob = 1.0 / (1.0 + np.exp(-logit)) |
| results.append((word, prob, prob >= threshold)) |
| return results |
|
|
|
|
| def expand_anchors(is_anchor: list[bool], n: int, radius: int = 2) -> list[bool]: |
| """Expand anchor words ±radius to capture adjacent command tokens.""" |
| expanded = [False] * n |
| for i in range(n): |
| if is_anchor[i]: |
| for j in range(max(0, i - radius), min(n, i + radius + 1)): |
| expanded[j] = True |
| return expanded |
|
|
|
|
| def build_segments(words: list[str], in_protocol: list[bool]) -> list[tuple[str, str]]: |
| """Build (kind, text) segments from word-level classifications.""" |
| segments = [] |
| current_kind = "protocol" if in_protocol[0] else "natural" |
| current_words = [words[0]] |
|
|
| for i in range(1, len(words)): |
| kind = "protocol" if in_protocol[i] else "natural" |
| if kind == current_kind: |
| current_words.append(words[i]) |
| else: |
| segments.append((current_kind, " ".join(current_words))) |
| current_kind = kind |
| current_words = [words[i]] |
| segments.append((current_kind, " ".join(current_words))) |
| return segments |
|
|
|
|
| def label_words_from_example(example: dict) -> list[int]: |
| """ |
| Label each word in the text as natural (0) or protocol (1) |
| based on the 'natural' and 'protocol' fields. |
| |
| The protocol field may contain ||| to indicate multiple protocol segments. |
| The natural field may contain ||| to indicate natural segments between protocol segments. |
| """ |
| text = example["text"] |
| words = text.split() |
| natural_parts = [p.strip() for p in example["natural"].split("|||") if p.strip()] |
| protocol_parts = [p.strip() for p in example["protocol"].split("|||") if p.strip()] |
|
|
| labels = [0] * len(words) |
|
|
| |
| for proto_part in protocol_parts: |
| proto_words = proto_part.split() |
| if not proto_words: |
| continue |
|
|
| |
| plen = len(proto_words) |
| for start in range(len(words) - plen + 1): |
| if words[start:start + plen] == proto_words: |
| for j in range(start, start + plen): |
| labels[j] = 1 |
| break |
|
|
| return labels |
|
|
|
|
| def evaluate(eval_path: str, model_path: str): |
| """Run full evaluation.""" |
| with open(eval_path) as f: |
| examples = json.load(f) |
|
|
| weights, bias = load_model(model_path) |
|
|
| print("=" * 70) |
| print("PROTOCOL SEGMENT CLASSIFIER EVALUATION") |
| print(f" Model: {model_path}") |
| print(f" Eval set: {len(examples)} examples") |
| print("=" * 70) |
|
|
| |
| total_words = 0 |
| correct_words = 0 |
| tp = fp = fn = tn = 0 |
|
|
| |
| pure_natural_count = 0 |
| pure_natural_correct = 0 |
| mixed_count = 0 |
| mixed_correct = 0 |
|
|
| |
| errors = [] |
|
|
| for ex in examples: |
| text = ex["text"] |
| words = text.split() |
| n = len(words) |
| is_pure_natural = not ex["protocol"].strip() |
|
|
| |
| gt_labels = label_words_from_example(ex) |
|
|
| |
| results = classify_words(words, weights, bias) |
| anchors = [r[2] for r in results] |
| predicted = expand_anchors(anchors, n) |
|
|
| |
| for i in range(n): |
| pred = 1 if predicted[i] else 0 |
| gt = gt_labels[i] |
| total_words += 1 |
| if pred == gt: |
| correct_words += 1 |
| if pred == 1 and gt == 1: |
| tp += 1 |
| elif pred == 1 and gt == 0: |
| fp += 1 |
| elif pred == 0 and gt == 1: |
| fn += 1 |
| else: |
| tn += 1 |
|
|
| |
| example_correct = all( |
| (1 if predicted[i] else 0) == gt_labels[i] |
| for i in range(n) |
| ) |
|
|
| if is_pure_natural: |
| pure_natural_count += 1 |
| if example_correct: |
| pure_natural_correct += 1 |
| else: |
| mixed_count += 1 |
| if example_correct: |
| mixed_correct += 1 |
|
|
| |
| if not example_correct: |
| pred_segments = build_segments(words, predicted) |
| gt_segments = build_segments(words, [bool(l) for l in gt_labels]) |
| errors.append({ |
| "id": ex["id"], |
| "note": ex.get("note", ""), |
| "words": words, |
| "gt": gt_labels, |
| "pred": [1 if p else 0 for p in predicted], |
| "pred_segments": pred_segments, |
| "gt_segments": gt_segments, |
| }) |
|
|
| |
| word_acc = correct_words / total_words if total_words else 0 |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
| print(f"\n{'WORD-LEVEL METRICS':─^70}") |
| print(f" Accuracy: {word_acc:.1%} ({correct_words}/{total_words})") |
| print(f" Precision: {precision:.1%} (protocol words correctly identified)") |
| print(f" Recall: {recall:.1%} (protocol words found)") |
| print(f" F1: {f1:.3f}") |
| print(f" Confusion: TP={tp} FP={fp} FN={fn} TN={tn}") |
|
|
| print(f"\n{'EXAMPLE-LEVEL METRICS':─^70}") |
| print(f" Pure natural (no protocol): {pure_natural_correct}/{pure_natural_count} perfect") |
| print(f" Mixed (natural + protocol): {mixed_correct}/{mixed_count} perfect") |
| total_ex = pure_natural_count + mixed_count |
| total_correct = pure_natural_correct + mixed_correct |
| print(f" Overall: {total_correct}/{total_ex} perfect ({100*total_correct/total_ex:.0f}%)") |
|
|
| |
| if errors: |
| print(f"\n{'ERRORS':─^70}") |
| for err in errors: |
| print(f"\n #{err['id']}: {err['note']}") |
| |
| parts = [] |
| for i, word in enumerate(err["words"]): |
| gt = err["gt"][i] |
| pred = err["pred"][i] |
| if gt == pred: |
| if gt == 1: |
| parts.append(f"\033[92m{word}\033[0m") |
| else: |
| parts.append(word) |
| elif pred == 1 and gt == 0: |
| parts.append(f"\033[91m[FP:{word}]\033[0m") |
| else: |
| parts.append(f"\033[93m[FN:{word}]\033[0m") |
| print(f" {' '.join(parts)}") |
|
|
| |
| print(f"\n{'SUMMARY':─^70}") |
| fp_rate = fp / (fp + tn) if (fp + tn) > 0 else 0 |
| fn_rate = fn / (fn + tp) if (fn + tp) > 0 else 0 |
| print(f" False positive rate: {fp_rate:.1%} (natural words misclassified as protocol)") |
| print(f" False negative rate: {fn_rate:.1%} (protocol words missed)") |
| print(f" Key concern: FP on natural text → model hallucinates on speech") |
| print(f" Secondary: FN on protocol → falls through as raw dictation") |
|
|
| return word_acc, precision, recall, f1 |
|
|
|
|
| if __name__ == "__main__": |
| eval_path = Path(__file__).parent / "eval-segment-classifier.json" |
| model_path = Path(__file__).parent / "segment-classifier-model.json" |
| evaluate(str(eval_path), str(model_path)) |
|
|