"""dedup.py — MinHash near-duplicate detection (finding D-12). Cross-generation dedup is a flywheel-collapse mitigation: a self-training loop that re-ingests its own outputs accumulates near-identical rows, and per-batch `document_deduplicator` (the only dedup the old designs had) never sees across runs. This module computes MinHash signatures over word 5-shingles so a run can (a) dedup within itself and (b) accept the PRIOR run's signature file and dedup against it (lineage threaded by `RunManifest.parent_run_id`). Pragmatic v1: builtin-hash permutation MinHash with N=64 seeds, no banding/LSH (O(n^2) pair scan — fine for Stage-0 corpus sizes, thousands of rows). `datasketch` (MinHashLSH) is the documented upgrade path when row counts make the pair scan bite. NOTE on hash stability: Python's builtin `hash()` over str is salted per process (PYTHONHASHSEED), which would make signatures non-portable across runs — exactly what cross-generation dedup needs. We therefore use md5-based hashing (stable everywhere) despite the small speed cost. """ from __future__ import annotations import hashlib import json import re from typing import IO, Callable, Iterable, Sequence N_PERMUTATIONS = 64 _SHINGLE_W = 5 _WORD_RE = re.compile(r"\w+") _MAX64 = (1 << 64) - 1 def _shingles(text: str, w: int = _SHINGLE_W) -> set[str]: words = _WORD_RE.findall(text.lower()) if len(words) <= w: return {" ".join(words)} if words else set() return {" ".join(words[i:i + w]) for i in range(len(words) - w + 1)} def _stable_hash(s: str, seed: int) -> int: h = hashlib.md5(f"{seed}:{s}".encode()).digest() return int.from_bytes(h[:8], "big") def minhash_signature(text: str, n_perm: int = N_PERMUTATIONS) -> tuple[int, ...]: """MinHash signature: per-seed minimum over the shingle set.""" sh = _shingles(text) if not sh: return tuple([_MAX64] * n_perm) return tuple(min(_stable_hash(s, seed) for s in sh) for seed in range(n_perm)) def jaccard_estimate(sig_a: Sequence[int], sig_b: Sequence[int]) -> float: """Estimated Jaccard similarity = fraction of agreeing signature slots.""" if len(sig_a) != len(sig_b) or not sig_a: raise ValueError("signatures must be equal-length and non-empty") return sum(1 for a, b in zip(sig_a, sig_b) if a == b) / len(sig_a) def find_near_duplicates( rows: Sequence[dict], key_fn: Callable[[dict], str], threshold: float = 0.85, *, prior_signatures: Sequence[Sequence[int]] | None = None, ) -> list[tuple[int, int]]: """All (i, j) index pairs whose estimated Jaccard >= threshold. `prior_signatures` (from a previous run) participate as virtual rows with negative indices -(k+1), so a pair (i, -1) means "row i duplicates prior signature 0" — the cross-generation case. """ sigs = [minhash_signature(key_fn(r)) for r in rows] pairs: list[tuple[int, int]] = [] for i in range(len(sigs)): for j in range(i + 1, len(sigs)): if jaccard_estimate(sigs[i], sigs[j]) >= threshold: pairs.append((i, j)) for k, prior in enumerate(prior_signatures or []): if jaccard_estimate(sigs[i], prior) >= threshold: pairs.append((i, -(k + 1))) return pairs def dedup( rows: Sequence[dict], key_fn: Callable[[dict], str], threshold: float = 0.85, *, prior_signatures: Sequence[Sequence[int]] | None = None, ) -> tuple[list[dict], dict]: """Keep-first dedup. Returns (kept_rows, stats). A row duplicating a PRIOR-run signature is dropped outright (the prior run already owns it); within-run duplicates keep the earliest occurrence. """ pairs = find_near_duplicates(rows, key_fn, threshold, prior_signatures=prior_signatures) # Partition into disjoint drop-reason sets (Wave-21 review P2: a row that # is both a within-run AND cross-generation duplicate must count once; # cross-generation wins the attribution since the prior run owns the row). drop_cross: set[int] = {i for i, j in pairs if j < 0} drop_within: set[int] = {j for _, j in pairs if j >= 0} - drop_cross drop = drop_cross | drop_within kept = [r for i, r in enumerate(rows) if i not in drop] return kept, { "rows_in": len(rows), "rows_kept": len(kept), "dropped_within_run": len(drop_within), "dropped_cross_generation": len(drop_cross), "threshold": threshold, } def signatures_to_jsonl(rows: Sequence[dict], key_fn: Callable[[dict], str], fh: IO[str]) -> int: """Persist this run's signatures so the NEXT generation can dedup against them (pass the loaded list as `prior_signatures`).""" n = 0 for r in rows: fh.write(json.dumps(list(minhash_signature(key_fn(r)))) + "\n") n += 1 return n def load_signatures(fh: IO[str]) -> list[tuple[int, ...]]: return [tuple(json.loads(line)) for line in fh if line.strip()] __all__ = [ "N_PERMUTATIONS", "dedup", "find_near_duplicates", "jaccard_estimate", "load_signatures", "minhash_signature", "signatures_to_jsonl", ]