File size: 5,748 Bytes
4129d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import argparse
import json
from collections import Counter, defaultdict

import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer

from training.benchmark_structured import ALLOWED_INPUTS, flatten_resume, predicted_spans_from_text
from training.benchmark_utils import classify_resume_noise
from training.labels import ID2LABEL
from training.structured_postprocess import StructuredPostProcessor, build_text_and_spans


def counter_diff_size(gold: list[str], pred: list[str]) -> tuple[int, int, int]:
    gold_counter = Counter(gold)
    pred_counter = Counter(pred)
    overlap = gold_counter & pred_counter
    tp = sum(overlap.values())
    fp = sum((pred_counter - overlap).values())
    fn = sum((gold_counter - overlap).values())
    return tp, fp, fn


def main() -> None:
    parser = argparse.ArgumentParser(description="Analyze per-resume structured extraction errors and outliers")
    parser.add_argument("--model-dir", default=".")
    parser.add_argument("--val-path", default="training/data/ner_val.json")
    parser.add_argument("--top-k", type=int, default=10)
    args = parser.parse_args()

    payload = json.load(open(args.val_path))
    examples = payload["data"]
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
    model = AutoModelForTokenClassification.from_pretrained(args.model_dir)
    model.eval()
    postprocessor = StructuredPostProcessor(args.model_dir)

    per_resume = []
    field_error_counts = Counter()
    bucket_counts = Counter()
    total_fp = 0
    total_fn = 0
    total_tp = 0

    for idx, example in enumerate(examples):
        gold_text, gold_spans = build_text_and_spans(example["tokens"], example["ner_tags"], ID2LABEL)
        gold_structured = postprocessor.build_structured_resume_from_spans(gold_spans, gold_text)
        bucket_info = classify_resume_noise(gold_text)
        bucket_counts[str(bucket_info["bucket"])] += 1

        tokenized = tokenizer(gold_text, return_tensors="pt", return_offsets_mapping=True, truncation=True, max_length=512)
        encoded = {k: v for k, v in tokenized.items() if k in ALLOWED_INPUTS}
        with torch.no_grad():
            pred_ids = model(**encoded).logits.argmax(dim=-1).squeeze(0).cpu().tolist()

        offsets = [tuple(pair) for pair in tokenized["offset_mapping"].squeeze(0).cpu().tolist()][1:-1]
        pred_text, pred_spans = predicted_spans_from_text(gold_text, offsets, pred_ids[1:-1])
        pred_structured = postprocessor.build_structured_resume_from_spans(pred_spans, pred_text)

        gold_flat = flatten_resume(gold_structured)
        pred_flat = flatten_resume(pred_structured)
        resume_fp = 0
        resume_fn = 0
        mismatched_fields = []

        for field in sorted(set(gold_flat) | set(pred_flat)):
            tp, fp, fn = counter_diff_size(gold_flat.get(field, []), pred_flat.get(field, []))
            total_tp += tp
            total_fp += fp
            total_fn += fn
            resume_fp += fp
            resume_fn += fn
            if fp or fn:
                field_error_counts[field] += fp + fn
                mismatched_fields.append(
                    {
                        "field": field,
                        "gold": gold_flat.get(field, []),
                        "pred": pred_flat.get(field, []),
                        "fp": fp,
                        "fn": fn,
                    }
                )

        per_resume.append(
            {
                "index": idx,
                "text_preview": gold_text[:300],
                "bucket": bucket_info["bucket"],
                "noise_signals": bucket_info["signals"],
                "gold_field_count": sum(len(v) for v in gold_flat.values()),
                "pred_field_count": sum(len(v) for v in pred_flat.values()),
                "tp": sum(min(Counter(gold_flat.get(f, []))[k], Counter(pred_flat.get(f, []))[k]) for f in set(gold_flat) | set(pred_flat) for k in (Counter(gold_flat.get(f, [])) & Counter(pred_flat.get(f, [])))),
                "fp": resume_fp,
                "fn": resume_fn,
                "errors": resume_fp + resume_fn,
                "mismatched_fields": mismatched_fields,
            }
        )

    error_values = [item["errors"] for item in per_resume]
    avg_errors = sum(error_values) / len(error_values) if error_values else 0.0
    median_errors = sorted(error_values)[len(error_values) // 2] if error_values else 0
    zero_error = sum(1 for value in error_values if value == 0)
    one_or_less = sum(1 for value in error_values if value <= 1)
    three_or_more = sum(1 for value in error_values if value >= 3)
    outliers = sorted(per_resume, key=lambda item: (-item["errors"], item["index"]))[: args.top_k]

    summary = {
        "examples": len(per_resume),
        "avg_errors_per_resume": avg_errors,
        "median_errors_per_resume": median_errors,
        "zero_error_resumes": zero_error,
        "zero_error_rate": zero_error / len(per_resume) if per_resume else 0.0,
        "one_or_less_error_resumes": one_or_less,
        "one_or_less_error_rate": one_or_less / len(per_resume) if per_resume else 0.0,
        "three_or_more_error_resumes": three_or_more,
        "three_or_more_error_rate": three_or_more / len(per_resume) if per_resume else 0.0,
        "micro": {
            "tp": total_tp,
            "fp": total_fp,
            "fn": total_fn,
        },
        "bucket_counts": dict(bucket_counts),
        "top_error_fields": field_error_counts.most_common(10),
        "outliers": outliers,
        "note": "errors = fp + fn over flattened structured fields per resume",
    }
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()