| |
| """Audit entity-propagation output against its base JSONL file. |
| |
| The propagation step can be useful for recall, but it can also amplify noisy |
| surface forms across many documents. This script compares base and propagated |
| files record-by-record, summarizes added spans, and flags surfaces that need |
| manual review before using the propagated file for training. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import re |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Any |
|
|
| Span = tuple[str, str, int, int] |
|
|
| GENERIC_SURFACES = { |
| "account", |
| "accounts", |
| "admin", |
| "android", |
| "app", |
| "application", |
| "attack", |
| "backdoor", |
| "browser", |
| "client", |
| "code", |
| "command", |
| "computer", |
| "data", |
| "device", |
| "devices", |
| "domain", |
| "download", |
| "email", |
| "exploit", |
| "file", |
| "files", |
| "host", |
| "install", |
| "malware", |
| "network", |
| "payload", |
| "process", |
| "program", |
| "registry", |
| "remote", |
| "script", |
| "server", |
| "service", |
| "shell", |
| "software", |
| "system", |
| "target", |
| "trojan", |
| "update", |
| "user", |
| "users", |
| "version", |
| "web", |
| "windows", |
| } |
|
|
|
|
| def parse_spans(record: dict[str, Any]) -> set[Span]: |
| spans = set() |
| for key, offsets in (record.get("spans") or {}).items(): |
| if ": " not in key: |
| continue |
| label, surface = key.split(": ", 1) |
| for start, end in offsets: |
| spans.add((label, surface, int(start), int(end))) |
| return spans |
|
|
|
|
| def load_jsonl(path: Path) -> list[dict[str, Any]]: |
| with path.open(encoding="utf-8") as f: |
| return [json.loads(line) for line in f if line.strip()] |
|
|
|
|
| def load_test_surfaces(paths: list[Path]) -> dict[str, set[str]]: |
| surfaces: dict[str, set[str]] = defaultdict(set) |
| for path in paths: |
| if not path.exists(): |
| continue |
| for rec in load_jsonl(path): |
| for label, surface, _, _ in parse_spans(rec): |
| surfaces[label].add(surface.lower()) |
| return surfaces |
|
|
|
|
| def flag_surface(label: str, surface: str, conflicts: set[str], test_seen: bool) -> list[str]: |
| flags = [] |
| norm = surface.strip().lower() |
| token_count = len(norm.split()) |
| if len(norm) < 4: |
| flags.append("short") |
| if norm in GENERIC_SURFACES: |
| flags.append("generic") |
| if token_count >= 6: |
| flags.append("long") |
| if re.fullmatch(r"[A-Za-z]+", surface) and surface.islower(): |
| flags.append("lowercase-word") |
| if conflicts: |
| flags.append("label-conflict") |
| if test_seen: |
| flags.append("seen-in-test") |
| if label in {"Organization", "System", "Malware"} and token_count == 1 and norm in GENERIC_SURFACES: |
| flags.append("high-risk-semantic") |
| return flags |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Audit propagated NER JSONL data") |
| parser.add_argument("--base", required=True, type=Path) |
| parser.add_argument("--propagated", required=True, type=Path) |
| parser.add_argument("--test-data", nargs="*", type=Path, default=[]) |
| parser.add_argument("--json-out", type=Path, default=Path("results/entity_propagation_audit.json")) |
| parser.add_argument("--md-out", type=Path, default=Path("results/entity_propagation_audit.md")) |
| parser.add_argument("--max-samples", type=int, default=8) |
| args = parser.parse_args() |
|
|
| base = load_jsonl(args.base) |
| prop = load_jsonl(args.propagated) |
| if len(base) != len(prop): |
| raise SystemExit(f"record count mismatch: base={len(base)} propagated={len(prop)}") |
|
|
| original_surface_labels: dict[str, Counter[str]] = defaultdict(Counter) |
| for rec in base: |
| for label, surface, _, _ in parse_spans(rec): |
| original_surface_labels[surface.lower()][label] += 1 |
|
|
| test_surfaces = load_test_surfaces(args.test_data) |
|
|
| added_by_class: Counter[str] = Counter() |
| original_by_class: Counter[str] = Counter() |
| added_by_surface: Counter[tuple[str, str]] = Counter() |
| flagged_by_reason: Counter[str] = Counter() |
| flagged_by_class: Counter[str] = Counter() |
| examples: dict[str, list[dict[str, Any]]] = defaultdict(list) |
| total_added = 0 |
| records_with_added = 0 |
|
|
| for idx, (base_rec, prop_rec) in enumerate(zip(base, prop)): |
| if base_rec.get("text") != prop_rec.get("text"): |
| raise SystemExit(f"text mismatch at record {idx}") |
|
|
| base_spans = parse_spans(base_rec) |
| prop_spans = parse_spans(prop_rec) |
| for label, _, _, _ in base_spans: |
| original_by_class[label] += 1 |
|
|
| added = prop_spans - base_spans |
| if added: |
| records_with_added += 1 |
|
|
| for label, surface, start, end in added: |
| total_added += 1 |
| added_by_class[label] += 1 |
| added_by_surface[(label, surface)] += 1 |
|
|
| labels_seen = set(original_surface_labels.get(surface.lower(), {})) |
| conflicts = labels_seen - {label} |
| test_seen = surface.lower() in test_surfaces.get(label, set()) |
| flags = flag_surface(label, surface, conflicts, test_seen) |
| for flag in flags: |
| flagged_by_reason[flag] += 1 |
| if flags: |
| flagged_by_class[label] += 1 |
| key = ",".join(flags) |
| if len(examples[key]) < args.max_samples: |
| text = prop_rec["text"] |
| examples[key].append( |
| { |
| "record": idx, |
| "label": label, |
| "surface": surface, |
| "offset": [start, end], |
| "flags": flags, |
| "context": text[max(0, start - 80) : min(len(text), end + 80)], |
| } |
| ) |
|
|
| top_surfaces = [ |
| {"label": label, "surface": surface, "added": count} |
| for (label, surface), count in added_by_surface.most_common(50) |
| ] |
| high_risk_surfaces = [ |
| item |
| for item in top_surfaces |
| if flag_surface( |
| item["label"], |
| item["surface"], |
| set(original_surface_labels.get(item["surface"].lower(), {})) - {item["label"]}, |
| item["surface"].lower() in test_surfaces.get(item["label"], set()), |
| ) |
| ] |
|
|
| summary = { |
| "base": str(args.base), |
| "propagated": str(args.propagated), |
| "records": len(base), |
| "records_with_added": records_with_added, |
| "original_spans": sum(original_by_class.values()), |
| "added_spans": total_added, |
| "propagated_spans": sum(original_by_class.values()) + total_added, |
| "original_by_class": dict(original_by_class), |
| "added_by_class": dict(added_by_class), |
| "flagged_by_reason": dict(flagged_by_reason), |
| "flagged_by_class": dict(flagged_by_class), |
| "top_added_surfaces": top_surfaces, |
| "high_risk_top_surfaces": high_risk_surfaces[:30], |
| "examples": examples, |
| } |
|
|
| args.json_out.parent.mkdir(parents=True, exist_ok=True) |
| args.json_out.write_text(json.dumps(summary, indent=2, sort_keys=True), encoding="utf-8") |
|
|
| lines = [ |
| "# Entity Propagation Audit", |
| "", |
| f"- Base: `{args.base}`", |
| f"- Propagated: `{args.propagated}`", |
| f"- Records: {len(base):,}", |
| f"- Records with added spans: {records_with_added:,}", |
| f"- Original spans: {sum(original_by_class.values()):,}", |
| f"- Added spans: {total_added:,}", |
| f"- Propagated total spans: {sum(original_by_class.values()) + total_added:,}", |
| "", |
| "## Added Spans By Class", |
| "", |
| "| Class | Original | Added | Added/Original |", |
| "|---|---:|---:|---:|", |
| ] |
| for label in sorted(set(original_by_class) | set(added_by_class)): |
| original = original_by_class[label] |
| added = added_by_class[label] |
| ratio = added / original if original else 0.0 |
| lines.append(f"| {label} | {original:,} | {added:,} | {ratio:.2f}x |") |
|
|
| lines.extend( |
| [ |
| "", |
| "## Flags", |
| "", |
| "| Flag | Count |", |
| "|---|---:|", |
| ] |
| ) |
| for flag, count in flagged_by_reason.most_common(): |
| lines.append(f"| {flag} | {count:,} |") |
|
|
| lines.extend( |
| [ |
| "", |
| "## Top Added Surfaces", |
| "", |
| "| Label | Surface | Added |", |
| "|---|---|---:|", |
| ] |
| ) |
| for item in top_surfaces[:25]: |
| surface = item["surface"].replace("|", "\\|") |
| lines.append(f"| {item['label']} | `{surface}` | {item['added']:,} |") |
|
|
| lines.extend( |
| [ |
| "", |
| "## High-Risk Top Surfaces", |
| "", |
| "| Label | Surface | Added |", |
| "|---|---|---:|", |
| ] |
| ) |
| for item in high_risk_surfaces[:25]: |
| surface = item["surface"].replace("|", "\\|") |
| lines.append(f"| {item['label']} | `{surface}` | {item['added']:,} |") |
|
|
| args.md_out.parent.mkdir(parents=True, exist_ok=True) |
| args.md_out.write_text("\n".join(lines) + "\n", encoding="utf-8") |
|
|
| print(f"Records: {len(base):,}") |
| print(f"Added spans: {total_added:,}") |
| print(f"Audit JSON: {args.json_out}") |
| print(f"Audit markdown: {args.md_out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|