File size: 5,433 Bytes
af9e909 b6dcfa1 af9e909 b6dcfa1 af9e909 b6dcfa1 af9e909 7f7d423 af9e909 b6dcfa1 7f7d423 af9e909 7f7d423 af9e909 b6dcfa1 af9e909 7f7d423 af9e909 7f7d423 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import numpy as np
from safetensors.numpy import load_file
K = 8
BASES = "ACGT"
CLASSES = ["human", "eukaryote", "bacteria", "virus", "engineered"]
_B = {b: i for i, b in enumerate(BASES)}
VOCAB = 4 ** K
def _index(kmer):
v = 0
for c in kmer:
b = _B.get(c)
if b is None:
return None
v = v * 4 + b
return v
class DnaOriginClassifier:
"""Discriminative 8-mer classifier of DNA origin, with exact closed-form
interpretability and robustness because the model is linear in 8-mer counts.
A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
frequency; three discriminatively trained linear heads read it: a 5-class origin
head and two binary detectors (host vs non-host, engineered vs natural). No
alignment, no database. Requires only numpy and safetensors.
Beyond classify/host_score/engineered_score, the linear form gives:
- attribute(seq): exact per-base contribution to a head (sums to the score)
- certify(seq): minimum base substitutions to flip a call (greedy, exact deltas)
"""
def __init__(self, path="model.safetensors"):
t = load_file(path)
self.scale = t["feature_scale"]
self.OW, self.Ob = t["origin.weight"], t["origin.bias"]
self.HW, self.Hb = t["host.weight"], t["host.bias"]
self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]
# ---- core ----
def features(self, seq):
seq = "".join(c for c in seq.upper() if c in _B)
v = np.zeros(VOCAB, dtype=np.float32)
for i in range(len(seq) - K + 1):
j = _index(seq[i:i + K])
if j is not None:
v[j] += 1.0
s = v.sum()
if s:
v /= s
return v / self.scale
def logits(self, seq):
return self.OW @ self.features(seq) + self.Ob
def classify(self, seq):
"""Return the most likely origin: human, eukaryote, bacteria, virus, or engineered."""
return CLASSES[int(self.logits(seq).argmax())]
def host_score(self, seq):
"""Higher means more human/host-like (host vs non-host head)."""
return float(self.HW @ self.features(seq) + self.Hb[0])
def engineered_score(self, seq):
"""Higher means more likely engineered/synthetic (engineered vs natural head)."""
return float(self.EW @ self.features(seq) + self.Eb[0])
# ---- closed-form interpretability and robustness ----
def _eff(self, head):
w = {"host": self.HW, "engineered": self.EW}[head]
return w / self.scale
def _bias(self, head):
return float({"host": self.Hb, "engineered": self.Eb}[head][0])
def attribute(self, seq, head="host"):
"""Exact per-base contribution of each position to the head score.
The score is a sum over 8-mer windows; this distributes each window's weight
across its 8 bases, so the contributions sum to (score - bias) with no
approximation. Returns an array of length len(seq).
"""
seq = "".join(c for c in seq.upper() if c in _B)
w = self._eff(head)
n = max(1, len(seq) - K + 1)
contrib = np.zeros(len(seq))
for i in range(len(seq) - K + 1):
j = _index(seq[i:i + K])
if j is None:
continue
per = w[j] / n / K
contrib[i:i + K] += per
return contrib
def certify(self, seq, head="host", max_edits=80):
"""Minimum base substitutions (greedy, with exact per-edit deltas) to flip the
head's sign. Returns the edit count, or None if not flipped within max_edits.
A near-tight upper bound on the true minimum adversarial radius.
"""
seq = [c for c in seq.upper() if c in _B]
w = self._eff(head)
b = self._bias(head)
n = max(1, len(seq) - K + 1)
def score(s):
tot = 0.0
for i in range(len(s) - K + 1):
j = _index(s[i:i + K])
if j is not None:
tot += w[j]
return tot / n + b
sign = 1 if score("".join(seq)) > 0 else -1
edits = 0
while sign * score("".join(seq)) > 0 and edits < max_edits:
s = "".join(seq)
best_d, best = 0.0, None
for p in range(len(seq)):
wins = range(max(0, p - K + 1), min(p, n - 1) + 1)
old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None)
for nb in BASES:
if nb == seq[p]:
continue
s2 = s[:p] + nb + s[p + 1:]
new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None)
d = (new - old) / n
if sign * d < best_d:
best_d, best = sign * d, (p, nb)
if best is None:
break
seq[best[0]] = best[1]
edits += 1
return edits if sign * score("".join(seq)) <= 0 else None
if __name__ == "__main__":
clf = DnaOriginClassifier()
seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3),
"edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4))
|