File size: 5,196 Bytes
9a2ce20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bbcf21
 
 
 
 
 
9a2ce20
 
 
 
3bbcf21
 
9a2ce20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""dedup.py — MinHash near-duplicate detection (finding D-12).

Cross-generation dedup is a flywheel-collapse mitigation: a self-training loop
that re-ingests its own outputs accumulates near-identical rows, and per-batch
`document_deduplicator` (the only dedup the old designs had) never sees across
runs. This module computes MinHash signatures over word 5-shingles so a run can
(a) dedup within itself and (b) accept the PRIOR run's signature file and dedup
against it (lineage threaded by `RunManifest.parent_run_id`).

Pragmatic v1: builtin-hash permutation MinHash with N=64 seeds, no banding/LSH
(O(n^2) pair scan — fine for Stage-0 corpus sizes, thousands of rows).
`datasketch` (MinHashLSH) is the documented upgrade path when row counts make
the pair scan bite.

NOTE on hash stability: Python's builtin `hash()` over str is salted per
process (PYTHONHASHSEED), which would make signatures non-portable across
runs — exactly what cross-generation dedup needs. We therefore use md5-based
hashing (stable everywhere) despite the small speed cost.
"""
from __future__ import annotations

import hashlib
import json
import re
from typing import IO, Callable, Iterable, Sequence

N_PERMUTATIONS = 64
_SHINGLE_W = 5
_WORD_RE = re.compile(r"\w+")
_MAX64 = (1 << 64) - 1


def _shingles(text: str, w: int = _SHINGLE_W) -> set[str]:
    words = _WORD_RE.findall(text.lower())
    if len(words) <= w:
        return {" ".join(words)} if words else set()
    return {" ".join(words[i:i + w]) for i in range(len(words) - w + 1)}


def _stable_hash(s: str, seed: int) -> int:
    h = hashlib.md5(f"{seed}:{s}".encode()).digest()
    return int.from_bytes(h[:8], "big")


def minhash_signature(text: str, n_perm: int = N_PERMUTATIONS) -> tuple[int, ...]:
    """MinHash signature: per-seed minimum over the shingle set."""
    sh = _shingles(text)
    if not sh:
        return tuple([_MAX64] * n_perm)
    return tuple(min(_stable_hash(s, seed) for s in sh) for seed in range(n_perm))


def jaccard_estimate(sig_a: Sequence[int], sig_b: Sequence[int]) -> float:
    """Estimated Jaccard similarity = fraction of agreeing signature slots."""
    if len(sig_a) != len(sig_b) or not sig_a:
        raise ValueError("signatures must be equal-length and non-empty")
    return sum(1 for a, b in zip(sig_a, sig_b) if a == b) / len(sig_a)


def find_near_duplicates(
    rows: Sequence[dict],
    key_fn: Callable[[dict], str],
    threshold: float = 0.85,
    *,
    prior_signatures: Sequence[Sequence[int]] | None = None,
) -> list[tuple[int, int]]:
    """All (i, j) index pairs whose estimated Jaccard >= threshold.

    `prior_signatures` (from a previous run) participate as virtual rows with
    negative indices -(k+1), so a pair (i, -1) means "row i duplicates prior
    signature 0" — the cross-generation case.
    """
    sigs = [minhash_signature(key_fn(r)) for r in rows]
    pairs: list[tuple[int, int]] = []
    for i in range(len(sigs)):
        for j in range(i + 1, len(sigs)):
            if jaccard_estimate(sigs[i], sigs[j]) >= threshold:
                pairs.append((i, j))
        for k, prior in enumerate(prior_signatures or []):
            if jaccard_estimate(sigs[i], prior) >= threshold:
                pairs.append((i, -(k + 1)))
    return pairs


def dedup(
    rows: Sequence[dict],
    key_fn: Callable[[dict], str],
    threshold: float = 0.85,
    *,
    prior_signatures: Sequence[Sequence[int]] | None = None,
) -> tuple[list[dict], dict]:
    """Keep-first dedup. Returns (kept_rows, stats).

    A row duplicating a PRIOR-run signature is dropped outright (the prior run
    already owns it); within-run duplicates keep the earliest occurrence.
    """
    pairs = find_near_duplicates(rows, key_fn, threshold,
                                 prior_signatures=prior_signatures)
    # Partition into disjoint drop-reason sets (Wave-21 review P2: a row that
    # is both a within-run AND cross-generation duplicate must count once;
    # cross-generation wins the attribution since the prior run owns the row).
    drop_cross: set[int] = {i for i, j in pairs if j < 0}
    drop_within: set[int] = {j for _, j in pairs if j >= 0} - drop_cross
    drop = drop_cross | drop_within
    kept = [r for i, r in enumerate(rows) if i not in drop]
    return kept, {
        "rows_in": len(rows),
        "rows_kept": len(kept),
        "dropped_within_run": len(drop_within),
        "dropped_cross_generation": len(drop_cross),
        "threshold": threshold,
    }


def signatures_to_jsonl(rows: Sequence[dict], key_fn: Callable[[dict], str],
                        fh: IO[str]) -> int:
    """Persist this run's signatures so the NEXT generation can dedup against
    them (pass the loaded list as `prior_signatures`)."""
    n = 0
    for r in rows:
        fh.write(json.dumps(list(minhash_signature(key_fn(r)))) + "\n")
        n += 1
    return n


def load_signatures(fh: IO[str]) -> list[tuple[int, ...]]:
    return [tuple(json.loads(line)) for line in fh if line.strip()]


__all__ = [
    "N_PERMUTATIONS",
    "dedup",
    "find_near_duplicates",
    "jaccard_estimate",
    "load_signatures",
    "minhash_signature",
    "signatures_to_jsonl",
]