| """V8 (193-dim) Path A feature encoder for engine/bitnet_classifier.py. |
| |
| Layout |
| ====== |
| Per-drug encoding (90 features × 2 = 180): |
| [0..64) 64 BLAKE2b ternary hash trits ∈ {-1, 0, +1} |
| [64..90) 26 ATC pharmacology flag bits ∈ {0, 1}, ordered by |
| `docs/pharmacology_flags.json` ``flag_keys``. |
| |
| Pair-level encoding (13 features): |
| [180..193) 13 pair-derived DDI-rule bits ∈ {0, 1}. |
| |
| Total: 64 + 26 + 64 + 26 + 13 = 193 trits/bits. |
| |
| Order canonicalisation |
| ---------------------- |
| Drug pairs are sorted lexicographically before encoding so that |
| `{warfarin, ibuprofen}` and `{ibuprofen, warfarin}` produce the same |
| 193-dim vector. Same canonicalisation as `engine/bitnet_classifier`. |
| |
| Source of truth |
| --------------- |
| The encoder is bit-identical to `retrain_runpod/train_bitnet_v8_h256.py` |
| since the v8 ternary weights bundle (1f0f8859…) was trained against this |
| exact pipeline. Any divergence here would silently change forward-pass |
| output and invalidate the audit-chain bundle_id binding. |
| """ |
| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| import logging |
| from pathlib import Path |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _REPO_ROOT = Path(__file__).resolve().parent.parent |
| _PHARM_FLAGS_PATH = _REPO_ROOT / "docs" / "pharmacology_flags.json" |
|
|
| |
| |
| |
| _TRIT_LOOKUP: tuple[int, ...] = ( |
| 0, 0, 0, 0, 0, 0, 0, 0, |
| 1, 1, 1, 1, |
| -1, -1, -1, -1, |
| ) |
|
|
| |
| |
| _BLAKE2B_DIGEST_SIZE = 16 |
|
|
| _NITRATE_NAMES = frozenset({ |
| "isosorbide mononitrate", |
| "isosorbide dinitrate", |
| "nitroglycerin", |
| }) |
|
|
| |
| _FLAGS_DOC = json.loads(_PHARM_FLAGS_PATH.read_text(encoding="utf-8")) |
| FLAG_KEYS: tuple[str, ...] = tuple(_FLAGS_DOC["flag_keys"]) |
| _FLAG_DRUGS: dict[str, dict] = _FLAGS_DOC["drugs"] |
|
|
| N_HASH_TRITS = 64 |
| N_FLAG_BITS = len(FLAG_KEYS) |
| N_PER_DRUG = N_HASH_TRITS + N_FLAG_BITS |
| N_PAIR_DERIVED = 13 |
| FEAT_DIM = N_PER_DRUG * 2 + N_PAIR_DERIVED |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| _LOAD_CONTEXT_LOGGED = False |
|
|
|
|
| def _canonical(name: str) -> str: |
| """Lowercase + whitespace-collapse — same canonicalisation as |
| `_encode_drug_token` in bitnet_classifier.""" |
| return " ".join(name.strip().lower().split()) |
|
|
|
|
| def hash_trits(name: str) -> list[int]: |
| """64-dim ternary hash trits ∈ {-1, 0, +1} via BLAKE2b-128 digest. |
| |
| Bit-identical to `engine.bitnet_classifier._encode_drug_token` and to |
| the v8 trainer — both produce the same vector for the same canonical |
| drug name on every machine. |
| """ |
| digest = hashlib.blake2b( |
| _canonical(name).encode("utf-8"), |
| digest_size=_BLAKE2B_DIGEST_SIZE, |
| ).digest() |
| out: list[int] = [] |
| for byte in digest: |
| out.append(_TRIT_LOOKUP[(byte >> 0) & 0xF]) |
| out.append(_TRIT_LOOKUP[(byte >> 4) & 0xF]) |
| out.append(_TRIT_LOOKUP[byte & 0xF]) |
| out.append(_TRIT_LOOKUP[(byte >> 2) & 0xF]) |
| return out[:N_HASH_TRITS] |
|
|
|
|
| def flag_bits(name: str) -> list[int]: |
| """26 ATC pharmacology flag bits ∈ {0, 1} per drug. |
| |
| Unknown drugs → all zeros (the v8 trainer was trained against this |
| same fall-through, so the model handles it as "no known class |
| membership"). |
| """ |
| entry = _FLAG_DRUGS.get(_canonical(name), {"flags": []}) |
| set_flags = set(entry["flags"]) |
| return [1 if k in set_flags else 0 for k in FLAG_KEYS] |
|
|
|
|
| def pair_derived_flags(da: str, db: str) -> list[int]: |
| """13 pair-derived DDI-rule bits encoding canonical interaction |
| rules directly. These bypass hash noise to make the decision |
| boundary explicit. |
| |
| Each bit fires iff the corresponding rule applies to the (drug_a, |
| drug_b) pair. Indices match the v8 trainer (and the iter-140 |
| pair-derived rule set): |
| |
| [0] cyp3a4_inhib_substrate |
| [1] oatp1b1_inhib_statin |
| [2] p_gp_inhib_substrate |
| [3] cyp2c9_inhib_anticoag |
| [4] maoi_serotonergic |
| [5] pde5_nitrate (special: nitrate via name suffix) |
| [6] iodinated_contrast_metformin |
| [7] cyp1a2_inhib_substrate |
| [8] xo_thiopurine |
| [9] folate_antagonist_pair (both drugs same flag) |
| [10] tetracycline_retinoid |
| [11] ace_neprilysin |
| [12] metformin_renal |
| """ |
| fa = set(_FLAG_DRUGS.get(_canonical(da), {"flags": []})["flags"]) |
| fb = set(_FLAG_DRUGS.get(_canonical(db), {"flags": []})["flags"]) |
|
|
| def has_pair(flag_x: str, flag_y: str) -> bool: |
| return (flag_x in fa and flag_y in fb) or (flag_x in fb and flag_y in fa) |
|
|
| def both_have(flag: str) -> bool: |
| return flag in fa and flag in fb |
|
|
| a_norm = _canonical(da) |
| b_norm = _canonical(db) |
| pde5_nitrate = ( |
| ("is_pde5_inhibitor" in fa and b_norm in _NITRATE_NAMES) |
| or ("is_pde5_inhibitor" in fb and a_norm in _NITRATE_NAMES) |
| ) |
|
|
| return [ |
| 1 if has_pair("is_cyp3a4_strong_inhibitor", "is_cyp3a4_substrate") else 0, |
| 1 if has_pair("is_oatp1b1_inhibitor", "is_statin") else 0, |
| 1 if has_pair("is_p_gp_inhibitor", "is_p_gp_substrate") else 0, |
| 1 if has_pair("is_cyp2c9_inhibitor", "is_anticoagulant") else 0, |
| 1 if has_pair("is_maoi", "is_serotonergic") else 0, |
| 1 if pde5_nitrate else 0, |
| 1 if has_pair("is_iodinated_contrast", "is_metformin") else 0, |
| 1 if has_pair("is_cyp1a2_inhibitor", "is_cyp1a2_substrate") else 0, |
| 1 if has_pair("is_xanthine_oxidase_inhibitor", "is_thiopurine") else 0, |
| 1 if both_have("is_folate_antagonist") else 0, |
| 1 if has_pair("is_tetracycline", "is_retinoid") else 0, |
| 1 if has_pair("is_ace_inhibitor", "is_neprilysin_inhibitor") else 0, |
| 1 if has_pair("is_metformin", "is_renal_state") else 0, |
| ] |
|
|
|
|
| def encode_pair_v8(drug_a: str, drug_b: str) -> list[int]: |
| """V8 193-dim feature vector for an order-canonicalised drug pair. |
| |
| Layout: hash_trits(a) + flag_bits(a) + hash_trits(b) + flag_bits(b) |
| + pair_derived_flags(a, b). Bit-identical to the v8 trainer's |
| ``encode_pair``. |
| |
| Emits a structured WARNING when EITHER drug is unknown to the flag |
| table — this is the OOV signal that says the model is falling back |
| to hash-only encoding for that drug, which is a safety-relevant |
| quality-of-prediction event (the cohort-aggregate recall claim |
| `43/43` covers in-distribution drugs only). |
| """ |
| global _LOAD_CONTEXT_LOGGED |
| if not _LOAD_CONTEXT_LOGGED: |
| |
| |
| |
| |
| logger.debug( |
| "bitnet_features_v8_loaded", |
| extra={ |
| "flags_path_basename": _PHARM_FLAGS_PATH.name, |
| "flag_keys_count": N_FLAG_BITS, |
| "drug_count": len(_FLAG_DRUGS), |
| "n_pair_derived": N_PAIR_DERIVED, |
| "feat_dim": FEAT_DIM, |
| }, |
| ) |
| _LOAD_CONTEXT_LOGGED = True |
|
|
| a, b = sorted((drug_a, drug_b)) |
| a_canon = _canonical(a) |
| b_canon = _canonical(b) |
| a_known = a_canon in _FLAG_DRUGS |
| b_known = b_canon in _FLAG_DRUGS |
| if not (a_known and b_known): |
| |
| |
| |
| |
| |
| logger.warning( |
| "bitnet_v8_oov_drug", |
| extra={ |
| "drug_a_known": a_known, |
| "drug_b_known": b_known, |
| "drug_a_hash_prefix": hashlib.sha256( |
| a_canon.encode("utf-8") |
| ).hexdigest()[:16], |
| "drug_b_hash_prefix": hashlib.sha256( |
| b_canon.encode("utf-8") |
| ).hexdigest()[:16], |
| "fallback": "hash_only_encoding", |
| "feat_dim": FEAT_DIM, |
| }, |
| ) |
|
|
| out = ( |
| hash_trits(a) |
| + flag_bits(a) |
| + hash_trits(b) |
| + flag_bits(b) |
| + pair_derived_flags(a, b) |
| ) |
| if len(out) != FEAT_DIM: |
| logger.error( |
| "bitnet_v8_encoder_dim_mismatch", |
| extra={ |
| "expected_dim": FEAT_DIM, |
| "actual_dim": len(out), |
| }, |
| ) |
| raise RuntimeError( |
| f"v8 encoder produced {len(out)}-dim vector, expected {FEAT_DIM}" |
| ) |
| return out |
|
|