File size: 5,433 Bytes
af9e909
 
 
b6dcfa1
af9e909
 
b6dcfa1
 
af9e909
 
b6dcfa1
 
 
 
 
 
 
 
af9e909
 
 
7f7d423
 
af9e909
b6dcfa1
 
 
 
7f7d423
 
 
 
af9e909
 
 
 
 
 
 
 
 
7f7d423
af9e909
b6dcfa1
 
 
 
 
 
 
 
 
af9e909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f7d423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af9e909
 
 
 
7f7d423
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
from safetensors.numpy import load_file

K = 8
BASES = "ACGT"
CLASSES = ["human", "eukaryote", "bacteria", "virus", "engineered"]
_B = {b: i for i, b in enumerate(BASES)}
VOCAB = 4 ** K


def _index(kmer):
    v = 0
    for c in kmer:
        b = _B.get(c)
        if b is None:
            return None
        v = v * 4 + b
    return v


class DnaOriginClassifier:
    """Discriminative 8-mer classifier of DNA origin, with exact closed-form
    interpretability and robustness because the model is linear in 8-mer counts.

    A fixed featurizer counts all 65,536 8-mers and normalizes to within-sequence
    frequency; three discriminatively trained linear heads read it: a 5-class origin
    head and two binary detectors (host vs non-host, engineered vs natural). No
    alignment, no database. Requires only numpy and safetensors.

    Beyond classify/host_score/engineered_score, the linear form gives:
      - attribute(seq): exact per-base contribution to a head (sums to the score)
      - certify(seq):   minimum base substitutions to flip a call (greedy, exact deltas)
    """

    def __init__(self, path="model.safetensors"):
        t = load_file(path)
        self.scale = t["feature_scale"]
        self.OW, self.Ob = t["origin.weight"], t["origin.bias"]
        self.HW, self.Hb = t["host.weight"], t["host.bias"]
        self.EW, self.Eb = t["engineered.weight"], t["engineered.bias"]

    # ---- core ----
    def features(self, seq):
        seq = "".join(c for c in seq.upper() if c in _B)
        v = np.zeros(VOCAB, dtype=np.float32)
        for i in range(len(seq) - K + 1):
            j = _index(seq[i:i + K])
            if j is not None:
                v[j] += 1.0
        s = v.sum()
        if s:
            v /= s
        return v / self.scale

    def logits(self, seq):
        return self.OW @ self.features(seq) + self.Ob

    def classify(self, seq):
        """Return the most likely origin: human, eukaryote, bacteria, virus, or engineered."""
        return CLASSES[int(self.logits(seq).argmax())]

    def host_score(self, seq):
        """Higher means more human/host-like (host vs non-host head)."""
        return float(self.HW @ self.features(seq) + self.Hb[0])

    def engineered_score(self, seq):
        """Higher means more likely engineered/synthetic (engineered vs natural head)."""
        return float(self.EW @ self.features(seq) + self.Eb[0])

    # ---- closed-form interpretability and robustness ----
    def _eff(self, head):
        w = {"host": self.HW, "engineered": self.EW}[head]
        return w / self.scale

    def _bias(self, head):
        return float({"host": self.Hb, "engineered": self.Eb}[head][0])

    def attribute(self, seq, head="host"):
        """Exact per-base contribution of each position to the head score.

        The score is a sum over 8-mer windows; this distributes each window's weight
        across its 8 bases, so the contributions sum to (score - bias) with no
        approximation. Returns an array of length len(seq).
        """
        seq = "".join(c for c in seq.upper() if c in _B)
        w = self._eff(head)
        n = max(1, len(seq) - K + 1)
        contrib = np.zeros(len(seq))
        for i in range(len(seq) - K + 1):
            j = _index(seq[i:i + K])
            if j is None:
                continue
            per = w[j] / n / K
            contrib[i:i + K] += per
        return contrib

    def certify(self, seq, head="host", max_edits=80):
        """Minimum base substitutions (greedy, with exact per-edit deltas) to flip the
        head's sign. Returns the edit count, or None if not flipped within max_edits.
        A near-tight upper bound on the true minimum adversarial radius.
        """
        seq = [c for c in seq.upper() if c in _B]
        w = self._eff(head)
        b = self._bias(head)
        n = max(1, len(seq) - K + 1)

        def score(s):
            tot = 0.0
            for i in range(len(s) - K + 1):
                j = _index(s[i:i + K])
                if j is not None:
                    tot += w[j]
            return tot / n + b

        sign = 1 if score("".join(seq)) > 0 else -1
        edits = 0
        while sign * score("".join(seq)) > 0 and edits < max_edits:
            s = "".join(seq)
            best_d, best = 0.0, None
            for p in range(len(seq)):
                wins = range(max(0, p - K + 1), min(p, n - 1) + 1)
                old = sum(w[_index(s[a:a + K])] for a in wins if _index(s[a:a + K]) is not None)
                for nb in BASES:
                    if nb == seq[p]:
                        continue
                    s2 = s[:p] + nb + s[p + 1:]
                    new = sum(w[_index(s2[a:a + K])] for a in wins if _index(s2[a:a + K]) is not None)
                    d = (new - old) / n
                    if sign * d < best_d:
                        best_d, best = sign * d, (p, nb)
            if best is None:
                break
            seq[best[0]] = best[1]
            edits += 1
        return edits if sign * score("".join(seq)) <= 0 else None


if __name__ == "__main__":
    clf = DnaOriginClassifier()
    seq = "ATGGCTAGCAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGTGATGTT" * 5
    print("origin:", clf.classify(seq), "host_score:", round(clf.host_score(seq), 3),
          "edits_to_flip:", clf.certify(seq), "top_base_contrib:", round(float(clf.attribute(seq).max()), 4))