| import numpy as np |
| from safetensors.numpy import load_file |
|
|
| K = 8 |
| BASES = "ACGT" |
| CLASSES = ["human", "eukaryote", "bacteria", "virus", "engineered"] |
| _B = {b: i for i, b in enumerate(BASES)} |
| VOCAB = 4 ** K |
|
|
|
|
| def _index(kmer): |
| v = 0 |
| for c in kmer: |
| b = _B.get(c) |
| if b is None: |
| return None |
| v = v * 4 + b |
| return v |
|
|
|
|
| class DnaOriginClassifier: |
| """Discriminative 8-mer classifier of DNA origin, with exact closed-form |
| interpretability and robustness because the model is linear in 8-mer counts. |
| |
| A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence |
| frequency; three discriminatively trained linear heads read it: a 5-class origin |
| head and two binary detectors (host vs non-host, engineered vs natural). No |
| alignment, no database. Requires only numpy and safetensors. |
| |
| Beyond classify/host_score/engineered_score, the linear form gives: |
| - attribute(seq): exact per-base contribution to a head (sums to the score) |
| - certify(seq): minimum base substitutions to flip a call (greedy, exact deltas) |
| """ |
|
|
| def __init__(self, path="model.safetensors"): |
| t = load_file(path) |
| self.scale = t["feature_scale"] |
| self.OW, self.Ob = t["origin.weight"], t["origin.bias"] |
| self.HW, self.Hb = t["host.weight"], t["host.bias"] |
| self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"] |
|
|
| |
| def features(self, seq): |
| seq = "".join(c for c in seq.upper() if c in _B) |
| v = np.zeros(VOCAB, dtype=np.float32) |
| for i in range(len(seq) - K + 1): |
| j = _index(seq[i:i + K]) |
| if j is not None: |
| v[j] += 1.0 |
| s = v.sum() |
| if s: |
| v /= s |
| return v / self.scale |
|
|
| def logits(self, seq): |
| return self.OW @ self.features(seq) + self.Ob |
|
|
| def classify(self, seq): |
| """Return the most likely origin: human, eukaryote, bacteria, virus, or engineered.""" |
| return CLASSES[int(self.logits(seq).argmax())] |
|
|
| def host_score(self, seq): |
| """Higher means more human/host-like (host vs non-host head).""" |
| return float(self.HW @ self.features(seq) + self.Hb[0]) |
|
|
| def engineered_score(self, seq): |
| """Higher means more likely engineered/synthetic (engineered vs natural head).""" |
| return float(self.EW @ self.features(seq) + self.Eb[0]) |
|
|
| |
| def _eff(self, head): |
| w = {"host": self.HW, "engineered": self.EW}[head] |
| return w / self.scale |
|
|
| def _bias(self, head): |
| return float({"host": self.Hb, "engineered": self.Eb}[head][0]) |
|
|
| def attribute(self, seq, head="host"): |
| """Exact per-base contribution of each position to the head score. |
| |
| The score is a sum over 8-mer windows; this distributes each window's weight |
| across its 8 bases, so the contributions sum to (score - bias) with no |
| approximation. Returns an array of length len(seq). |
| """ |
| seq = "".join(c for c in seq.upper() if c in _B) |
| w = self._eff(head) |
| n = max(1, len(seq) - K + 1) |
| contrib = np.zeros(len(seq)) |
| for i in range(len(seq) - K + 1): |
| j = _index(seq[i:i + K]) |
| if j is None: |
| continue |
| per = w[j] / n / K |
| contrib[i:i + K] += per |
| return contrib |
|
|
| def certify(self, seq, head="host", max_edits=80): |
| """Minimum base substitutions (greedy, with exact per-edit deltas) to flip the |
| head's sign. Returns the edit count, or None if not flipped within max_edits. |
| A near-tight upper bound on the true minimum adversarial radius. |
| """ |
| seq = [c for c in seq.upper() if c in _B] |
| w = self._eff(head) |
| b = self._bias(head) |
| n = max(1, len(seq) - K + 1) |
|
|
| def score(s): |
| tot = 0.0 |
| for i in range(len(s) - K + 1): |
| j = _index(s[i:i + K]) |
| if j is not None: |
| tot += w[j] |
| return tot / n + b |
|
|
| sign = 1 if score("".join(seq)) > 0 else -1 |
| edits = 0 |
| while sign * score("".join(seq)) > 0 and edits < max_edits: |
| s = "".join(seq) |
| best_d, best = 0.0, None |
| for p in range(len(seq)): |
| wins = range(max(0, p - K + 1), min(p, n - 1) + 1) |
| old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None) |
| for nb in BASES: |
| if nb == seq[p]: |
| continue |
| s2 = s[:p] + nb + s[p + 1:] |
| new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None) |
| d = (new - old) / n |
| if sign * d < best_d: |
| best_d, best = sign * d, (p, nb) |
| if best is None: |
| break |
| seq[best[0]] = best[1] |
| edits += 1 |
| return edits if sign * score("".join(seq)) <= 0 else None |
|
|
|
|
| if __name__ == "__main__": |
| clf = DnaOriginClassifier() |
| seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5 |
| print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3), |
| "edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4)) |
|
|