dna-origin-classifier / Detector.v
phanerozoic's picture
Partition law and first-principles crossover; machine-checked robustness proof; provenance watermark
4d79fd9 verified
(* ============================================================================
Certified robustness of the dna-origin-classifier, machine-checked in Rocq 9.
The published model scores a sequence as a linear function of its k-mer
counts: score(s) = (1/T) * sum over windows w of s of u(w) + bias, where
u(w) is the effective weight of the k-mer w (the stored weight divided by
its feature scale) and T is the number of windows.
Because the score is linear, a single base substitution perturbs the counts
of at most k windows, each by an effective weight bounded by Umax in
magnitude, so one edit moves the unnormalized margin by at most k*(2*Umax),
and n edits by at most n*k*(2*Umax). The main theorem turns that bound into
a certificate: if n * k * (2*Umax) < |margin(s)|, no n-substitution edit of s
can change the sign of the call.
This proven radius is a lower bound on the adversarial edit distance: no
substitution count within it can flip the call. The Python `certify` exhibits
an actual flip by greedy search, an upper bound on the same distance, so the
two bracket it. For the published host head the effective weights are
heavy-tailed (max |u| about 95) and the achievable margin is small (|score|
at most about 9), so the proven radius is a few edits at most.
Everything is proved over the rationals with no axioms and no admits.
========================================================================== *)
From Stdlib Require Import List Arith Lia QArith Qabs.
From Stdlib Require Import Lqa.
Import ListNotations.
Open Scope Q_scope.
Section Detector.
(* An abstract alphabet. The published model instantiates it with the four
DNA bases; nothing below depends on the cardinality. *)
Variable Base : Type.
(* k-mer width (8 in the published model). *)
Variable k : nat.
(* Effective per-k-mer weight: the stored weight divided by its feature scale. *)
Variable u : list Base -> Q.
(* A uniform bound on the magnitude of every effective weight. The published
host head has a concrete, finite Umax computed from its 65,536 weights. *)
Variable Umax : Q.
Hypothesis Hu : forall w, Qabs (u w) <= Umax.
(* A sequence is a total map from index to symbol; only indices below the
length are ever inspected. *)
(* The j-th length-k window of f. *)
Definition win (f : nat -> Base) (j : nat) : list Base :=
map (fun p => f (j + p)%nat) (seq 0 k).
(* Number of windows of a length-L sequence (for k <= L this is L - k + 1). *)
Definition nwin (L : nat) : nat := S (L - k).
(* Unnormalized score: the sum of effective weights over all windows. *)
Definition raw (L : nat) (f : nat -> Base) : Q :=
fold_right Qplus 0 (map (fun j => u (win f j)) (seq 0 (nwin L))).
(* Substitute symbol b at position i. *)
Definition upd (f : nat -> Base) (i : nat) (b : Base) : nat -> Base :=
fun j => if Nat.eqb j i then b else f j.
(* A window of index j "contains" position i when j <= i < j + k. *)
Definition contains (i j : nat) : bool := andb (Nat.leb j i) (Nat.ltb i (j + k)).
(* The per-edit margin bound: one substitution moves raw by at most this. *)
Definition K : Q := inject_Z (Z.of_nat k) * (Umax + Umax).
(* ---------- generic rational / list helpers ---------- *)
Lemma Qabs_minus_le : forall a b : Q, Qabs (a - b) <= Qabs a + Qabs b.
Proof.
intros a b. unfold Qminus.
eapply Qle_trans; [apply Qabs_triangle|].
rewrite Qabs_opp. apply Qle_refl.
Qed.
Lemma Qabs_0_le : Qabs 0 <= 0.
Proof. rewrite Qabs_pos by apply Qle_refl. apply Qle_refl. Qed.
Lemma Qabs_sum_le : forall l : list Q,
Qabs (fold_right Qplus 0 l) <= fold_right Qplus 0 (map Qabs l).
Proof.
induction l as [|x l IH]; simpl.
- apply Qle_refl.
- eapply Qle_trans; [apply Qabs_triangle|].
apply Qplus_le_compat; [apply Qle_refl| exact IH].
Qed.
Lemma fold_right_Qplus_le : forall (l : list nat) (a b : nat -> Q),
(forall j, In j l -> a j <= b j) ->
fold_right Qplus 0 (map a l) <= fold_right Qplus 0 (map b l).
Proof.
induction l as [|x l IH]; intros a b H; simpl.
- apply Qle_refl.
- apply Qplus_le_compat.
+ apply H. left. reflexivity.
+ apply IH. intros j Hj. apply H. right. exact Hj.
Qed.
Lemma diff_of_folds : forall (l : list nat) (A B : nat -> Q),
fold_right Qplus 0 (map A l) - fold_right Qplus 0 (map B l)
== fold_right Qplus 0 (map (fun j => A j - B j) l).
Proof.
induction l as [|x l IH]; intros A B; simpl.
- ring.
- rewrite <- IH. ring.
Qed.
Lemma inject_Z_Snat : forall n : nat,
inject_Z (Z.of_nat (S n)) == inject_Z (Z.of_nat n) + 1.
Proof.
intros n.
assert (HZ : Z.of_nat (S n) = (Z.of_nat n + 1)%Z) by lia.
rewrite HZ. rewrite inject_Z_plus.
replace (inject_Z 1) with 1 by reflexivity. reflexivity.
Qed.
Lemma indicator_sum : forall (l : list nat) (P : nat -> bool) (c : Q),
fold_right Qplus 0 (map (fun j => if P j then c else 0) l)
== c * inject_Z (Z.of_nat (length (filter P l))).
Proof.
induction l as [|x l IH]; intros P c; simpl.
- ring.
- destruct (P x) eqn:Hx; simpl.
+ rewrite IH. rewrite (inject_Z_Snat (length (filter P l))). ring.
+ rewrite IH. ring.
Qed.
(* ---------- structural facts about windows and substitution ---------- *)
(* A window that does not contain the edited position is unchanged. *)
Lemma win_eq_off : forall f i b j,
contains i j = false -> win (upd f i b) j = win f j.
Proof.
intros f i b j Hc. unfold win. apply map_ext_in.
intros p Hp. apply in_seq in Hp. destruct Hp as [_ Hp]. simpl in Hp.
unfold upd. destruct (Nat.eqb (j + p) i) eqn:E; [|reflexivity].
apply Nat.eqb_eq in E. exfalso.
assert (Hct : contains i j = true).
{ unfold contains. apply andb_true_intro. split.
- apply Nat.leb_le. lia.
- apply Nat.ltb_lt. lia. }
rewrite Hct in Hc. discriminate.
Qed.
(* At most k windows contain a given position. *)
Lemma count_le_k : forall L i,
(length (filter (contains i) (seq 0 (nwin L))) <= k)%nat.
Proof.
intros L i.
apply Nat.le_trans with (length (seq (i + 1 - k) k)).
- apply NoDup_incl_length.
+ apply NoDup_filter. apply seq_NoDup.
+ intros x Hx. apply filter_In in Hx. destruct Hx as [_ Hc].
unfold contains in Hc. apply andb_true_iff in Hc. destruct Hc as [Hle Hlt].
apply Nat.leb_le in Hle. apply Nat.ltb_lt in Hlt.
apply in_seq. split; lia.
- rewrite length_seq. apply Nat.le_refl.
Qed.
(* ---------- one substitution moves raw by at most K ---------- *)
Lemma single_subst : forall L f i b,
Qabs (raw L f - raw L (upd f i b)) <= K.
Proof.
intros L f i b. unfold raw.
rewrite (diff_of_folds (seq 0 (nwin L))
(fun j => u (win f j)) (fun j => u (win (upd f i b) j))).
eapply Qle_trans.
{ apply Qabs_sum_le. }
rewrite map_map.
eapply Qle_trans.
{ apply (fold_right_Qplus_le (seq 0 (nwin L))
(fun j => Qabs (u (win f j) - u (win (upd f i b) j)))
(fun j => if contains i j then Umax + Umax else 0)).
intros j Hj. cbv beta. destruct (contains i j) eqn:Hcj.
- eapply Qle_trans; [apply Qabs_minus_le|].
apply Qplus_le_compat; apply Hu.
- assert (Hw : win (upd f i b) j = win f j) by (apply win_eq_off; exact Hcj).
rewrite Hw.
setoid_replace (u (win f j) - u (win f j)) with 0 by ring.
apply Qabs_0_le. }
rewrite (indicator_sum (seq 0 (nwin L)) (contains i) (Umax + Umax)).
unfold K.
rewrite (Qmult_comm (Umax + Umax)).
apply Qmult_le_compat_r.
- rewrite <- Zle_Qle. pose proof (count_le_k L i) as Hck. lia.
- assert (HU0 : 0 <= Umax).
{ eapply Qle_trans; [apply (Qabs_nonneg (u nil))| apply Hu]. }
lra.
Qed.
(* ---------- n substitutions move raw by at most n * K ---------- *)
(* g is reachable from f by exactly n single-symbol substitutions. *)
Inductive reach (L : nat) (f : nat -> Base) : (nat -> Base) -> nat -> Prop :=
| reach0 : reach L f f 0
| reachS : forall g i b n, reach L f g n -> reach L f (upd g i b) (S n).
Lemma reach_bound : forall L f g n,
reach L f g n ->
Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K.
Proof.
intros L f g n H. induction H.
- setoid_replace (raw L f - raw L f) with 0 by ring.
setoid_replace (inject_Z (Z.of_nat 0) * K) with 0
by (replace (inject_Z (Z.of_nat 0)) with 0 by reflexivity; ring).
apply Qabs_0_le.
- eapply Qle_trans.
+ setoid_replace (raw L f - raw L (upd g i b))
with ((raw L f - raw L g) + (raw L g - raw L (upd g i b))) by ring.
apply Qabs_triangle.
+ eapply Qle_trans.
{ apply Qplus_le_compat; [exact IHreach| apply single_subst]. }
setoid_replace (inject_Z (Z.of_nat (S n)) * K)
with (inject_Z (Z.of_nat n) * K + K) by (rewrite (inject_Z_Snat n); ring).
apply Qle_refl.
Qed.
(* ---------- the certificate ---------- *)
(* The unnormalized margin: positive favors the class, negative opposes it.
It shares the sign of the normalized score because the window count is
positive, and a same-length edit leaves the window count fixed. *)
Definition margin (L : nat) (b0 : Q) (f : nat -> Base) : Q :=
raw L f + b0 * inject_Z (Z.of_nat (nwin L)).
Lemma margin_diff : forall L b0 f g,
margin L b0 f - margin L b0 g == raw L f - raw L g.
Proof. intros. unfold margin. ring. Qed.
Lemma sign_preserved : forall x d : Q,
Qabs d < Qabs x ->
(0 < x -> 0 < x - d) /\ (x < 0 -> x - d < 0).
Proof.
intros x d Hlt. split; intro Hx.
- assert (Hax : Qabs x == x) by (apply Qabs_pos; lra).
assert (Hd : d <= Qabs d) by apply Qle_Qabs.
lra.
- assert (Hax : Qabs x == - x) by (apply Qabs_neg; lra).
assert (Hcc := proj1 (Qabs_Qle_condition d (Qabs d)) (Qle_refl (Qabs d))).
destruct Hcc as [Hc _].
lra.
Qed.
(* Main theorem. If g is reachable from f by n substitutions and the
certificate n * K < |margin(f)| holds, the call's sign cannot change.
K = k * (2 * Umax), so the certified radius is the largest n with
n * k * 2 * Umax < |margin|, which is what `certify` returns. *)
Theorem certified_robust : forall L b0 f g n,
reach L f g n ->
inject_Z (Z.of_nat n) * K < Qabs (margin L b0 f) ->
(0 < margin L b0 f -> 0 < margin L b0 g) /\
(margin L b0 f < 0 -> margin L b0 g < 0).
Proof.
intros L b0 f g n Hreach Hcert.
assert (Hb : Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K)
by (apply reach_bound; exact Hreach).
set (x := margin L b0 f). set (y := margin L b0 g).
assert (Hd : Qabs (x - y) < Qabs x).
{ setoid_replace (x - y) with (raw L f - raw L g)
by (unfold x, y; apply margin_diff).
eapply Qle_lt_trans; [exact Hb| exact Hcert]. }
destruct (sign_preserved x (x - y) Hd) as [Hpos Hneg].
split; intro Hsgn.
- assert (Hy : 0 < x - (x - y)) by (apply Hpos; exact Hsgn).
setoid_replace y with (x - (x - y)) by ring. exact Hy.
- assert (Hy : x - (x - y) < 0) by (apply Hneg; exact Hsgn).
setoid_replace y with (x - (x - y)) by ring. exact Hy.
Qed.
End Detector.
(* Instantiation to the published model's setting: the four DNA bases and k = 8.
The host head's effective weights have a concrete finite bound Umax_host, so
the guarantee holds with per-substitution margin constant K = 8 * (2 * Umax_host).
The radius `certify` reports is the largest n with n * K < |margin|. *)
Inductive DNA := dA | dC | dG | dT.
Corollary certified_robust_dna :
forall (u : list DNA -> Q) (Umax : Q),
(forall w, Qabs (u w) <= Umax) ->
forall (L : nat) (b0 : Q) (f g : nat -> DNA) (n : nat),
reach DNA L f g n ->
inject_Z (Z.of_nat n) * K 8 Umax < Qabs (margin DNA 8 u L b0 f) ->
(0 < margin DNA 8 u L b0 f -> 0 < margin DNA 8 u L b0 g) /\
(margin DNA 8 u L b0 f < 0 -> margin DNA 8 u L b0 g < 0).
Proof.
intros u Umax Hu. exact (certified_robust DNA 8 u Umax Hu).
Qed.