hcli-bitnet-b158 / bitnet_classifier.py
star-ga's picture
Initial release — BitNet b1.58 cascade (A gate + B specialist) for clinical DDI severity classification, FDA SaMD reproducibility primitive
d76dce2 verified
"""BitNet b1.58 ternary drug-interaction classifier.
A clean-room Python implementation of a BitNet b1.58-style ternary linear
classifier for drug-drug-interaction (DDI) severity. The forward pass is
pure integer arithmetic over Q16.16 fixed-point activations and ternary
weights ∈ {-1, 0, +1}, which makes the output **bit-identical across
architectures** (ARM, x86_64, CUDA, NPU) — the same reproducibility
guarantee the FDA expects for production AI / ML SaMD.
Reference: Ma, Wang, Wang et al., "The Era of 1-bit LLMs: All Large Language
Models are in 1.58 Bits," arXiv:2402.17764, 2024.
Why this layer matters in ClinicalMem
─────────────────────────────────────
The 4-tier interaction pipeline already catches known pairs deterministically
and verifies novel ones via 5-LLM US-based consensus. The BitNet layer sits at "Layer
4.5": a determinism-checked, FDA-grade classifier that:
1. Reproduces the deterministic table's outputs bit-identically.
2. Emits a Q16.16-scaled severity logit vector that the audit chain can
hash into the per-decision preimage (TAG_v1 schema).
3. Returns a `repro_hash` that any auditor with this Python file and the
ternary weights bundle can verify against without floating-point math.
The classifier is intentionally small (200-pair training corpus, 64-dim
hidden) — accuracy is bounded by the deterministic table the weights are
fit to. The load-bearing claim is the *architecture*, not the absolute
accuracy: bit-identical integer arithmetic across hardware.
Public scope
────────────
This file is Apache-2.0 licensed alongside the rest of ClinicalMem. It does NOT
vendor any source from the STARGA proprietary toolchain (MindLLM,
rfn-mind, mind-runtime, mind-flow are commercial-licensed and live in
private repositories). The BitNet b1.58 architecture is described in the
public arXiv paper above; this file implements it in pure Python.
Copyright 2026 STARGA, Inc. — Apache-2.0 License.
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# Q16.16 fixed-point: 16 integer bits, 16 fractional bits, signed 32-bit.
# Range: [-32768.0, 32767.99999847412].
Q16_ONE: int = 1 << 16 # 65536 — represents 1.0
Q16_HALF: int = 1 << 15 # 32768 — represents 0.5
Q16_ZERO: int = 0
_Q16_MIN: int = -(1 << 31)
_Q16_MAX: int = (1 << 31) - 1
# Severity classes — must match the deterministic table in clinical_scoring.py
SEVERITY_NONE: int = 0
SEVERITY_MINOR: int = 1
SEVERITY_MODERATE: int = 2
SEVERITY_MAJOR: int = 3
SEVERITY_CONTRAINDICATED: int = 4
# Iter-275 v8 promotion: vocab aligned with the corpus / cache /
# trainer (`retrain_runpod/train_bitnet_v8_h256.py:39 SEV_NAMES`).
# Pre-v8 (cfadb4f6) used `(none, minor, moderate, major,
# contraindicated)` — the engine's first-era vocab — but v3+ trainers
# all use the corpus vocab `(none, moderate, serious, major,
# contraindicated)`. Engine output now matches the cache ground-truth
# vocabulary directly: a class-2 logit emits "serious" (cache match),
# not "moderate" (vocab-skewed v1 mapping).
_SEVERITY_NAMES: tuple[str, ...] = (
"none",
"moderate",
"serious",
"major",
"contraindicated",
)
@dataclass(frozen=True)
class BitNetResult:
"""Result of a BitNet b1.58 forward pass on a drug-pair input.
Every field is integer-valued so the audit chain can record the result
without any float-to-string conversion ambiguity.
"""
severity: int # 0..4 (see SEVERITY_* constants)
severity_name: str # e.g. "major"
logits_q16: tuple[int, ...] # Q16.16 logit per class, in canonical class order
feature_hash: str # Hex SHA-256 over the canonical input encoding
repro_hash: str # Hex SHA-256 over (feature_hash, logits_q16, severity, weights_id)
weights_id: str # The bundle hash recorded at load time
deterministic_table_match: bool # True if the weights reproduce a row in the deterministic table
# ─── Q16.16 arithmetic primitives (bit-identical across architectures) ─────
def _q16_clamp(value: int) -> int:
"""Saturating clamp into the signed 32-bit Q16.16 range."""
if value > _Q16_MAX:
return _Q16_MAX
if value < _Q16_MIN:
return _Q16_MIN
return value
def _q16_relu(value: int) -> int:
"""Clamp negative values to zero. Pure integer compare; no float."""
return value if value > 0 else 0
def _q16_dot_ternary(activations_q16: list[int], ternary_weights: list[int]) -> int:
"""Dot product of a Q16.16 activation vector and a ternary weight row.
Ternary weights are one of {-1, 0, +1}. The product is the activation
itself (or its negation, or zero) — no multiplication required, only
addition and subtraction. The result is the canonical Q16.16 sum
accumulated in row-major left-to-right order (same reduction order
as the rest of the MIND ecosystem's deterministic kernels).
Bit-identical guarantee: this function uses only Python's
arbitrary-precision integers; the output is independent of the
underlying CPU/GPU architecture, FMA ordering, and tensor-core
accumulate semantics.
"""
if len(activations_q16) != len(ternary_weights):
raise ValueError(
f"shape mismatch: act={len(activations_q16)} ternary={len(ternary_weights)}"
)
acc: int = 0
for activation, weight in zip(activations_q16, ternary_weights, strict=True):
if weight == 1:
acc += activation
elif weight == -1:
acc -= activation
# weight == 0 contributes nothing — skipped by design
return _q16_clamp(acc)
# ─── Deterministic feature encoding ────────────────────────────────────────
def _encode_drug_token(rxcui_or_name: str) -> list[int]:
"""Encode a drug identifier as a 64-dim ternary feature vector ∈ {-1, 0, +1}.
The encoding is purely deterministic: the input string is canonicalised
(lowercased, whitespace-collapsed) and hashed with BLAKE2b. Each pair
of bits in the digest produces one ternary feature value via a
distribution-balanced trit table. Same string → same vector on every
machine.
The 64-dim feature size is small enough that the full DrugBank
interaction matrix can be linearly separated by a 64×5 ternary
classifier head; large enough that two distinct drug names hash to
distinct vectors with negligible collision probability.
"""
canonical = " ".join(rxcui_or_name.strip().lower().split())
digest = hashlib.blake2b(canonical.encode("utf-8"), digest_size=16).digest()
# 16 bytes × 4 trits/byte = 64 trits. 4 trits encoded per byte using the
# 2-bit window mapping below (50/50/25/25 distribution biased toward 0
# so most features stay sparse — important for the ternary linear
# classifier's effective rank).
_TRIT_LOOKUP: tuple[int, ...] = (0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, -1, -1, -1, -1)
out: list[int] = []
for byte in digest:
out.append(_TRIT_LOOKUP[(byte >> 0) & 0xF])
out.append(_TRIT_LOOKUP[(byte >> 4) & 0xF])
out.append(_TRIT_LOOKUP[byte & 0xF])
out.append(_TRIT_LOOKUP[(byte >> 2) & 0xF])
return out[:64]
def _q16_scale_features(ternary_features: list[int]) -> list[int]:
"""Lift a ternary feature vector to Q16.16 activations.
Ternary {-1, 0, +1} → Q16.16 {-Q16_ONE, 0, +Q16_ONE}. The scale is
canonical so the dot products in the first linear layer accumulate
in the standard Q16.16 range.
"""
return [v * Q16_ONE for v in ternary_features]
# ─── Weights bundle ────────────────────────────────────────────────────────
_SCHEMA_V1 = "bitnet_classifier_v1"
_SCHEMA_V3_ATC = "bitnet_classifier_v3_atc_flags"
@dataclass(frozen=True)
class BitNetWeights:
"""Loaded ternary-weights bundle.
Layout (matches `engine/bitnet_weights.json`):
schema : one of ``bitnet_classifier_v1`` (128-dim hash-only
encoding, hidden=64) or ``bitnet_classifier_v3_atc_flags``
(193-dim hash + 26 ATC flag + 13 pair-derived encoding,
hidden=256). Drives encoder dispatch in ``classify``.
hidden_w : ``hidden_features`` × ``in_features`` ternary matrix.
hidden_b : Q16.16 bias vector (length ``hidden_features``)
output_w : ``out_features`` × ``hidden_features`` ternary matrix.
output_b : Q16.16 bias vector (length ``out_features``, = 5 for
the 5-severity classifier)
bundle_id : SHA-256 over the canonical JSON encoding of the four
matrices above (stable across loads — the audit chain
records this as the "weights_id" so a verifier can
pin the exact bundle a decision was made under).
"""
hidden_w: list[list[int]]
hidden_b: list[int]
output_w: list[list[int]]
output_b: list[int]
bundle_id: str
schema: str = _SCHEMA_V1
in_features: int = 128
hidden_features: int = 64
out_features: int = 5
def _bundle_id(payload: dict[str, Any]) -> str:
"""SHA-256 over the canonical-JSON encoding of the four weight matrices.
Stable across loads on every machine; same payload → same hash.
"""
canonical = json.dumps(
{
"hidden_w": payload["hidden_w"],
"hidden_b": payload["hidden_b"],
"output_w": payload["output_w"],
"output_b": payload["output_b"],
},
sort_keys=True,
separators=(",", ":"),
)
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
def load_weights(path: str | os.PathLike[str] | None = None) -> BitNetWeights:
"""Load the ternary-weights bundle from disk.
Defaults to `engine/bitnet_weights.json` next to this file.
"""
if path is None:
path = Path(__file__).parent / "bitnet_weights.json"
# DEBUG entry — visibility into when the bundle gets parsed
# (rotation, first-load, etc.). Mirrors the fda_label_search_start
# convention (PHI-safe: only path basename + size).
raw = Path(path).read_text(encoding="utf-8")
logger.debug(
"bitnet_load_weights_start",
extra={
"path_basename": Path(path).name,
"raw_size_bytes": len(raw),
},
)
payload = json.loads(raw)
hidden_w = [list(row) for row in payload["hidden_w"]]
hidden_b = list(payload["hidden_b"])
output_w = [list(row) for row in payload["output_w"]]
output_b = list(payload["output_b"])
meta = payload.get("_meta", {})
schema = meta.get("schema", _SCHEMA_V1)
if schema not in (_SCHEMA_V1, _SCHEMA_V3_ATC):
logger.error(
"bitnet_weights_unknown_schema",
extra={"schema": schema, "path": str(path)},
)
raise ValueError(
f"Unknown bitnet schema {schema!r}; expected one of "
f"{_SCHEMA_V1!r}, {_SCHEMA_V3_ATC!r}"
)
hidden_features = len(hidden_w)
in_features = len(hidden_w[0]) if hidden_w else 0
out_features = len(output_w)
meta_in = meta.get("in_features", in_features)
meta_hidden = meta.get("hidden_features", hidden_features)
meta_out = meta.get("out_features", out_features)
for field, observed, declared in (
("in_features", in_features, meta_in),
("hidden_features", hidden_features, meta_hidden),
("out_features", out_features, meta_out),
):
if observed != declared:
logger.error(
"bitnet_weights_meta_mismatch",
extra={
"field": field,
"matrix_dim": observed,
"meta_dim": declared,
"path": str(path),
},
)
raise ValueError(
f"{field}: matrix dim {observed} != _meta declaration {declared}"
)
if any(len(row) != in_features for row in hidden_w):
logger.error(
"bitnet_weights_shape_mismatch",
extra={
"field": "hidden_w",
"expected_cols": in_features,
"path": str(path),
},
)
raise ValueError(
f"hidden_w rows must all be length {in_features} (drug-pair feature dim)"
)
if len(hidden_b) != hidden_features:
logger.error(
"bitnet_weights_shape_mismatch",
extra={
"field": "hidden_b",
"expected_len": hidden_features,
"actual_len": len(hidden_b),
"path": str(path),
},
)
raise ValueError(
f"hidden_b must have {hidden_features} entries; got {len(hidden_b)}"
)
if out_features != 5:
logger.error(
"bitnet_weights_shape_mismatch",
extra={
"field": "output_w",
"expected_rows": 5,
"actual_rows": out_features,
"path": str(path),
},
)
raise ValueError(
f"output_w must have 5 rows (one per severity class); got {out_features}"
)
if any(len(row) != hidden_features for row in output_w):
logger.error(
"bitnet_weights_shape_mismatch",
extra={
"field": "output_w",
"expected_cols": hidden_features,
"path": str(path),
},
)
raise ValueError(
f"output_w rows must all be length {hidden_features} (hidden dim)"
)
if len(output_b) != out_features:
logger.error(
"bitnet_weights_shape_mismatch",
extra={
"field": "output_b",
"expected_len": out_features,
"actual_len": len(output_b),
"path": str(path),
},
)
raise ValueError(f"output_b must have {out_features} entries; got {len(output_b)}")
expected_in = 128 if schema == _SCHEMA_V1 else 193
if in_features != expected_in:
logger.error(
"bitnet_weights_schema_dim_mismatch",
extra={
"schema": schema,
"expected_in_features": expected_in,
"actual_in_features": in_features,
"path": str(path),
},
)
raise ValueError(
f"schema {schema!r} expects in_features={expected_in}, got {in_features}"
)
for matrix_name, matrix in (("hidden_w", hidden_w), ("output_w", output_w)):
for i, row in enumerate(matrix):
for j, weight in enumerate(row):
if weight not in (-1, 0, 1):
logger.error(
"bitnet_weights_non_ternary",
extra={"matrix": matrix_name, "row": i, "col": j, "value": weight},
)
raise ValueError(
f"{matrix_name}[{i}][{j}] = {weight!r}; weights must be ternary"
)
weights = BitNetWeights(
hidden_w=hidden_w,
hidden_b=hidden_b,
output_w=output_w,
output_b=output_b,
bundle_id=_bundle_id(payload),
schema=schema,
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
)
logger.info(
"bitnet_weights_loaded",
extra={
"bundle_id": weights.bundle_id,
"path": str(path),
"schema": schema,
"in_features": in_features,
"hidden_features": hidden_features,
},
)
return weights
# ─── Forward pass ──────────────────────────────────────────────────────────
def load_weights_b(path: str | os.PathLike[str] | None = None) -> BitNetWeights | None:
"""Load the optional Path B tier-2 specialist bundle (iter-421).
Returns None if the bundle file is absent — callers must treat single-
bundle mode (A-only) as the default. The specialist is trained ONLY on
the 95 non-contra samples (4 major + 69 serious + 22 moderate); engine
dispatch applies a constrained argmax over {moderate, serious, major}
so it can never emit ``contraindicated`` (class 4) or ``none`` (class 0).
"""
if path is None:
path = Path(__file__).parent / "bitnet_weights_b_specialist.json"
p = Path(path)
if not p.exists():
return None
return load_weights(p)
def _classify_constrained_b(
a_canonical: str,
b_canonical: str,
weights_b: BitNetWeights,
) -> tuple[int, tuple[int, ...]]:
"""Forward pass through B with constrained argmax over {1, 2, 3}.
Returns ``(severity_int, logits_q16)`` where severity_int is in
{1, 2, 3} = {moderate, serious, major}. Classes 0 (none) and 4
(contraindicated) are masked because B was never trained on them.
The same Q16.16 ternary kernels as ``classify`` are reused so B's
forward pass is bit-identical across architectures alongside A's.
"""
if weights_b.schema == _SCHEMA_V3_ATC:
from engine.bitnet_features_v8 import encode_pair_v8
pair_features = encode_pair_v8(a_canonical, b_canonical)
else:
feature_a = _encode_drug_token(a_canonical)
feature_b = _encode_drug_token(b_canonical)
pair_features = feature_a + feature_b
activations_q16 = _q16_scale_features(pair_features)
hidden_pre_q16 = [
_q16_clamp(_q16_dot_ternary(activations_q16, weights_b.hidden_w[j]) + weights_b.hidden_b[j])
for j in range(weights_b.hidden_features)
]
hidden_q16 = [_q16_relu(v) for v in hidden_pre_q16]
logits_q16 = [
_q16_clamp(_q16_dot_ternary(hidden_q16, weights_b.output_w[k]) + weights_b.output_b[k])
for k in range(weights_b.out_features)
]
# Constrained argmax over classes {1, 2, 3} only. Ties broken by
# lower index (consistent with the unconstrained argmax in classify).
severity = 1
best_logit = logits_q16[1]
for k in (2, 3):
if logits_q16[k] > best_logit:
best_logit = logits_q16[k]
severity = k
return severity, tuple(logits_q16)
def classify(
drug_a: str,
drug_b: str,
weights: BitNetWeights,
*,
deterministic_table_severity: int | None = None,
weights_b: BitNetWeights | None = None,
) -> BitNetResult:
"""Ternary classifier forward pass: drug-pair -> severity class.
The pair is order-canonicalised (lex sort) so {warfarin, ibuprofen} and
{ibuprofen, warfarin} produce the same feature vector and the same
output. The logits and severity decision are reproducible bit-for-bit
on any platform that runs Python.
If `deterministic_table_severity` is provided (a known result from
`engine.clinical_scoring`'s 4-tier deterministic table), the result's
`deterministic_table_match` flag records whether the BitNet output
agrees. Disagreement is a release-blocking event — surfaces in the
`tests/test_engine/test_bitnet_classifier.py` regression set.
"""
a_canonical, b_canonical = sorted((drug_a, drug_b))
if weights.schema == _SCHEMA_V3_ATC:
from engine.bitnet_features_v8 import encode_pair_v8
pair_features = encode_pair_v8(a_canonical, b_canonical)
else:
feature_a = _encode_drug_token(a_canonical)
feature_b = _encode_drug_token(b_canonical)
pair_features = feature_a + feature_b
if len(pair_features) != weights.in_features:
raise RuntimeError(
f"internal error: pair features length {len(pair_features)} != "
f"weights.in_features {weights.in_features}"
)
activations_q16 = _q16_scale_features(pair_features)
feature_hash = hashlib.sha256(
bytes((v + 1) for v in pair_features) # ternary -> {0,1,2}
).hexdigest()
# First linear layer: in_features -> hidden_features
hidden_pre_q16 = [
_q16_clamp(_q16_dot_ternary(activations_q16, weights.hidden_w[j]) + weights.hidden_b[j])
for j in range(weights.hidden_features)
]
hidden_q16 = [_q16_relu(v) for v in hidden_pre_q16]
# Second linear layer: hidden_features -> out_features (5 severity classes)
logits_q16 = [
_q16_clamp(_q16_dot_ternary(hidden_q16, weights.output_w[k]) + weights.output_b[k])
for k in range(weights.out_features)
]
# Argmax — pure integer compare; ties broken by lower-index class.
severity = 0
best_logit = logits_q16[0]
for k in range(1, 5):
if logits_q16[k] > best_logit:
best_logit = logits_q16[k]
severity = k
# iter-421 Path B cascade: when a tier-2 specialist bundle is supplied,
# A's contraindicated verdict ALWAYS wins (frozen FDA-grade contra
# gate, 100% recall + 0 FP). For all non-contra A predictions, B's
# constrained argmax over {moderate, serious, major} replaces A's
# raw argmax. B was trained without contra anchors, so its capacity
# is fully spent on the non-contra discrimination v8 historically
# under-fit (84% serious / 91% moderate standalone).
weights_id_for_audit = weights.bundle_id
logits_q16_b: tuple[int, ...] | None = None
if weights_b is not None and severity != 4:
# 4 = contraindicated; preserve A's contra verdict.
b_severity, logits_q16_b = _classify_constrained_b(
a_canonical, b_canonical, weights_b
)
severity = b_severity
# Audit-chain: composite weights_id captures both bundle hashes
# so a verifier can replay the cascade decision exactly.
weights_id_for_audit = f"{weights.bundle_id}+{weights_b.bundle_id}"
repro_hash_payload = {
"feature_hash": feature_hash,
"logits_q16": logits_q16,
"severity": severity,
"weights_id": weights_id_for_audit,
}
if logits_q16_b is not None:
repro_hash_payload["logits_q16_b"] = list(logits_q16_b)
repro_hash_payload["bundle_id_b"] = weights_b.bundle_id # type: ignore[union-attr]
repro_hash = hashlib.sha256(
json.dumps(
repro_hash_payload,
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
).hexdigest()
deterministic_match = True
if deterministic_table_severity is not None:
deterministic_match = (severity == deterministic_table_severity)
# Audit-grade trace: structured DEBUG log so production INFO-level
# surfaces stay quiet but a reviewer can opt in by raising verbosity.
# iter-309 PHI fix: replace raw drug_a/drug_b with 16-char SHA-256
# pair_hash_prefix (lex-sorted canonical form). Same iter-291 /
# iter-284 / iter-279 PHI discipline class — drug-pair identity stays
# grep-able for forensic correlation but raw names never reach handlers.
# Pre-iter-309 this event leaked drug_a + drug_b on EVERY classification
# (live since the iter-72-era classifier landing); caught by audit
# because both keys are absent from the iter-240 forbidden-extras-keys
# list (which is now extended in iter-309 to catch this regression class).
_pair_hash_prefix = hashlib.sha256(
f"{a_canonical}+{b_canonical}".encode("utf-8")
).hexdigest()[:16]
# iter-432 observability ratchet: categorical `ensemble_path` field
# disambiguates the 3 dispatch states a forensic reader otherwise
# has to reverse-engineer from `weights_id` length + `ensemble_active`:
# - "cascade_fired" : A predicted non-contra AND B was loaded;
# B's constrained argmax replaced A's class.
# - "a_only_contra_veto" : A predicted contra (severity=4); B was
# available but bypassed by the safety
# contract (A's contra ALWAYS wins).
# - "a_only_no_b" : B was not loaded (single-bundle mode);
# ensemble cascade unreachable for this
# classification regardless of A's output.
# Strict subset of the existing `ensemble_active` bool — preserved
# alongside for backwards compat with parsers built pre-iter-432.
if logits_q16_b is not None:
_ensemble_path = "cascade_fired"
elif weights_b is None:
_ensemble_path = "a_only_no_b"
else:
# weights_b supplied AND severity == 4 (contra) AND no logits_q16_b:
# cascade was bypassed by the contra-veto safety contract.
_ensemble_path = "a_only_contra_veto"
logger.debug(
"bitnet_classified",
extra={
"pair_hash_prefix": _pair_hash_prefix,
"severity": severity,
"severity_name": _SEVERITY_NAMES[severity],
"repro_hash": repro_hash,
"weights_id": weights_id_for_audit,
"deterministic_match": deterministic_match,
"ensemble_active": logits_q16_b is not None,
"ensemble_path": _ensemble_path,
},
)
return BitNetResult(
severity=severity,
severity_name=_SEVERITY_NAMES[severity],
logits_q16=tuple(logits_q16),
feature_hash=feature_hash,
repro_hash=repro_hash,
weights_id=weights_id_for_audit,
deterministic_table_match=deterministic_match,
)
# ─── Convenience wrapper for the consensus pipeline ────────────────────────
import threading
_CACHED_WEIGHTS: BitNetWeights | None = None
_PINNED_BUNDLE_ID: str | None = None
# iter-421 Path B: tier-2 specialist cache + pin (parallel to A's cache).
# When the bundle file is absent the slot stays None and the engine falls
# back to single-bundle mode automatically.
_CACHED_WEIGHTS_B: BitNetWeights | None = None
_PINNED_BUNDLE_ID_B: str | None = None
_B_LOAD_ATTEMPTED: bool = False
_CACHE_LOCK = threading.Lock()
class WeightsTamperError(RuntimeError):
"""Raised when the on-disk weights bundle's bundle_id no longer matches
the value pinned at first load. Indicates the file was swapped under
the running process — a release-blocking integrity violation."""
def reload_weights() -> BitNetWeights:
"""Force a fresh load + re-pin. Use after a confirmed weights rotation."""
global _CACHED_WEIGHTS, _PINNED_BUNDLE_ID
global _CACHED_WEIGHTS_B, _PINNED_BUNDLE_ID_B, _B_LOAD_ATTEMPTED
with _CACHE_LOCK:
previous_id = _PINNED_BUNDLE_ID
previous_id_b = _PINNED_BUNDLE_ID_B
weights = load_weights()
_CACHED_WEIGHTS = weights
_PINNED_BUNDLE_ID = weights.bundle_id
# iter-421 Path B: re-pin tier-2 specialist alongside A. If the
# bundle disappears between rotations, ensemble drops to A-only.
_B_LOAD_ATTEMPTED = True
_CACHED_WEIGHTS_B = load_weights_b()
_PINNED_BUNDLE_ID_B = (
_CACHED_WEIGHTS_B.bundle_id if _CACHED_WEIGHTS_B is not None else None
)
logger.warning(
"bitnet_weights_reloaded",
extra={
"previous_bundle_id": previous_id,
"new_bundle_id": weights.bundle_id,
"previous_bundle_id_b": previous_id_b,
"new_bundle_id_b": _PINNED_BUNDLE_ID_B,
},
)
return weights
def classifier_layer(drug_a: str, drug_b: str) -> BitNetResult:
"""Layer-4.5 entry point used by `engine.consensus_engine`.
The first call loads the weights bundle and pins its `bundle_id`.
**Every subsequent call re-loads the bundle and verifies the
`bundle_id` still matches** — an inexpensive SHA-256 compare that
closes the security gap where a tampered `bitnet_weights.json` swapped
on disk would silently produce wrong severity verdicts for the entire
process lifetime. A mismatch raises `WeightsTamperError`, which the
pipeline must treat as release-blocking.
The cache is guarded by a `threading.Lock` so high-throughput clinical
deployments (10K+ pairs/min, multi-threaded) cannot trigger the
thundering-herd race that would otherwise re-parse the JSON on every
contended call.
"""
global _CACHED_WEIGHTS, _PINNED_BUNDLE_ID
global _CACHED_WEIGHTS_B, _PINNED_BUNDLE_ID_B, _B_LOAD_ATTEMPTED
with _CACHE_LOCK:
if _CACHED_WEIGHTS is None:
_CACHED_WEIGHTS = load_weights()
_PINNED_BUNDLE_ID = _CACHED_WEIGHTS.bundle_id
# iter-421 Path B: opportunistic tier-2 load + pin (audit-clean
# — same SHA-256-canonical-JSON integrity primitive as A). Absent
# bundle leaves the slot None and engine falls back to A-only.
if not _B_LOAD_ATTEMPTED:
_B_LOAD_ATTEMPTED = True
_CACHED_WEIGHTS_B = load_weights_b()
if _CACHED_WEIGHTS_B is not None:
_PINNED_BUNDLE_ID_B = _CACHED_WEIGHTS_B.bundle_id
logger.info(
"bitnet_classifier_b_load_pinned",
extra={
"bundle_id_b_prefix": _PINNED_BUNDLE_ID_B[:16],
"hidden_features_b": len(_CACHED_WEIGHTS_B.hidden_w),
},
)
# First-load pinning event — fires ONCE per process. Lets
# auditors correlate every BitNetResult emitted in the
# process to the bundle_id that was pinned at startup.
# bundle_id is SHA-256-of-canonical-JSON, NOT secret material;
# safe to log in full (it IS the integrity primitive).
logger.info(
"bitnet_classifier_first_load_pinned",
extra={
"bundle_id_prefix": _PINNED_BUNDLE_ID[:16],
"hidden_features": len(_CACHED_WEIGHTS.hidden_w),
"in_features": (
len(_CACHED_WEIGHTS.hidden_w[0])
if _CACHED_WEIGHTS.hidden_w else 0
),
"out_features": len(_CACHED_WEIGHTS.output_w),
"ensemble_active": _CACHED_WEIGHTS_B is not None,
},
)
else:
# Re-load + verify the pinned bundle_id on every call. The full
# JSON parse is ~1 ms; the alternative is a class of FDA-blocking
# silent-tampering bugs.
current = load_weights()
if current.bundle_id != _PINNED_BUNDLE_ID:
# Pre-raise structured CRITICAL — this is a release-blocking
# FDA SaMD integrity violation. Bundle IDs are SHA-256
# prefixes of the canonical-JSON weights file, safe to log
# (they ARE the integrity primitive, not secret material).
logger.critical(
"bitnet_weights_tamper_detected",
extra={
"pinned_bundle_id": (
_PINNED_BUNDLE_ID[:16] if _PINNED_BUNDLE_ID else None
),
"on_disk_bundle_id": current.bundle_id[:16],
},
)
raise WeightsTamperError(
f"bitnet_weights.json bundle_id changed under the running "
f"process: pinned {_PINNED_BUNDLE_ID[:16]}... "
f"on-disk {current.bundle_id[:16]}... — call "
f"reload_weights() after a deliberate rotation."
)
_CACHED_WEIGHTS = current
# iter-421 Path B: same tamper check on B if it was pinned.
if _PINNED_BUNDLE_ID_B is not None:
current_b = load_weights_b()
if current_b is None or current_b.bundle_id != _PINNED_BUNDLE_ID_B:
logger.critical(
"bitnet_weights_b_tamper_detected",
extra={
"pinned_bundle_id_b": _PINNED_BUNDLE_ID_B[:16],
"on_disk_bundle_id_b": (
current_b.bundle_id[:16] if current_b else None
),
},
)
raise WeightsTamperError(
f"bitnet_weights_b_specialist.json bundle_id changed under "
f"the running process: pinned {_PINNED_BUNDLE_ID_B[:16]}... "
f"— call reload_weights() after a deliberate rotation."
)
_CACHED_WEIGHTS_B = current_b
return classify(drug_a, drug_b, _CACHED_WEIGHTS, weights_b=_CACHED_WEIGHTS_B)
__all__ = [
"BitNetResult",
"BitNetWeights",
"Q16_ONE",
"Q16_HALF",
"Q16_ZERO",
"SEVERITY_NONE",
"SEVERITY_MINOR",
"SEVERITY_MODERATE",
"SEVERITY_MAJOR",
"SEVERITY_CONTRAINDICATED",
"WeightsTamperError",
"classify",
"classifier_layer",
"load_weights",
"load_weights_b",
"reload_weights",
]