"""Closed-form inverse design against the classifier's linear heads. Because each head is linear in 8-mer counts, sequences that maximize or minimize a head score can be found by coordinate ascent. Two modes: - free: design a sequence from scratch toward maximum or minimum score - synonymous: re-choose codons of a given coding sequence to push the score while preserving the encoded protein Usage: from design import free_design, synonymous_design from model import DnaOriginClassifier clf = DnaOriginClassifier("model.safetensors") seq = free_design(clf, length=300, direction="max") # maximally host-like recoded = synonymous_design(clf, cds, direction="min") # least host-like, same protein """ import random import itertools BASES = "ACGT" _CODON = {} _aa = "KNKNTTTTRSRSIIMIQHQHPPPPRRRRLLLLEDEDAAAAGGGGVVVV*Y*YSSSSLFLF*CWCLFLF" # standard code, see below # build the standard genetic code explicitly to avoid ordering ambiguity _BASES4 = "TCAG" _AAS = ("FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG") _i = 0 for a in _BASES4: for b in _BASES4: for c in _BASES4: _CODON[a + b + c] = _AAS[_i]; _i += 1 _AA2COD = {} for cod, aa in _CODON.items(): _AA2COD.setdefault(aa, []).append(cod) def _protein(seq): return "".join(_CODON.get(seq[i:i + 3], "X") for i in range(0, len(seq) - 2, 3)) def free_design(clf, length=300, direction="max", passes=6, seed=0): rng = random.Random(seed) seq = [rng.choice(BASES) for _ in range(length)] d = 1 if direction == "max" else -1 for _ in range(passes): for p in range(length): best, bs = seq[p], clf.host_score("".join(seq)) for nb in BASES: seq[p] = nb sc = clf.host_score("".join(seq)) if d * sc > d * bs: bs, best = sc, nb seq[p] = best return "".join(seq) def synonymous_design(clf, cds, direction="max", passes=4): cod = [cds[i:i + 3] for i in range(0, len(cds) - 2, 3)] d = 1 if direction == "max" else -1 for _ in range(passes): for ci in range(len(cod)): aa = _CODON.get(cod[ci]) if aa not in _AA2COD: continue best, bs = cod[ci], clf.host_score("".join(cod)) for cand in _AA2COD[aa]: cod[ci] = cand sc = clf.host_score("".join(cod)) if d * sc > d * bs: bs, best = sc, cand cod[ci] = best out = "".join(cod) assert _protein(out) == _protein(cds), "protein not preserved" return out if __name__ == "__main__": from model import DnaOriginClassifier clf = DnaOriginClassifier() mx = free_design(clf, 300, "max") mn = free_design(clf, 300, "min") print("max-host design score:", round(clf.host_score(mx), 2)) print("min-host design score:", round(clf.host_score(mn), 2))