File size: 3,442 Bytes
5850885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Deterministic result verification.

Canonicalizes floats to `FLOAT_PRECISION` decimal places and treats NULL
uniformly so that two result sets with the same semantic content hash to
the same digest regardless of row order, floating-point noise, or None vs
SQL NULL representation.
"""

from __future__ import annotations

import hashlib
from collections.abc import Iterable
from typing import Any

NULL_SENTINEL = "\x00NULL\x00"
FLOAT_PRECISION = 6
_DIGEST_MODULUS = 1 << 256


def _normalize_value(v: Any) -> Any:
    if v is None:
        return NULL_SENTINEL
    if isinstance(v, float):
        # NaN hashes to itself here — a NaN in rows yields a deterministic
        # digest, but two NaNs round-trip to different representations when
        # we `repr`. Guard explicitly.
        if v != v:  # NaN
            return NULL_SENTINEL
        return round(v, FLOAT_PRECISION)
    return v


def _row_digest_int(row: Iterable[Any]) -> int:
    normalized = tuple(_normalize_value(v) for v in row)
    digest = hashlib.sha256(repr(normalized).encode()).digest()
    return int.from_bytes(digest, "big", signed=False)


def canonical_row_hash(rows: Iterable[Iterable[Any]]) -> str:
    """Order-independent hash of a result set.

    This stays order-independent and duplicate-sensitive without
    materializing the full result in memory. Each normalized row is
    hashed once, then folded into three commutative accumulators so the
    final digest is stable across row order and Python processes.
    """
    row_count = 0
    sum_acc = 0
    sumsq_acc = 0
    xor_acc = 0
    for row in rows:
        row_count += 1
        row_digest = _row_digest_int(row)
        sum_acc = (sum_acc + row_digest) % _DIGEST_MODULUS
        sumsq_acc = (sumsq_acc + ((row_digest * row_digest) % _DIGEST_MODULUS)) % _DIGEST_MODULUS
        xor_acc ^= row_digest
    payload = b"".join(
        (
            row_count.to_bytes(32, "big", signed=False),
            sum_acc.to_bytes(32, "big", signed=False),
            sumsq_acc.to_bytes(32, "big", signed=False),
            xor_acc.to_bytes(32, "big", signed=False),
        )
    )
    return hashlib.sha256(payload).hexdigest()


def result_matches(agent_rows: Iterable[Iterable[Any]], gt_hash: str) -> bool:
    """True if `agent_rows` canonicalizes to the ground-truth hash."""
    return canonical_row_hash(agent_rows) == gt_hash


def row_set_jaccard(a: Iterable[Iterable[Any]], b: Iterable[Iterable[Any]]) -> float:
    """Jaccard over normalized row sets (order- and duplicate-insensitive).

    Each input row is normalised with :func:`_normalize_value` and
    collapsed into a :class:`frozenset`-style Python ``set``, so rows
    that repeat within a single result are counted once. This is
    deliberately *not* a multiset Jaccard — multiset semantics would
    punish correct queries that legitimately emit duplicates more
    harshly than intended.

    Not used by the lean reward today, but kept covered by tests so
    we can opt in later without rework.
    """
    norm_a = {tuple(_normalize_value(v) for v in row) for row in a}
    norm_b = {tuple(_normalize_value(v) for v in row) for row in b}
    if not norm_a and not norm_b:
        return 1.0
    union = norm_a | norm_b
    inter = norm_a & norm_b
    return len(inter) / len(union)


__all__ = [
    "FLOAT_PRECISION",
    "NULL_SENTINEL",
    "canonical_row_hash",
    "result_matches",
    "row_set_jaccard",
]