File size: 9,402 Bytes
d76dce2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """V8 (193-dim) Path A feature encoder for engine/bitnet_classifier.py.
Layout
======
Per-drug encoding (90 features Γ 2 = 180):
[0..64) 64 BLAKE2b ternary hash trits β {-1, 0, +1}
[64..90) 26 ATC pharmacology flag bits β {0, 1}, ordered by
`docs/pharmacology_flags.json` ``flag_keys``.
Pair-level encoding (13 features):
[180..193) 13 pair-derived DDI-rule bits β {0, 1}.
Total: 64 + 26 + 64 + 26 + 13 = 193 trits/bits.
Order canonicalisation
----------------------
Drug pairs are sorted lexicographically before encoding so that
`{warfarin, ibuprofen}` and `{ibuprofen, warfarin}` produce the same
193-dim vector. Same canonicalisation as `engine/bitnet_classifier`.
Source of truth
---------------
The encoder is bit-identical to `retrain_runpod/train_bitnet_v8_h256.py`
since the v8 ternary weights bundle (1f0f8859β¦) was trained against this
exact pipeline. Any divergence here would silently change forward-pass
output and invalidate the audit-chain bundle_id binding.
"""
from __future__ import annotations
import hashlib
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
_REPO_ROOT = Path(__file__).resolve().parent.parent
_PHARM_FLAGS_PATH = _REPO_ROOT / "docs" / "pharmacology_flags.json"
# Distribution-balanced trit table (50% zeros, 25% +1, 25% -1) β matches
# the table used in `engine/bitnet_classifier._encode_drug_token` and in
# the v8 trainer.
_TRIT_LOOKUP: tuple[int, ...] = (
0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1,
-1, -1, -1, -1,
)
# 64-byte hash digest size hits 64 trits cleanly via the 4-trit-per-byte
# extraction below; 16-byte BLAKE2b key matches the v8 trainer.
_BLAKE2B_DIGEST_SIZE = 16
_NITRATE_NAMES = frozenset({
"isosorbide mononitrate",
"isosorbide dinitrate",
"nitroglycerin",
})
# Cached pharmacology flag table β read once at module import.
_FLAGS_DOC = json.loads(_PHARM_FLAGS_PATH.read_text(encoding="utf-8"))
FLAG_KEYS: tuple[str, ...] = tuple(_FLAGS_DOC["flag_keys"])
_FLAG_DRUGS: dict[str, dict] = _FLAGS_DOC["drugs"]
N_HASH_TRITS = 64
N_FLAG_BITS = len(FLAG_KEYS)
N_PER_DRUG = N_HASH_TRITS + N_FLAG_BITS
N_PAIR_DERIVED = 13 # iter-140: 6 baseline + 7 closure rules
FEAT_DIM = N_PER_DRUG * 2 + N_PAIR_DERIVED
# Iter-279: module-load purity preserved (the engine arch-mind gate
# requires every engine module to be pure on import). The flag-table
# snapshot identifier and counts are surfaced at FIRST USE via the
# OOV warning's `extra` block instead β lets auditors correlate every
# BitNet decision to the encoder version without breaking purity.
# Latch: emit the load-context DEBUG ONCE per process on the first
# encode_pair_v8 call, instead of at module import. Same audit
# correlation; preserves engine purity discipline.
_LOAD_CONTEXT_LOGGED = False
def _canonical(name: str) -> str:
"""Lowercase + whitespace-collapse β same canonicalisation as
`_encode_drug_token` in bitnet_classifier."""
return " ".join(name.strip().lower().split())
def hash_trits(name: str) -> list[int]:
"""64-dim ternary hash trits β {-1, 0, +1} via BLAKE2b-128 digest.
Bit-identical to `engine.bitnet_classifier._encode_drug_token` and to
the v8 trainer β both produce the same vector for the same canonical
drug name on every machine.
"""
digest = hashlib.blake2b(
_canonical(name).encode("utf-8"),
digest_size=_BLAKE2B_DIGEST_SIZE,
).digest()
out: list[int] = []
for byte in digest:
out.append(_TRIT_LOOKUP[(byte >> 0) & 0xF])
out.append(_TRIT_LOOKUP[(byte >> 4) & 0xF])
out.append(_TRIT_LOOKUP[byte & 0xF])
out.append(_TRIT_LOOKUP[(byte >> 2) & 0xF])
return out[:N_HASH_TRITS]
def flag_bits(name: str) -> list[int]:
"""26 ATC pharmacology flag bits β {0, 1} per drug.
Unknown drugs β all zeros (the v8 trainer was trained against this
same fall-through, so the model handles it as "no known class
membership").
"""
entry = _FLAG_DRUGS.get(_canonical(name), {"flags": []})
set_flags = set(entry["flags"])
return [1 if k in set_flags else 0 for k in FLAG_KEYS]
def pair_derived_flags(da: str, db: str) -> list[int]:
"""13 pair-derived DDI-rule bits encoding canonical interaction
rules directly. These bypass hash noise to make the decision
boundary explicit.
Each bit fires iff the corresponding rule applies to the (drug_a,
drug_b) pair. Indices match the v8 trainer (and the iter-140
pair-derived rule set):
[0] cyp3a4_inhib_substrate
[1] oatp1b1_inhib_statin
[2] p_gp_inhib_substrate
[3] cyp2c9_inhib_anticoag
[4] maoi_serotonergic
[5] pde5_nitrate (special: nitrate via name suffix)
[6] iodinated_contrast_metformin
[7] cyp1a2_inhib_substrate
[8] xo_thiopurine
[9] folate_antagonist_pair (both drugs same flag)
[10] tetracycline_retinoid
[11] ace_neprilysin
[12] metformin_renal
"""
fa = set(_FLAG_DRUGS.get(_canonical(da), {"flags": []})["flags"])
fb = set(_FLAG_DRUGS.get(_canonical(db), {"flags": []})["flags"])
def has_pair(flag_x: str, flag_y: str) -> bool:
return (flag_x in fa and flag_y in fb) or (flag_x in fb and flag_y in fa)
def both_have(flag: str) -> bool:
return flag in fa and flag in fb
a_norm = _canonical(da)
b_norm = _canonical(db)
pde5_nitrate = (
("is_pde5_inhibitor" in fa and b_norm in _NITRATE_NAMES)
or ("is_pde5_inhibitor" in fb and a_norm in _NITRATE_NAMES)
)
return [
1 if has_pair("is_cyp3a4_strong_inhibitor", "is_cyp3a4_substrate") else 0,
1 if has_pair("is_oatp1b1_inhibitor", "is_statin") else 0,
1 if has_pair("is_p_gp_inhibitor", "is_p_gp_substrate") else 0,
1 if has_pair("is_cyp2c9_inhibitor", "is_anticoagulant") else 0,
1 if has_pair("is_maoi", "is_serotonergic") else 0,
1 if pde5_nitrate else 0,
1 if has_pair("is_iodinated_contrast", "is_metformin") else 0,
1 if has_pair("is_cyp1a2_inhibitor", "is_cyp1a2_substrate") else 0,
1 if has_pair("is_xanthine_oxidase_inhibitor", "is_thiopurine") else 0,
1 if both_have("is_folate_antagonist") else 0,
1 if has_pair("is_tetracycline", "is_retinoid") else 0,
1 if has_pair("is_ace_inhibitor", "is_neprilysin_inhibitor") else 0,
1 if has_pair("is_metformin", "is_renal_state") else 0,
]
def encode_pair_v8(drug_a: str, drug_b: str) -> list[int]:
"""V8 193-dim feature vector for an order-canonicalised drug pair.
Layout: hash_trits(a) + flag_bits(a) + hash_trits(b) + flag_bits(b)
+ pair_derived_flags(a, b). Bit-identical to the v8 trainer's
``encode_pair``.
Emits a structured WARNING when EITHER drug is unknown to the flag
table β this is the OOV signal that says the model is falling back
to hash-only encoding for that drug, which is a safety-relevant
quality-of-prediction event (the cohort-aggregate recall claim
`43/43` covers in-distribution drugs only).
"""
global _LOAD_CONTEXT_LOGGED
if not _LOAD_CONTEXT_LOGGED:
# Iter-279: emit the load-context DEBUG on first call instead of
# at import (preserves engine module purity for the arch-mind
# gate). Auditors get the same correlation between decisions and
# the flag-table snapshot.
logger.debug(
"bitnet_features_v8_loaded",
extra={
"flags_path_basename": _PHARM_FLAGS_PATH.name,
"flag_keys_count": N_FLAG_BITS,
"drug_count": len(_FLAG_DRUGS),
"n_pair_derived": N_PAIR_DERIVED,
"feat_dim": FEAT_DIM,
},
)
_LOAD_CONTEXT_LOGGED = True
a, b = sorted((drug_a, drug_b))
a_canon = _canonical(a)
b_canon = _canonical(b)
a_known = a_canon in _FLAG_DRUGS
b_known = b_canon in _FLAG_DRUGS
if not (a_known and b_known):
# PHI-safe shape: drug-name fields hashed via the same SHA-256
# canonicalisation engine.bitnet_classifier uses for feature
# hashes (NOT raw names). Auditors get a stable identifier that
# ties the OOV event to the audit-replay row without leaking
# patient-context information through the log.
logger.warning(
"bitnet_v8_oov_drug",
extra={
"drug_a_known": a_known,
"drug_b_known": b_known,
"drug_a_hash_prefix": hashlib.sha256(
a_canon.encode("utf-8")
).hexdigest()[:16],
"drug_b_hash_prefix": hashlib.sha256(
b_canon.encode("utf-8")
).hexdigest()[:16],
"fallback": "hash_only_encoding",
"feat_dim": FEAT_DIM,
},
)
out = (
hash_trits(a)
+ flag_bits(a)
+ hash_trits(b)
+ flag_bits(b)
+ pair_derived_flags(a, b)
)
if len(out) != FEAT_DIM:
logger.error(
"bitnet_v8_encoder_dim_mismatch",
extra={
"expected_dim": FEAT_DIM,
"actual_dim": len(out),
},
)
raise RuntimeError(
f"v8 encoder produced {len(out)}-dim vector, expected {FEAT_DIM}"
)
return out
|