phanerozoic's picture
Rich README + certify/attribute methods, design.py, bundled 8-mer atlas
6f4134a verified
"""Closed-form inverse design against the classifier's linear heads.
Because each head is linear in 8-mer counts, sequences that maximize or minimize a
head score can be found by coordinate ascent. Two modes:
- free: design a sequence from scratch toward maximum or minimum score
- synonymous: re-choose codons of a given coding sequence to push the score while
preserving the encoded protein
Usage:
from design import free_design, synonymous_design
from model import DnaOriginClassifier
clf = DnaOriginClassifier("model.safetensors")
seq = free_design(clf, length=300, direction="max") # maximally host-like
recoded = synonymous_design(clf, cds, direction="min") # least host-like, same protein
"""
import random
import itertools
BASES = "ACGT"
_CODON = {}
_aa = "KNKNTTTTRSRSIIMIQHQHPPPPRRRRLLLLEDEDAAAAGGGGVVVV*Y*YSSSSLFLF*CWCLFLF" # standard code, see below
# build the standard genetic code explicitly to avoid ordering ambiguity
_BASES4 = "TCAG"
_AAS = ("FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG")
_i = 0
for a in _BASES4:
for b in _BASES4:
for c in _BASES4:
_CODON[a + b + c] = _AAS[_i]; _i += 1
_AA2COD = {}
for cod, aa in _CODON.items():
_AA2COD.setdefault(aa, []).append(cod)
def _protein(seq):
return "".join(_CODON.get(seq[i:i + 3], "X") for i in range(0, len(seq) - 2, 3))
def free_design(clf, length=300, direction="max", passes=6, seed=0):
rng = random.Random(seed)
seq = [rng.choice(BASES) for _ in range(length)]
d = 1 if direction == "max" else -1
for _ in range(passes):
for p in range(length):
best, bs = seq[p], clf.host_score("".join(seq))
for nb in BASES:
seq[p] = nb
sc = clf.host_score("".join(seq))
if d * sc > d * bs:
bs, best = sc, nb
seq[p] = best
return "".join(seq)
def synonymous_design(clf, cds, direction="max", passes=4):
cod = [cds[i:i + 3] for i in range(0, len(cds) - 2, 3)]
d = 1 if direction == "max" else -1
for _ in range(passes):
for ci in range(len(cod)):
aa = _CODON.get(cod[ci])
if aa not in _AA2COD:
continue
best, bs = cod[ci], clf.host_score("".join(cod))
for cand in _AA2COD[aa]:
cod[ci] = cand
sc = clf.host_score("".join(cod))
if d * sc > d * bs:
bs, best = sc, cand
cod[ci] = best
out = "".join(cod)
assert _protein(out) == _protein(cds), "protein not preserved"
return out
if __name__ == "__main__":
from model import DnaOriginClassifier
clf = DnaOriginClassifier()
mx = free_design(clf, 300, "max")
mn = free_design(clf, 300, "min")
print("max-host design score:", round(clf.host_score(mx), 2))
print("min-host design score:", round(clf.host_score(mn), 2))