dna-origin-classifier / CompositionBoundary.v
phanerozoic's picture
Comprehensive card rewrite; Markov-correlated generalization in the proof; correlation figure
9e6124a verified
(* ============================================================================
The composition boundary, machine-checked in Rocq 9.
A composition classifier sees a sequence only through its symbol counts, a
sufficient statistic of an iid model, so model a length-N sequence over V
symbols as N independent draws from a class-conditional distribution P0 or P1
(a list of V nonnegative probabilities).
The Bhattacharyya coefficient of the two class distributions over sequences,
BCseq = sum over all length-N sequences s of sqrt(P0(s) * P1(s)), factorizes
exactly over the independent coordinates:
BCseq = rho ^ N, rho = sum_g sqrt(P0_g * P1_g).
The Bayes error of the equiprobable two-class test is 1/2 * sum min(P0,P1),
bounded by 1/2 * BCseq = 1/2 * rho^N. Hence the dichotomy:
- identical class distributions (P0 = P1, normalized) give rho = 1, so
BCseq = 1 and the Bayes error is exactly 1/2: composition is at chance;
- rho < 1 gives BCseq = rho^N, nonincreasing in N and bounded by rho, so
the classes separate and composition reaches the Bayes rate.
Proved over the reals with no axioms and no admits.
========================================================================== *)
From Stdlib Require Import List Reals Lra Lia.
Import ListNotations.
Open Scope R_scope.
Definition Rsum : list R -> R := fold_right Rplus 0.
(* ---------- list-sum helpers ---------- *)
Lemma Rsum_app : forall l1 l2, Rsum (l1 ++ l2) = Rsum l1 + Rsum l2.
Proof. induction l1; intros; simpl; [lra | rewrite IHl1; lra]. Qed.
Lemma Rsum_map_mult_l : forall (A : Type) (c : R) (f : A -> R) (l : list A),
Rsum (map (fun x => c * f x) l) = c * Rsum (map f l).
Proof. intros A c f; induction l; simpl; [ring | rewrite IHl; ring]. Qed.
Lemma Rsum_map_mult_r : forall (A : Type) (c : R) (f : A -> R) (l : list A),
Rsum (map (fun x => f x * c) l) = Rsum (map f l) * c.
Proof. intros A c f; induction l; simpl; [ring | rewrite IHl; ring]. Qed.
Lemma Rsum_map_le : forall (A : Type) (f g : A -> R) (l : list A),
(forall x, In x l -> f x <= g x) -> Rsum (map f l) <= Rsum (map g l).
Proof.
intros A f g; induction l as [|x l IH]; intros H; simpl; [lra|].
apply Rplus_le_compat; [apply H; left; reflexivity | apply IH; intros y Hy; apply H; right; exact Hy].
Qed.
Lemma Rsum_flat_map : forall (A : Type) (F : list nat -> R) (G : A -> list (list nat)) (L : list A),
Rsum (map F (flat_map G L)) = Rsum (map (fun x => Rsum (map F (G x))) L).
Proof.
intros A F G; induction L as [|x L IH]; simpl; [reflexivity|].
rewrite map_app, Rsum_app, IH; reflexivity.
Qed.
(* ---------- iid model ---------- *)
Definition jp (P : list R) (s : list nat) : R :=
fold_right Rmult 1 (map (fun g => nth g P 0) s).
Lemma jp_cons : forall P g s, jp P (g :: s) = nth g P 0 * jp P s.
Proof. reflexivity. Qed.
Fixpoint allseq (V N : nat) : list (list nat) :=
match N with
| 0 => [ [] ]
| S n => flat_map (fun s => map (fun g => g :: s) (seq 0 V)) (allseq V n)
end.
Definition rho1 (P0 P1 : list R) (V : nat) : R :=
Rsum (map (fun g => sqrt (nth g P0 0 * nth g P1 0)) (seq 0 V)).
Definition BCseq (P0 P1 : list R) (V N : nat) : R :=
Rsum (map (fun s => sqrt (jp P0 s * jp P1 s)) (allseq V N)).
(* ---------- nonnegativity ---------- *)
Lemma nth_nonneg : forall P g, Forall (Rle 0) P -> 0 <= nth g P 0.
Proof.
intros P g H. destruct (Nat.lt_ge_cases g (length P)) as [Hlt | Hge].
- rewrite Forall_nth in H. apply H. exact Hlt.
- rewrite nth_overflow by lia. lra.
Qed.
Lemma jp_nonneg : forall P s, Forall (Rle 0) P -> 0 <= jp P s.
Proof.
intros P s H. induction s as [|g s IH].
- unfold jp; simpl; lra.
- rewrite jp_cons. apply Rmult_le_pos; [apply nth_nonneg; exact H | exact IH].
Qed.
(* ---------- per-coordinate sums ---------- *)
Lemma inner_sum : forall P0 P1, Forall (Rle 0) P0 -> Forall (Rle 0) P1 ->
forall V s,
Rsum (map (fun s' => sqrt (jp P0 s' * jp P1 s')) (map (fun g => g :: s) (seq 0 V)))
= rho1 P0 P1 V * sqrt (jp P0 s * jp P1 s).
Proof.
intros P0 P1 H0 H1 V s. rewrite map_map. cbv beta.
assert (BODY : forall g, sqrt (jp P0 (g :: s) * jp P1 (g :: s))
= sqrt (nth g P0 0 * nth g P1 0) * sqrt (jp P0 s * jp P1 s)).
{ intro g. rewrite !jp_cons.
replace (nth g P0 0 * jp P0 s * (nth g P1 0 * jp P1 s))
with ((nth g P0 0 * nth g P1 0) * (jp P0 s * jp P1 s)) by ring.
rewrite sqrt_mult.
- reflexivity.
- apply Rmult_le_pos; apply nth_nonneg; assumption.
- apply Rmult_le_pos; apply jp_nonneg; assumption. }
rewrite (map_ext _ _ BODY).
rewrite (Rsum_map_mult_r _ (sqrt (jp P0 s * jp P1 s)) (fun g => sqrt (nth g P0 0 * nth g P1 0)) (seq 0 V)).
unfold rho1. reflexivity.
Qed.
Lemma inner_total : forall P V s,
Rsum (map (fun s' => jp P s') (map (fun g => g :: s) (seq 0 V)))
= Rsum (map (fun g => nth g P 0) (seq 0 V)) * jp P s.
Proof.
intros P V s. rewrite map_map. cbv beta.
assert (BODY : forall g, jp P (g :: s) = nth g P 0 * jp P s) by (intro g; apply jp_cons).
rewrite (map_ext _ _ BODY).
rewrite (Rsum_map_mult_r _ (jp P s) (fun g => nth g P 0) (seq 0 V)). reflexivity.
Qed.
(* ---------- the factorization: BCseq = rho ^ N ---------- *)
Theorem BC_factorizes : forall P0 P1 V N,
Forall (Rle 0) P0 -> Forall (Rle 0) P1 ->
BCseq P0 P1 V N = (rho1 P0 P1 V) ^ N.
Proof.
intros P0 P1 V N H0 H1. induction N as [|N IH].
- unfold BCseq. simpl. unfold jp; simpl. rewrite Rmult_1_r, sqrt_1. simpl; lra.
- unfold BCseq.
replace (allseq V (S N)) with (flat_map (fun s => map (fun g => g :: s) (seq 0 V)) (allseq V N)) by reflexivity.
rewrite Rsum_flat_map. cbv beta.
rewrite (map_ext _ _ (inner_sum P0 P1 H0 H1 V)).
rewrite (Rsum_map_mult_l _ (rho1 P0 P1 V) (fun s => sqrt (jp P0 s * jp P1 s)) (allseq V N)).
change (Rsum (map (fun s => sqrt (jp P0 s * jp P1 s)) (allseq V N))) with (BCseq P0 P1 V N).
rewrite IH. reflexivity.
Qed.
Theorem total_prob : forall P V N,
Rsum (map (fun s => jp P s) (allseq V N)) = (Rsum (map (fun g => nth g P 0) (seq 0 V))) ^ N.
Proof.
intros P V N. induction N as [|N IH].
- simpl. unfold jp; simpl. lra.
- replace (allseq V (S N)) with (flat_map (fun s => map (fun g => g :: s) (seq 0 V)) (allseq V N)) by reflexivity.
rewrite Rsum_flat_map. cbv beta.
rewrite (map_ext _ _ (inner_total P V)).
rewrite (Rsum_map_mult_l _ (Rsum (map (fun g => nth g P 0) (seq 0 V))) (fun s => jp P s) (allseq V N)).
rewrite IH. reflexivity.
Qed.
(* ---------- Bayes error ---------- *)
Definition pe (P0 P1 : list R) (V N : nat) : R :=
(1/2) * Rsum (map (fun s => Rmin (jp P0 s) (jp P1 s)) (allseq V N)).
Lemma Rmin_le_sqrt_mult : forall a b, 0 <= a -> 0 <= b -> Rmin a b <= sqrt (a * b).
Proof.
intros a b Ha Hb. apply Rsqr_incr_0.
- unfold Rsqr. rewrite sqrt_sqrt by (apply Rmult_le_pos; assumption).
destruct (Rle_dec a b) as [H|H].
+ rewrite Rmin_left by exact H. apply Rmult_le_compat_l; [exact Ha | exact H].
+ rewrite Rmin_right by lra. rewrite (Rmult_comm a b).
apply Rmult_le_compat_l; [exact Hb | lra].
- apply Rmin_glb; assumption.
- apply sqrt_pos.
Qed.
Theorem bayes_bound : forall P0 P1 V N,
Forall (Rle 0) P0 -> Forall (Rle 0) P1 ->
pe P0 P1 V N <= (1/2) * (rho1 P0 P1 V) ^ N.
Proof.
intros P0 P1 V N H0 H1. unfold pe.
rewrite <- BC_factorizes by assumption. unfold BCseq.
apply Rmult_le_compat_l; [lra|].
apply Rsum_map_le. intros s _. apply Rmin_le_sqrt_mult; apply jp_nonneg; assumption.
Qed.
(* ---------- dichotomy ---------- *)
Theorem degenerate_is_chance : forall P V N,
Forall (Rle 0) P ->
Rsum (map (fun g => nth g P 0) (seq 0 V)) = 1 ->
rho1 P P V = 1 /\ pe P P V N = 1/2.
Proof.
intros P V N HP Hnorm. split.
- unfold rho1.
assert (B : forall g, sqrt (nth g P 0 * nth g P 0) = nth g P 0).
{ intro g. replace (nth g P 0 * nth g P 0) with (Rsqr (nth g P 0)) by (unfold Rsqr; ring).
apply sqrt_Rsqr. apply nth_nonneg; exact HP. }
rewrite (map_ext _ _ B). exact Hnorm.
- unfold pe.
assert (B : forall s, Rmin (jp P s) (jp P s) = jp P s) by (intro s; apply Rmin_left; apply Rle_refl).
rewrite (map_ext _ _ B). rewrite total_prob, Hnorm, pow1. lra.
Qed.
Lemma pow_le_one : forall r n, 0 <= r -> r <= 1 -> r ^ n <= 1.
Proof.
intros r n Hr0 Hr1. induction n as [|n IH]; simpl; [lra|].
apply Rle_trans with (r * 1); [apply Rmult_le_compat_l; [exact Hr0 | exact IH] | lra].
Qed.
Theorem separable_shrinks : forall r N,
0 <= r -> r < 1 -> r ^ (S N) <= r ^ N /\ r ^ (S N) <= r.
Proof.
intros r N Hr0 Hr1. split; simpl.
- apply Rle_trans with (1 * r ^ N); [apply Rmult_le_compat_r; [apply pow_le; exact Hr0 | lra] | lra].
- apply Rle_trans with (r * 1); [apply Rmult_le_compat_l; [exact Hr0 | apply pow_le_one; lra] | lra].
Qed.
(* ============================================================================
The Jensen-Shannon identity tying the boundary to mutual information.
For a balanced binary label Y with X|Y distributed as P or Q, the mutual
information I(Y;X) equals 1/2 KL(P||M) + 1/2 KL(Q||M) with M = (P+Q)/2, the
information form of the Jensen-Shannon divergence. That equals the entropy
form H(M) - 1/2 H(P) - 1/2 H(Q). Below, jsd_info is the information form
summed coordinatewise and jsd_entropy is the entropy form; they coincide on
strictly positive distributions. Since X may be the k-mer count vector, this
identifies the boundary's information fraction with a mutual information.
========================================================================== *)
Lemma lndiv : forall x y, 0 < x -> 0 < y -> ln (x / y) = ln x - ln y.
Proof.
intros x y Hx Hy. unfold Rdiv.
rewrite ln_mult by (first [exact Hx | apply Rinv_0_lt_compat; exact Hy]).
rewrite ln_Rinv by exact Hy. ring.
Qed.
Fixpoint jsd_info (P Q : list R) : R :=
match P, Q with
| p :: P', q :: Q' =>
(1/2) * (p * ln (p / ((p + q) / 2))) + (1/2) * (q * ln (q / ((p + q) / 2))) + jsd_info P' Q'
| _, _ => 0
end.
Fixpoint jsd_entropy (P Q : list R) : R :=
match P, Q with
| p :: P', q :: Q' =>
- (((p + q) / 2) * ln ((p + q) / 2)) + (1/2) * (p * ln p) + (1/2) * (q * ln q) + jsd_entropy P' Q'
| _, _ => 0
end.
Theorem jsd_info_eq_entropy : forall P Q,
Forall (Rlt 0) P -> Forall (Rlt 0) Q -> length P = length Q ->
jsd_info P Q = jsd_entropy P Q.
Proof.
intros P. induction P as [|p P IH]; intros Q HP HQ Hlen;
destruct Q as [|q Q]; simpl in *; try reflexivity; try discriminate.
inversion HP as [|x xs Hp HP' Heq0]; subst.
inversion HQ as [|x xs Hq HQ' Heq1]; subst.
assert (Hlen' : length P = length Q) by lia.
rewrite (lndiv p ((p + q) / 2)) by lra.
rewrite (lndiv q ((p + q) / 2)) by lra.
rewrite (IH Q HP' HQ' Hlen').
field.
Qed.
(* ============================================================================
Correlated generalization: the Markov transfer recursion.
The bag-of-k-mers model treats k-mers as independent; real k-mers overlap and
are correlated. Model the sequence as an order-1 Markov chain over states
(the k-mers): amp i = sqrt(pi0_i * pi1_i) is the initial Bhattacharyya
amplitude and M i j = sqrt(T0(j|i) * T1(j|i)) the per-transition Bhattacharyya
weight. The Bhattacharyya coefficient of the two length-L path distributions
is the transfer-matrix evaluation below: bw n i sums the transition weight
over all length-n continuations from state i, and BCpath L weights it by amp.
Its per-step growth is the top eigenvalue of M, the correlated analogue of the
scalar overlap rho. When the chain is iid, M i j = r j, the recursion
collapses to (sum r)^n and the coefficient to (sum amp)(sum r)^(L-1),
recovering the bag-of-k-mers result. Proved over the reals, no admits.
========================================================================== *)
Section Markov.
Variable Q : nat. (* number of states *)
Variable M : nat -> nat -> R. (* per-transition Bhattacharyya weight *)
Variable amp : nat -> R. (* initial Bhattacharyya amplitude *)
Fixpoint bw (n i : nat) : R :=
match n with
| O => 1
| S m => Rsum (map (fun j => M i j * bw m j) (seq 0 Q))
end.
Definition BCpath (L : nat) : R :=
match L with
| O => 0
| S n => Rsum (map (fun i => amp i * bw n i) (seq 0 Q))
end.
(* the transfer-matrix recursion, stated explicitly: one step applies M *)
Lemma bw_transfer : forall n i,
bw (S n) i = Rsum (map (fun j => M i j * bw n j) (seq 0 Q)).
Proof. reflexivity. Qed.
(* iid case: constant transfer rows collapse the recursion to a scalar power *)
Lemma bw_iid : forall r, (forall i j, M i j = r j) ->
forall n i, bw n i = (Rsum (map r (seq 0 Q))) ^ n.
Proof.
intros r Hr n. induction n as [|m IH]; intro i; simpl.
- reflexivity.
- erewrite map_ext.
2:{ intro j. rewrite (Hr i j). rewrite (IH j). reflexivity. }
rewrite (Rsum_map_mult_r nat ((Rsum (map r (seq 0 Q))) ^ m) r (seq 0 Q)).
reflexivity.
Qed.
Theorem BC_iid_reduces : forall r, (forall i j, M i j = r j) ->
forall n, BCpath (S n) = Rsum (map amp (seq 0 Q)) * (Rsum (map r (seq 0 Q))) ^ n.
Proof.
intros r Hr n. unfold BCpath.
erewrite map_ext.
2:{ intro i. rewrite (bw_iid r Hr n i). reflexivity. }
rewrite (Rsum_map_mult_r nat ((Rsum (map r (seq 0 Q))) ^ n) amp (seq 0 Q)).
reflexivity.
Qed.
End Markov.