Real foundation model (DNABERT) gain-vs-D, field-wide sweep, protein/RNA parity; boundary.py multiclass+regression
f6775f9 verified | """The composition-solvability criterion. | |
| Given two classes of sequences, D = -N ln rho measures whether their k-mer | |
| counts can separate them, where rho is the Bhattacharyya overlap of the | |
| class-conditional k-mer distributions and N is the per-sequence k-mer count. | |
| D is read from the data with no classifier and no training. | |
| D near 0 the classes share their k-mer distribution; no composition method | |
| exceeds chance and a context model is required (the deep regime). | |
| D large the k-mer counts are a sufficient statistic; a linear model reaches | |
| the Bayes rate (the closed-form regime). | |
| The threshold is near D = 1.4. What D controls is proved in CompositionBoundary.v | |
| (BC_N = rho^N, the Bayes-error bound, the dichotomy). rho is bias-corrected | |
| against a matched same-distribution null, so two finite samples of one | |
| distribution score D near 0. Requires only numpy. | |
| """ | |
| import math | |
| from itertools import product | |
| import numpy as np | |
| def _index_map(alphabet, k): | |
| return {"".join(p): i for i, p in enumerate(product(alphabet, repeat=k))} | |
| def _theta(seqs, k, idx, vsize, alpha=0.5): | |
| counts = np.full(vsize, alpha) | |
| lengths = [] | |
| for s in seqs: | |
| s = s.upper(); n = 0 | |
| for i in range(len(s) - k + 1): | |
| j = idx.get(s[i:i + k]) | |
| if j is not None: | |
| counts[j] += 1; n += 1 | |
| if n: | |
| lengths.append(n) | |
| return counts / counts.sum(), (float(np.median(lengths)) if lengths else 0.0) | |
| def _overlap(a, b): | |
| return float(np.sum(np.sqrt(a * b))) | |
| def boundary_distance(class_a, class_b, alphabet="ACGT", k=6, splits=6, seed=0): | |
| """Effective Bhattacharyya distance D between two classes of sequences. | |
| Returns (D, info); info carries the between-class and matched-null overlaps, | |
| the bias-corrected overlap, N, and the predicted regime. | |
| """ | |
| rng = np.random.RandomState(seed) | |
| idx = _index_map(alphabet, k); vsize = len(alphabet) ** k | |
| a = list(class_a); b = list(class_b) | |
| between, null, ns = [], [], [] | |
| for _ in range(splits): | |
| rng.shuffle(a); rng.shuffle(b) | |
| a0, a1 = a[:len(a) // 2], a[len(a) // 2:] | |
| b0, b1 = b[:len(b) // 2], b[len(b) // 2:] | |
| ta0, na = _theta(a0, k, idx, vsize); ta1, _ = _theta(a1, k, idx, vsize) | |
| tb0, nb = _theta(b0, k, idx, vsize); tb1, _ = _theta(b1, k, idx, vsize) | |
| between.append(_overlap(ta0, tb0)) | |
| null.append(math.sqrt(_overlap(ta0, ta1) * _overlap(tb0, tb1))) | |
| ns.append((na + nb) / 2) | |
| rho_between = float(np.mean(between)); rho_null = float(np.mean(null)); N = round(float(np.mean(ns))) | |
| rho = min(1.0, rho_between / rho_null) if rho_null > 0 else 1.0 | |
| D = -N * math.log(rho) if 0 < rho < 1 else (0.0 if rho >= 1 else float("inf")) | |
| return D, {"rho_between": round(rho_between, 4), "rho_null": round(rho_null, 4), | |
| "rho_corrected": round(rho, 4), "N": N, | |
| "regime": "composition-solvable" if D >= 1.4 else "context-required"} | |
| def boundary_distance_multiclass(classes, alphabet="ACGT", k=6, splits=6, seed=0): | |
| """Minimum pairwise D over all class pairs. | |
| `classes` is a dict label -> sequences. A task is composition-solvable only | |
| if every pair of classes is separable, so the minimum pairwise D governs. | |
| Returns (D_min, info) with the per-pair distances. | |
| """ | |
| from itertools import combinations | |
| labels = list(classes) | |
| pairs = {} | |
| for a, b in combinations(labels, 2): | |
| D, _ = boundary_distance(classes[a], classes[b], alphabet, k, splits, seed) | |
| pairs[f"{a}|{b}"] = round(D, 2) | |
| d_min = min(pairs.values()) | |
| return d_min, {"pairwise_D": pairs, | |
| "regime": "composition-solvable" if d_min >= 1.4 else "context-required"} | |
| def boundary_distance_regression(seqs, values, alphabet="ACGT", k=6, splits=6, seed=0): | |
| """Binarize a continuous target at its median, then measure D between the | |
| high and low halves.""" | |
| v = np.asarray(values, dtype=float) | |
| med = float(np.median(v)) | |
| hi = [s for s, x in zip(seqs, v) if x > med] | |
| lo = [s for s, x in zip(seqs, v) if x <= med] | |
| return boundary_distance(hi, lo, alphabet, k, splits, seed) | |
| if __name__ == "__main__": | |
| rng = np.random.RandomState(0) | |
| def draw(gc, n=400, L=300): | |
| p = [(1 - gc) / 2, gc / 2, gc / 2, (1 - gc) / 2] # A, C, G, T | |
| return ["".join(rng.choice(list("ACGT"), p=p, size=L)) for _ in range(n)] | |
| base = draw(0.40) | |
| for label, other in [("GC 0.40 vs 0.60", draw(0.60)), ("GC 0.40 vs 0.40", draw(0.40))]: | |
| D, info = boundary_distance(base, other) | |
| print(f"{label}: D = {D:.2f} ({info['regime']})") | |
| three = {"AT": draw(0.30), "mid": draw(0.50), "GC": draw(0.70)} | |
| Dm, infom = boundary_distance_multiclass(three) | |
| print(f"3-class GC: min pairwise D = {Dm:.2f} ({infom['regime']})") | |