File size: 9,402 Bytes
d76dce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""V8 (193-dim) Path A feature encoder for engine/bitnet_classifier.py.

Layout
======
Per-drug encoding (90 features Γ— 2 = 180):
    [0..64)   64 BLAKE2b ternary hash trits ∈ {-1, 0, +1}
    [64..90)  26 ATC pharmacology flag bits ∈ {0, 1}, ordered by
              `docs/pharmacology_flags.json` ``flag_keys``.

Pair-level encoding (13 features):
    [180..193)  13 pair-derived DDI-rule bits ∈ {0, 1}.

Total: 64 + 26 + 64 + 26 + 13 = 193 trits/bits.

Order canonicalisation
----------------------
Drug pairs are sorted lexicographically before encoding so that
`{warfarin, ibuprofen}` and `{ibuprofen, warfarin}` produce the same
193-dim vector. Same canonicalisation as `engine/bitnet_classifier`.

Source of truth
---------------
The encoder is bit-identical to `retrain_runpod/train_bitnet_v8_h256.py`
since the v8 ternary weights bundle (1f0f8859…) was trained against this
exact pipeline. Any divergence here would silently change forward-pass
output and invalidate the audit-chain bundle_id binding.
"""
from __future__ import annotations

import hashlib
import json
import logging
from pathlib import Path

logger = logging.getLogger(__name__)

_REPO_ROOT = Path(__file__).resolve().parent.parent
_PHARM_FLAGS_PATH = _REPO_ROOT / "docs" / "pharmacology_flags.json"

# Distribution-balanced trit table (50% zeros, 25% +1, 25% -1) β€” matches
# the table used in `engine/bitnet_classifier._encode_drug_token` and in
# the v8 trainer.
_TRIT_LOOKUP: tuple[int, ...] = (
    0, 0, 0, 0, 0, 0, 0, 0,
    1, 1, 1, 1,
    -1, -1, -1, -1,
)

# 64-byte hash digest size hits 64 trits cleanly via the 4-trit-per-byte
# extraction below; 16-byte BLAKE2b key matches the v8 trainer.
_BLAKE2B_DIGEST_SIZE = 16

_NITRATE_NAMES = frozenset({
    "isosorbide mononitrate",
    "isosorbide dinitrate",
    "nitroglycerin",
})

# Cached pharmacology flag table β€” read once at module import.
_FLAGS_DOC = json.loads(_PHARM_FLAGS_PATH.read_text(encoding="utf-8"))
FLAG_KEYS: tuple[str, ...] = tuple(_FLAGS_DOC["flag_keys"])
_FLAG_DRUGS: dict[str, dict] = _FLAGS_DOC["drugs"]

N_HASH_TRITS = 64
N_FLAG_BITS = len(FLAG_KEYS)
N_PER_DRUG = N_HASH_TRITS + N_FLAG_BITS
N_PAIR_DERIVED = 13  # iter-140: 6 baseline + 7 closure rules
FEAT_DIM = N_PER_DRUG * 2 + N_PAIR_DERIVED

# Iter-279: module-load purity preserved (the engine arch-mind gate
# requires every engine module to be pure on import). The flag-table
# snapshot identifier and counts are surfaced at FIRST USE via the
# OOV warning's `extra` block instead β€” lets auditors correlate every
# BitNet decision to the encoder version without breaking purity.

# Latch: emit the load-context DEBUG ONCE per process on the first
# encode_pair_v8 call, instead of at module import. Same audit
# correlation; preserves engine purity discipline.
_LOAD_CONTEXT_LOGGED = False


def _canonical(name: str) -> str:
    """Lowercase + whitespace-collapse β€” same canonicalisation as
    `_encode_drug_token` in bitnet_classifier."""
    return " ".join(name.strip().lower().split())


def hash_trits(name: str) -> list[int]:
    """64-dim ternary hash trits ∈ {-1, 0, +1} via BLAKE2b-128 digest.

    Bit-identical to `engine.bitnet_classifier._encode_drug_token` and to
    the v8 trainer β€” both produce the same vector for the same canonical
    drug name on every machine.
    """
    digest = hashlib.blake2b(
        _canonical(name).encode("utf-8"),
        digest_size=_BLAKE2B_DIGEST_SIZE,
    ).digest()
    out: list[int] = []
    for byte in digest:
        out.append(_TRIT_LOOKUP[(byte >> 0) & 0xF])
        out.append(_TRIT_LOOKUP[(byte >> 4) & 0xF])
        out.append(_TRIT_LOOKUP[byte & 0xF])
        out.append(_TRIT_LOOKUP[(byte >> 2) & 0xF])
    return out[:N_HASH_TRITS]


def flag_bits(name: str) -> list[int]:
    """26 ATC pharmacology flag bits ∈ {0, 1} per drug.

    Unknown drugs β†’ all zeros (the v8 trainer was trained against this
    same fall-through, so the model handles it as "no known class
    membership").
    """
    entry = _FLAG_DRUGS.get(_canonical(name), {"flags": []})
    set_flags = set(entry["flags"])
    return [1 if k in set_flags else 0 for k in FLAG_KEYS]


def pair_derived_flags(da: str, db: str) -> list[int]:
    """13 pair-derived DDI-rule bits encoding canonical interaction
    rules directly. These bypass hash noise to make the decision
    boundary explicit.

    Each bit fires iff the corresponding rule applies to the (drug_a,
    drug_b) pair. Indices match the v8 trainer (and the iter-140
    pair-derived rule set):

      [0]  cyp3a4_inhib_substrate
      [1]  oatp1b1_inhib_statin
      [2]  p_gp_inhib_substrate
      [3]  cyp2c9_inhib_anticoag
      [4]  maoi_serotonergic
      [5]  pde5_nitrate            (special: nitrate via name suffix)
      [6]  iodinated_contrast_metformin
      [7]  cyp1a2_inhib_substrate
      [8]  xo_thiopurine
      [9]  folate_antagonist_pair  (both drugs same flag)
      [10] tetracycline_retinoid
      [11] ace_neprilysin
      [12] metformin_renal
    """
    fa = set(_FLAG_DRUGS.get(_canonical(da), {"flags": []})["flags"])
    fb = set(_FLAG_DRUGS.get(_canonical(db), {"flags": []})["flags"])

    def has_pair(flag_x: str, flag_y: str) -> bool:
        return (flag_x in fa and flag_y in fb) or (flag_x in fb and flag_y in fa)

    def both_have(flag: str) -> bool:
        return flag in fa and flag in fb

    a_norm = _canonical(da)
    b_norm = _canonical(db)
    pde5_nitrate = (
        ("is_pde5_inhibitor" in fa and b_norm in _NITRATE_NAMES)
        or ("is_pde5_inhibitor" in fb and a_norm in _NITRATE_NAMES)
    )

    return [
        1 if has_pair("is_cyp3a4_strong_inhibitor", "is_cyp3a4_substrate") else 0,
        1 if has_pair("is_oatp1b1_inhibitor", "is_statin") else 0,
        1 if has_pair("is_p_gp_inhibitor", "is_p_gp_substrate") else 0,
        1 if has_pair("is_cyp2c9_inhibitor", "is_anticoagulant") else 0,
        1 if has_pair("is_maoi", "is_serotonergic") else 0,
        1 if pde5_nitrate else 0,
        1 if has_pair("is_iodinated_contrast", "is_metformin") else 0,
        1 if has_pair("is_cyp1a2_inhibitor", "is_cyp1a2_substrate") else 0,
        1 if has_pair("is_xanthine_oxidase_inhibitor", "is_thiopurine") else 0,
        1 if both_have("is_folate_antagonist") else 0,
        1 if has_pair("is_tetracycline", "is_retinoid") else 0,
        1 if has_pair("is_ace_inhibitor", "is_neprilysin_inhibitor") else 0,
        1 if has_pair("is_metformin", "is_renal_state") else 0,
    ]


def encode_pair_v8(drug_a: str, drug_b: str) -> list[int]:
    """V8 193-dim feature vector for an order-canonicalised drug pair.

    Layout: hash_trits(a) + flag_bits(a) + hash_trits(b) + flag_bits(b)
    + pair_derived_flags(a, b). Bit-identical to the v8 trainer's
    ``encode_pair``.

    Emits a structured WARNING when EITHER drug is unknown to the flag
    table β€” this is the OOV signal that says the model is falling back
    to hash-only encoding for that drug, which is a safety-relevant
    quality-of-prediction event (the cohort-aggregate recall claim
    `43/43` covers in-distribution drugs only).
    """
    global _LOAD_CONTEXT_LOGGED
    if not _LOAD_CONTEXT_LOGGED:
        # Iter-279: emit the load-context DEBUG on first call instead of
        # at import (preserves engine module purity for the arch-mind
        # gate). Auditors get the same correlation between decisions and
        # the flag-table snapshot.
        logger.debug(
            "bitnet_features_v8_loaded",
            extra={
                "flags_path_basename": _PHARM_FLAGS_PATH.name,
                "flag_keys_count": N_FLAG_BITS,
                "drug_count": len(_FLAG_DRUGS),
                "n_pair_derived": N_PAIR_DERIVED,
                "feat_dim": FEAT_DIM,
            },
        )
        _LOAD_CONTEXT_LOGGED = True

    a, b = sorted((drug_a, drug_b))
    a_canon = _canonical(a)
    b_canon = _canonical(b)
    a_known = a_canon in _FLAG_DRUGS
    b_known = b_canon in _FLAG_DRUGS
    if not (a_known and b_known):
        # PHI-safe shape: drug-name fields hashed via the same SHA-256
        # canonicalisation engine.bitnet_classifier uses for feature
        # hashes (NOT raw names). Auditors get a stable identifier that
        # ties the OOV event to the audit-replay row without leaking
        # patient-context information through the log.
        logger.warning(
            "bitnet_v8_oov_drug",
            extra={
                "drug_a_known": a_known,
                "drug_b_known": b_known,
                "drug_a_hash_prefix": hashlib.sha256(
                    a_canon.encode("utf-8")
                ).hexdigest()[:16],
                "drug_b_hash_prefix": hashlib.sha256(
                    b_canon.encode("utf-8")
                ).hexdigest()[:16],
                "fallback": "hash_only_encoding",
                "feat_dim": FEAT_DIM,
            },
        )

    out = (
        hash_trits(a)
        + flag_bits(a)
        + hash_trits(b)
        + flag_bits(b)
        + pair_derived_flags(a, b)
    )
    if len(out) != FEAT_DIM:
        logger.error(
            "bitnet_v8_encoder_dim_mismatch",
            extra={
                "expected_dim": FEAT_DIM,
                "actual_dim": len(out),
            },
        )
        raise RuntimeError(
            f"v8 encoder produced {len(out)}-dim vector, expected {FEAT_DIM}"
        )
    return out