dna-origin-classifier / boundary.py
phanerozoic's picture
Real foundation model (DNABERT) gain-vs-D, field-wide sweep, protein/RNA parity; boundary.py multiclass+regression
f6775f9 verified
"""The composition-solvability criterion.
Given two classes of sequences, D = -N ln rho measures whether their k-mer
counts can separate them, where rho is the Bhattacharyya overlap of the
class-conditional k-mer distributions and N is the per-sequence k-mer count.
D is read from the data with no classifier and no training.
D near 0 the classes share their k-mer distribution; no composition method
exceeds chance and a context model is required (the deep regime).
D large the k-mer counts are a sufficient statistic; a linear model reaches
the Bayes rate (the closed-form regime).
The threshold is near D = 1.4. What D controls is proved in CompositionBoundary.v
(BC_N = rho^N, the Bayes-error bound, the dichotomy). rho is bias-corrected
against a matched same-distribution null, so two finite samples of one
distribution score D near 0. Requires only numpy.
"""
import math
from itertools import product
import numpy as np
def _index_map(alphabet, k):
return {"".join(p): i for i, p in enumerate(product(alphabet, repeat=k))}
def _theta(seqs, k, idx, vsize, alpha=0.5):
counts = np.full(vsize, alpha)
lengths = []
for s in seqs:
s = s.upper(); n = 0
for i in range(len(s) - k + 1):
j = idx.get(s[i:i + k])
if j is not None:
counts[j] += 1; n += 1
if n:
lengths.append(n)
return counts / counts.sum(), (float(np.median(lengths)) if lengths else 0.0)
def _overlap(a, b):
return float(np.sum(np.sqrt(a * b)))
def boundary_distance(class_a, class_b, alphabet="ACGT", k=6, splits=6, seed=0):
"""Effective Bhattacharyya distance D between two classes of sequences.
Returns (D, info); info carries the between-class and matched-null overlaps,
the bias-corrected overlap, N, and the predicted regime.
"""
rng = np.random.RandomState(seed)
idx = _index_map(alphabet, k); vsize = len(alphabet) ** k
a = list(class_a); b = list(class_b)
between, null, ns = [], [], []
for _ in range(splits):
rng.shuffle(a); rng.shuffle(b)
a0, a1 = a[:len(a) // 2], a[len(a) // 2:]
b0, b1 = b[:len(b) // 2], b[len(b) // 2:]
ta0, na = _theta(a0, k, idx, vsize); ta1, _ = _theta(a1, k, idx, vsize)
tb0, nb = _theta(b0, k, idx, vsize); tb1, _ = _theta(b1, k, idx, vsize)
between.append(_overlap(ta0, tb0))
null.append(math.sqrt(_overlap(ta0, ta1) * _overlap(tb0, tb1)))
ns.append((na + nb) / 2)
rho_between = float(np.mean(between)); rho_null = float(np.mean(null)); N = round(float(np.mean(ns)))
rho = min(1.0, rho_between / rho_null) if rho_null > 0 else 1.0
D = -N * math.log(rho) if 0 < rho < 1 else (0.0 if rho >= 1 else float("inf"))
return D, {"rho_between": round(rho_between, 4), "rho_null": round(rho_null, 4),
"rho_corrected": round(rho, 4), "N": N,
"regime": "composition-solvable" if D >= 1.4 else "context-required"}
def boundary_distance_multiclass(classes, alphabet="ACGT", k=6, splits=6, seed=0):
"""Minimum pairwise D over all class pairs.
`classes` is a dict label -> sequences. A task is composition-solvable only
if every pair of classes is separable, so the minimum pairwise D governs.
Returns (D_min, info) with the per-pair distances.
"""
from itertools import combinations
labels = list(classes)
pairs = {}
for a, b in combinations(labels, 2):
D, _ = boundary_distance(classes[a], classes[b], alphabet, k, splits, seed)
pairs[f"{a}|{b}"] = round(D, 2)
d_min = min(pairs.values())
return d_min, {"pairwise_D": pairs,
"regime": "composition-solvable" if d_min >= 1.4 else "context-required"}
def boundary_distance_regression(seqs, values, alphabet="ACGT", k=6, splits=6, seed=0):
"""Binarize a continuous target at its median, then measure D between the
high and low halves."""
v = np.asarray(values, dtype=float)
med = float(np.median(v))
hi = [s for s, x in zip(seqs, v) if x > med]
lo = [s for s, x in zip(seqs, v) if x <= med]
return boundary_distance(hi, lo, alphabet, k, splits, seed)
if __name__ == "__main__":
rng = np.random.RandomState(0)
def draw(gc, n=400, L=300):
p = [(1 - gc) / 2, gc / 2, gc / 2, (1 - gc) / 2] # A, C, G, T
return ["".join(rng.choice(list("ACGT"), p=p, size=L)) for _ in range(n)]
base = draw(0.40)
for label, other in [("GC 0.40 vs 0.60", draw(0.60)), ("GC 0.40 vs 0.40", draw(0.40))]:
D, info = boundary_distance(base, other)
print(f"{label}: D = {D:.2f} ({info['regime']})")
three = {"AT": draw(0.30), "mid": draw(0.50), "GC": draw(0.70)}
Dm, infom = boundary_distance_multiclass(three)
print(f"3-class GC: min pairwise D = {Dm:.2f} ({infom['regime']})")