phanerozoic's picture
Rich README + certify/attribute methods, design.py, bundled 8-mer atlas
7f7d423 verified
import numpy as np
from safetensors.numpy import load_file
K = 8
BASES = "ACGT"
CLASSES = ["human", "eukaryote", "bacteria", "virus", "engineered"]
_B = {b: i for i, b in enumerate(BASES)}
VOCAB = 4 ** K
def _index(kmer):
v = 0
for c in kmer:
b = _B.get(c)
if b is None:
return None
v = v * 4 + b
return v
class DnaOriginClassifier:
"""Discriminative 8-mer classifier of DNA origin, with exact closed-form
interpretability and robustness because the model is linear in 8-mer counts.
A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
frequency; three discriminatively trained linear heads read it: a 5-class origin
head and two binary detectors (host vs non-host, engineered vs natural). No
alignment, no database. Requires only numpy and safetensors.
Beyond classify/host_score/engineered_score, the linear form gives:
- attribute(seq): exact per-base contribution to a head (sums to the score)
- certify(seq): minimum base substitutions to flip a call (greedy, exact deltas)
"""
def __init__(self, path="model.safetensors"):
t = load_file(path)
self.scale = t["feature_scale"]
self.OW, self.Ob = t["origin.weight"], t["origin.bias"]
self.HW, self.Hb = t["host.weight"], t["host.bias"]
self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]
# ---- core ----
def features(self, seq):
seq = "".join(c for c in seq.upper() if c in _B)
v = np.zeros(VOCAB, dtype=np.float32)
for i in range(len(seq) - K + 1):
j = _index(seq[i:i + K])
if j is not None:
v[j] += 1.0
s = v.sum()
if s:
v /= s
return v / self.scale
def logits(self, seq):
return self.OW @ self.features(seq) + self.Ob
def classify(self, seq):
"""Return the most likely origin: human, eukaryote, bacteria, virus, or engineered."""
return CLASSES[int(self.logits(seq).argmax())]
def host_score(self, seq):
"""Higher means more human/host-like (host vs non-host head)."""
return float(self.HW @ self.features(seq) + self.Hb[0])
def engineered_score(self, seq):
"""Higher means more likely engineered/synthetic (engineered vs natural head)."""
return float(self.EW @ self.features(seq) + self.Eb[0])
# ---- closed-form interpretability and robustness ----
def _eff(self, head):
w = {"host": self.HW, "engineered": self.EW}[head]
return w / self.scale
def _bias(self, head):
return float({"host": self.Hb, "engineered": self.Eb}[head][0])
def attribute(self, seq, head="host"):
"""Exact per-base contribution of each position to the head score.
The score is a sum over 8-mer windows; this distributes each window's weight
across its 8 bases, so the contributions sum to (score - bias) with no
approximation. Returns an array of length len(seq).
"""
seq = "".join(c for c in seq.upper() if c in _B)
w = self._eff(head)
n = max(1, len(seq) - K + 1)
contrib = np.zeros(len(seq))
for i in range(len(seq) - K + 1):
j = _index(seq[i:i + K])
if j is None:
continue
per = w[j] / n / K
contrib[i:i + K] += per
return contrib
def certify(self, seq, head="host", max_edits=80):
"""Minimum base substitutions (greedy, with exact per-edit deltas) to flip the
head's sign. Returns the edit count, or None if not flipped within max_edits.
A near-tight upper bound on the true minimum adversarial radius.
"""
seq = [c for c in seq.upper() if c in _B]
w = self._eff(head)
b = self._bias(head)
n = max(1, len(seq) - K + 1)
def score(s):
tot = 0.0
for i in range(len(s) - K + 1):
j = _index(s[i:i + K])
if j is not None:
tot += w[j]
return tot / n + b
sign = 1 if score("".join(seq)) > 0 else -1
edits = 0
while sign * score("".join(seq)) > 0 and edits < max_edits:
s = "".join(seq)
best_d, best = 0.0, None
for p in range(len(seq)):
wins = range(max(0, p - K + 1), min(p, n - 1) + 1)
old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None)
for nb in BASES:
if nb == seq[p]:
continue
s2 = s[:p] + nb + s[p + 1:]
new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None)
d = (new - old) / n
if sign * d < best_d:
best_d, best = sign * d, (p, nb)
if best is None:
break
seq[best[0]] = best[1]
edits += 1
return edits if sign * score("".join(seq)) <= 0 else None
if __name__ == "__main__":
clf = DnaOriginClassifier()
seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3),
"edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4))