Baladithya Balamurugan
Wave 21: adversarial-review fixes — all 9 verified findings closed
3bbcf21
Raw
History Blame Contribute Delete
5.2 kB
"""dedup.py — MinHash near-duplicate detection (finding D-12).
Cross-generation dedup is a flywheel-collapse mitigation: a self-training loop
that re-ingests its own outputs accumulates near-identical rows, and per-batch
`document_deduplicator` (the only dedup the old designs had) never sees across
runs. This module computes MinHash signatures over word 5-shingles so a run can
(a) dedup within itself and (b) accept the PRIOR run's signature file and dedup
against it (lineage threaded by `RunManifest.parent_run_id`).
Pragmatic v1: builtin-hash permutation MinHash with N=64 seeds, no banding/LSH
(O(n^2) pair scan — fine for Stage-0 corpus sizes, thousands of rows).
`datasketch` (MinHashLSH) is the documented upgrade path when row counts make
the pair scan bite.
NOTE on hash stability: Python's builtin `hash()` over str is salted per
process (PYTHONHASHSEED), which would make signatures non-portable across
runs — exactly what cross-generation dedup needs. We therefore use md5-based
hashing (stable everywhere) despite the small speed cost.
"""
from __future__ import annotations
import hashlib
import json
import re
from typing import IO, Callable, Iterable, Sequence
N_PERMUTATIONS = 64
_SHINGLE_W = 5
_WORD_RE = re.compile(r"\w+")
_MAX64 = (1 << 64) - 1
def _shingles(text: str, w: int = _SHINGLE_W) -> set[str]:
words = _WORD_RE.findall(text.lower())
if len(words) <= w:
return {" ".join(words)} if words else set()
return {" ".join(words[i:i + w]) for i in range(len(words) - w + 1)}
def _stable_hash(s: str, seed: int) -> int:
h = hashlib.md5(f"{seed}:{s}".encode()).digest()
return int.from_bytes(h[:8], "big")
def minhash_signature(text: str, n_perm: int = N_PERMUTATIONS) -> tuple[int, ...]:
"""MinHash signature: per-seed minimum over the shingle set."""
sh = _shingles(text)
if not sh:
return tuple([_MAX64] * n_perm)
return tuple(min(_stable_hash(s, seed) for s in sh) for seed in range(n_perm))
def jaccard_estimate(sig_a: Sequence[int], sig_b: Sequence[int]) -> float:
"""Estimated Jaccard similarity = fraction of agreeing signature slots."""
if len(sig_a) != len(sig_b) or not sig_a:
raise ValueError("signatures must be equal-length and non-empty")
return sum(1 for a, b in zip(sig_a, sig_b) if a == b) / len(sig_a)
def find_near_duplicates(
rows: Sequence[dict],
key_fn: Callable[[dict], str],
threshold: float = 0.85,
*,
prior_signatures: Sequence[Sequence[int]] | None = None,
) -> list[tuple[int, int]]:
"""All (i, j) index pairs whose estimated Jaccard >= threshold.
`prior_signatures` (from a previous run) participate as virtual rows with
negative indices -(k+1), so a pair (i, -1) means "row i duplicates prior
signature 0" — the cross-generation case.
"""
sigs = [minhash_signature(key_fn(r)) for r in rows]
pairs: list[tuple[int, int]] = []
for i in range(len(sigs)):
for j in range(i + 1, len(sigs)):
if jaccard_estimate(sigs[i], sigs[j]) >= threshold:
pairs.append((i, j))
for k, prior in enumerate(prior_signatures or []):
if jaccard_estimate(sigs[i], prior) >= threshold:
pairs.append((i, -(k + 1)))
return pairs
def dedup(
rows: Sequence[dict],
key_fn: Callable[[dict], str],
threshold: float = 0.85,
*,
prior_signatures: Sequence[Sequence[int]] | None = None,
) -> tuple[list[dict], dict]:
"""Keep-first dedup. Returns (kept_rows, stats).
A row duplicating a PRIOR-run signature is dropped outright (the prior run
already owns it); within-run duplicates keep the earliest occurrence.
"""
pairs = find_near_duplicates(rows, key_fn, threshold,
prior_signatures=prior_signatures)
# Partition into disjoint drop-reason sets (Wave-21 review P2: a row that
# is both a within-run AND cross-generation duplicate must count once;
# cross-generation wins the attribution since the prior run owns the row).
drop_cross: set[int] = {i for i, j in pairs if j < 0}
drop_within: set[int] = {j for _, j in pairs if j >= 0} - drop_cross
drop = drop_cross | drop_within
kept = [r for i, r in enumerate(rows) if i not in drop]
return kept, {
"rows_in": len(rows),
"rows_kept": len(kept),
"dropped_within_run": len(drop_within),
"dropped_cross_generation": len(drop_cross),
"threshold": threshold,
}
def signatures_to_jsonl(rows: Sequence[dict], key_fn: Callable[[dict], str],
fh: IO[str]) -> int:
"""Persist this run's signatures so the NEXT generation can dedup against
them (pass the loaded list as `prior_signatures`)."""
n = 0
for r in rows:
fh.write(json.dumps(list(minhash_signature(key_fn(r)))) + "\n")
n += 1
return n
def load_signatures(fh: IO[str]) -> list[tuple[int, ...]]:
return [tuple(json.loads(line)) for line in fh if line.strip()]
__all__ = [
"N_PERMUTATIONS",
"dedup",
"find_near_duplicates",
"jaccard_estimate",
"load_signatures",
"minhash_signature",
"signatures_to_jsonl",
]