| """Closed-form inverse design against the classifier's linear heads. |
| |
| Because each head is linear in 8-mer counts, sequences that maximize or minimize a |
| head score can be found by coordinate ascent. Two modes: |
| - free: design a sequence from scratch toward maximum or minimum score |
| - synonymous: re-choose codons of a given coding sequence to push the score while |
| preserving the encoded protein |
| |
| Usage: |
| from design import free_design, synonymous_design |
| from model import DnaOriginClassifier |
| clf = DnaOriginClassifier("model.safetensors") |
| seq = free_design(clf, length=300, direction="max") # maximally host-like |
| recoded = synonymous_design(clf, cds, direction="min") # least host-like, same protein |
| """ |
| import random |
| import itertools |
|
|
| BASES = "ACGT" |
| _CODON = {} |
| _aa = "KNKNTTTTRSRSIIMIQHQHPPPPRRRRLLLLEDEDAAAAGGGGVVVV*Y*YSSSSLFLF*CWCLFLF" |
| |
| _BASES4 = "TCAG" |
| _AAS = ("FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG") |
| _i = 0 |
| for a in _BASES4: |
| for b in _BASES4: |
| for c in _BASES4: |
| _CODON[a + b + c] = _AAS[_i]; _i += 1 |
| _AA2COD = {} |
| for cod, aa in _CODON.items(): |
| _AA2COD.setdefault(aa, []).append(cod) |
|
|
|
|
| def _protein(seq): |
| return "".join(_CODON.get(seq[i:i + 3], "X") for i in range(0, len(seq) - 2, 3)) |
|
|
|
|
| def free_design(clf, length=300, direction="max", passes=6, seed=0): |
| rng = random.Random(seed) |
| seq = [rng.choice(BASES) for _ in range(length)] |
| d = 1 if direction == "max" else -1 |
| for _ in range(passes): |
| for p in range(length): |
| best, bs = seq[p], clf.host_score("".join(seq)) |
| for nb in BASES: |
| seq[p] = nb |
| sc = clf.host_score("".join(seq)) |
| if d * sc > d * bs: |
| bs, best = sc, nb |
| seq[p] = best |
| return "".join(seq) |
|
|
|
|
| def synonymous_design(clf, cds, direction="max", passes=4): |
| cod = [cds[i:i + 3] for i in range(0, len(cds) - 2, 3)] |
| d = 1 if direction == "max" else -1 |
| for _ in range(passes): |
| for ci in range(len(cod)): |
| aa = _CODON.get(cod[ci]) |
| if aa not in _AA2COD: |
| continue |
| best, bs = cod[ci], clf.host_score("".join(cod)) |
| for cand in _AA2COD[aa]: |
| cod[ci] = cand |
| sc = clf.host_score("".join(cod)) |
| if d * sc > d * bs: |
| bs, best = sc, cand |
| cod[ci] = best |
| out = "".join(cod) |
| assert _protein(out) == _protein(cds), "protein not preserved" |
| return out |
|
|
|
|
| if __name__ == "__main__": |
| from model import DnaOriginClassifier |
| clf = DnaOriginClassifier() |
| mx = free_design(clf, 300, "max") |
| mn = free_design(clf, 300, "min") |
| print("max-host design score:", round(clf.host_score(mx), 2)) |
| print("min-host design score:", round(clf.host_score(mn), 2)) |
|
|