import numpy as np from safetensors.numpy import load_file K = 8 BASES = "ACGT" CLASSES = ["human", "eukaryote", "bacteria", "virus", "engineered"] _B = {b: i for i, b in enumerate(BASES)} VOCAB = 4 ** K def _index(kmer): v = 0 for c in kmer: b = _B.get(c) if b is None: return None v = v * 4 + b return v class DnaOriginClassifier: """Discriminative 8-mer classifier of DNA origin, with exact closed-form interpretability and robustness because the model is linear in 8-mer counts. A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence frequency; three discriminatively trained linear heads read it: a 5-class origin head and two binary detectors (host vs non-host, engineered vs natural). No alignment, no database. Requires only numpy and safetensors. Beyond classify/host_score/engineered_score, the linear form gives: - attribute(seq): exact per-base contribution to a head (sums to the score) - certify(seq): minimum base substitutions to flip a call (greedy, exact deltas) """ def __init__(self, path="model.safetensors"): t = load_file(path) self.scale = t["feature_scale"] self.OW, self.Ob = t["origin.weight"], t["origin.bias"] self.HW, self.Hb = t["host.weight"], t["host.bias"] self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"] # ---- core ---- def features(self, seq): seq = "".join(c for c in seq.upper() if c in _B) v = np.zeros(VOCAB, dtype=np.float32) for i in range(len(seq) - K + 1): j = _index(seq[i:i + K]) if j is not None: v[j] += 1.0 s = v.sum() if s: v /= s return v / self.scale def logits(self, seq): return self.OW @ self.features(seq) + self.Ob def classify(self, seq): """Return the most likely origin: human, eukaryote, bacteria, virus, or engineered.""" return CLASSES[int(self.logits(seq).argmax())] def host_score(self, seq): """Higher means more human/host-like (host vs non-host head).""" return float(self.HW @ self.features(seq) + self.Hb[0]) def engineered_score(self, seq): """Higher means more likely engineered/synthetic (engineered vs natural head).""" return float(self.EW @ self.features(seq) + self.Eb[0]) # ---- closed-form interpretability and robustness ---- def _eff(self, head): w = {"host": self.HW, "engineered": self.EW}[head] return w / self.scale def _bias(self, head): return float({"host": self.Hb, "engineered": self.Eb}[head][0]) def attribute(self, seq, head="host"): """Exact per-base contribution of each position to the head score. The score is a sum over 8-mer windows; this distributes each window's weight across its 8 bases, so the contributions sum to (score - bias) with no approximation. Returns an array of length len(seq). """ seq = "".join(c for c in seq.upper() if c in _B) w = self._eff(head) n = max(1, len(seq) - K + 1) contrib = np.zeros(len(seq)) for i in range(len(seq) - K + 1): j = _index(seq[i:i + K]) if j is None: continue per = w[j] / n / K contrib[i:i + K] += per return contrib def certify(self, seq, head="host", max_edits=80): """Minimum base substitutions (greedy, with exact per-edit deltas) to flip the head's sign. Returns the edit count, or None if not flipped within max_edits. A near-tight upper bound on the true minimum adversarial radius. """ seq = [c for c in seq.upper() if c in _B] w = self._eff(head) b = self._bias(head) n = max(1, len(seq) - K + 1) def score(s): tot = 0.0 for i in range(len(s) - K + 1): j = _index(s[i:i + K]) if j is not None: tot += w[j] return tot / n + b sign = 1 if score("".join(seq)) > 0 else -1 edits = 0 while sign * score("".join(seq)) > 0 and edits < max_edits: s = "".join(seq) best_d, best = 0.0, None for p in range(len(seq)): wins = range(max(0, p - K + 1), min(p, n - 1) + 1) old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None) for nb in BASES: if nb == seq[p]: continue s2 = s[:p] + nb + s[p + 1:] new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None) d = (new - old) / n if sign * d < best_d: best_d, best = sign * d, (p, nb) if best is None: break seq[best[0]] = best[1] edits += 1 return edits if sign * score("".join(seq)) <= 0 else None if __name__ == "__main__": clf = DnaOriginClassifier() seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5 print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3), "edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4))