(* ============================================================================ Certified robustness of the dna-origin-classifier, machine-checked in Rocq 9. The published model scores a sequence as a linear function of its k-mer counts: score(s) = (1/T) * sum over windows w of s of u(w) + bias, where u(w) is the effective weight of the k-mer w (the stored weight divided by its feature scale) and T is the number of windows. Because the score is linear, a single base substitution perturbs the counts of at most k windows, each by an effective weight bounded by Umax in magnitude, so one edit moves the unnormalized margin by at most k*(2*Umax), and n edits by at most n*k*(2*Umax). The main theorem turns that bound into a certificate: if n * k * (2*Umax) < |margin(s)|, no n-substitution edit of s can change the sign of the call. This proven radius is a lower bound on the adversarial edit distance: no substitution count within it can flip the call. The Python `certify` exhibits an actual flip by greedy search, an upper bound on the same distance, so the two bracket it. For the published host head the effective weights are heavy-tailed (max |u| about 95) and the achievable margin is small (|score| at most about 9), so the proven radius is a few edits at most. Everything is proved over the rationals with no axioms and no admits. ========================================================================== *) From Stdlib Require Import List Arith Lia QArith Qabs. From Stdlib Require Import Lqa. Import ListNotations. Open Scope Q_scope. Section Detector. (* An abstract alphabet. The published model instantiates it with the four DNA bases; nothing below depends on the cardinality. *) Variable Base : Type. (* k-mer width (8 in the published model). *) Variable k : nat. (* Effective per-k-mer weight: the stored weight divided by its feature scale. *) Variable u : list Base -> Q. (* A uniform bound on the magnitude of every effective weight. The published host head has a concrete, finite Umax computed from its 65,536 weights. *) Variable Umax : Q. Hypothesis Hu : forall w, Qabs (u w) <= Umax. (* A sequence is a total map from index to symbol; only indices below the length are ever inspected. *) (* The j-th length-k window of f. *) Definition win (f : nat -> Base) (j : nat) : list Base := map (fun p => f (j + p)%nat) (seq 0 k). (* Number of windows of a length-L sequence (for k <= L this is L - k + 1). *) Definition nwin (L : nat) : nat := S (L - k). (* Unnormalized score: the sum of effective weights over all windows. *) Definition raw (L : nat) (f : nat -> Base) : Q := fold_right Qplus 0 (map (fun j => u (win f j)) (seq 0 (nwin L))). (* Substitute symbol b at position i. *) Definition upd (f : nat -> Base) (i : nat) (b : Base) : nat -> Base := fun j => if Nat.eqb j i then b else f j. (* A window of index j "contains" position i when j <= i < j + k. *) Definition contains (i j : nat) : bool := andb (Nat.leb j i) (Nat.ltb i (j + k)). (* The per-edit margin bound: one substitution moves raw by at most this. *) Definition K : Q := inject_Z (Z.of_nat k) * (Umax + Umax). (* ---------- generic rational / list helpers ---------- *) Lemma Qabs_minus_le : forall a b : Q, Qabs (a - b) <= Qabs a + Qabs b. Proof. intros a b. unfold Qminus. eapply Qle_trans; [apply Qabs_triangle|]. rewrite Qabs_opp. apply Qle_refl. Qed. Lemma Qabs_0_le : Qabs 0 <= 0. Proof. rewrite Qabs_pos by apply Qle_refl. apply Qle_refl. Qed. Lemma Qabs_sum_le : forall l : list Q, Qabs (fold_right Qplus 0 l) <= fold_right Qplus 0 (map Qabs l). Proof. induction l as [|x l IH]; simpl. - apply Qle_refl. - eapply Qle_trans; [apply Qabs_triangle|]. apply Qplus_le_compat; [apply Qle_refl| exact IH]. Qed. Lemma fold_right_Qplus_le : forall (l : list nat) (a b : nat -> Q), (forall j, In j l -> a j <= b j) -> fold_right Qplus 0 (map a l) <= fold_right Qplus 0 (map b l). Proof. induction l as [|x l IH]; intros a b H; simpl. - apply Qle_refl. - apply Qplus_le_compat. + apply H. left. reflexivity. + apply IH. intros j Hj. apply H. right. exact Hj. Qed. Lemma diff_of_folds : forall (l : list nat) (A B : nat -> Q), fold_right Qplus 0 (map A l) - fold_right Qplus 0 (map B l) == fold_right Qplus 0 (map (fun j => A j - B j) l). Proof. induction l as [|x l IH]; intros A B; simpl. - ring. - rewrite <- IH. ring. Qed. Lemma inject_Z_Snat : forall n : nat, inject_Z (Z.of_nat (S n)) == inject_Z (Z.of_nat n) + 1. Proof. intros n. assert (HZ : Z.of_nat (S n) = (Z.of_nat n + 1)%Z) by lia. rewrite HZ. rewrite inject_Z_plus. replace (inject_Z 1) with 1 by reflexivity. reflexivity. Qed. Lemma indicator_sum : forall (l : list nat) (P : nat -> bool) (c : Q), fold_right Qplus 0 (map (fun j => if P j then c else 0) l) == c * inject_Z (Z.of_nat (length (filter P l))). Proof. induction l as [|x l IH]; intros P c; simpl. - ring. - destruct (P x) eqn:Hx; simpl. + rewrite IH. rewrite (inject_Z_Snat (length (filter P l))). ring. + rewrite IH. ring. Qed. (* ---------- structural facts about windows and substitution ---------- *) (* A window that does not contain the edited position is unchanged. *) Lemma win_eq_off : forall f i b j, contains i j = false -> win (upd f i b) j = win f j. Proof. intros f i b j Hc. unfold win. apply map_ext_in. intros p Hp. apply in_seq in Hp. destruct Hp as [_ Hp]. simpl in Hp. unfold upd. destruct (Nat.eqb (j + p) i) eqn:E; [|reflexivity]. apply Nat.eqb_eq in E. exfalso. assert (Hct : contains i j = true). { unfold contains. apply andb_true_intro. split. - apply Nat.leb_le. lia. - apply Nat.ltb_lt. lia. } rewrite Hct in Hc. discriminate. Qed. (* At most k windows contain a given position. *) Lemma count_le_k : forall L i, (length (filter (contains i) (seq 0 (nwin L))) <= k)%nat. Proof. intros L i. apply Nat.le_trans with (length (seq (i + 1 - k) k)). - apply NoDup_incl_length. + apply NoDup_filter. apply seq_NoDup. + intros x Hx. apply filter_In in Hx. destruct Hx as [_ Hc]. unfold contains in Hc. apply andb_true_iff in Hc. destruct Hc as [Hle Hlt]. apply Nat.leb_le in Hle. apply Nat.ltb_lt in Hlt. apply in_seq. split; lia. - rewrite length_seq. apply Nat.le_refl. Qed. (* ---------- one substitution moves raw by at most K ---------- *) Lemma single_subst : forall L f i b, Qabs (raw L f - raw L (upd f i b)) <= K. Proof. intros L f i b. unfold raw. rewrite (diff_of_folds (seq 0 (nwin L)) (fun j => u (win f j)) (fun j => u (win (upd f i b) j))). eapply Qle_trans. { apply Qabs_sum_le. } rewrite map_map. eapply Qle_trans. { apply (fold_right_Qplus_le (seq 0 (nwin L)) (fun j => Qabs (u (win f j) - u (win (upd f i b) j))) (fun j => if contains i j then Umax + Umax else 0)). intros j Hj. cbv beta. destruct (contains i j) eqn:Hcj. - eapply Qle_trans; [apply Qabs_minus_le|]. apply Qplus_le_compat; apply Hu. - assert (Hw : win (upd f i b) j = win f j) by (apply win_eq_off; exact Hcj). rewrite Hw. setoid_replace (u (win f j) - u (win f j)) with 0 by ring. apply Qabs_0_le. } rewrite (indicator_sum (seq 0 (nwin L)) (contains i) (Umax + Umax)). unfold K. rewrite (Qmult_comm (Umax + Umax)). apply Qmult_le_compat_r. - rewrite <- Zle_Qle. pose proof (count_le_k L i) as Hck. lia. - assert (HU0 : 0 <= Umax). { eapply Qle_trans; [apply (Qabs_nonneg (u nil))| apply Hu]. } lra. Qed. (* ---------- n substitutions move raw by at most n * K ---------- *) (* g is reachable from f by exactly n single-symbol substitutions. *) Inductive reach (L : nat) (f : nat -> Base) : (nat -> Base) -> nat -> Prop := | reach0 : reach L f f 0 | reachS : forall g i b n, reach L f g n -> reach L f (upd g i b) (S n). Lemma reach_bound : forall L f g n, reach L f g n -> Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K. Proof. intros L f g n H. induction H. - setoid_replace (raw L f - raw L f) with 0 by ring. setoid_replace (inject_Z (Z.of_nat 0) * K) with 0 by (replace (inject_Z (Z.of_nat 0)) with 0 by reflexivity; ring). apply Qabs_0_le. - eapply Qle_trans. + setoid_replace (raw L f - raw L (upd g i b)) with ((raw L f - raw L g) + (raw L g - raw L (upd g i b))) by ring. apply Qabs_triangle. + eapply Qle_trans. { apply Qplus_le_compat; [exact IHreach| apply single_subst]. } setoid_replace (inject_Z (Z.of_nat (S n)) * K) with (inject_Z (Z.of_nat n) * K + K) by (rewrite (inject_Z_Snat n); ring). apply Qle_refl. Qed. (* ---------- the certificate ---------- *) (* The unnormalized margin: positive favors the class, negative opposes it. It shares the sign of the normalized score because the window count is positive, and a same-length edit leaves the window count fixed. *) Definition margin (L : nat) (b0 : Q) (f : nat -> Base) : Q := raw L f + b0 * inject_Z (Z.of_nat (nwin L)). Lemma margin_diff : forall L b0 f g, margin L b0 f - margin L b0 g == raw L f - raw L g. Proof. intros. unfold margin. ring. Qed. Lemma sign_preserved : forall x d : Q, Qabs d < Qabs x -> (0 < x -> 0 < x - d) /\ (x < 0 -> x - d < 0). Proof. intros x d Hlt. split; intro Hx. - assert (Hax : Qabs x == x) by (apply Qabs_pos; lra). assert (Hd : d <= Qabs d) by apply Qle_Qabs. lra. - assert (Hax : Qabs x == - x) by (apply Qabs_neg; lra). assert (Hcc := proj1 (Qabs_Qle_condition d (Qabs d)) (Qle_refl (Qabs d))). destruct Hcc as [Hc _]. lra. Qed. (* Main theorem. If g is reachable from f by n substitutions and the certificate n * K < |margin(f)| holds, the call's sign cannot change. K = k * (2 * Umax), so the certified radius is the largest n with n * k * 2 * Umax < |margin|, which is what `certify` returns. *) Theorem certified_robust : forall L b0 f g n, reach L f g n -> inject_Z (Z.of_nat n) * K < Qabs (margin L b0 f) -> (0 < margin L b0 f -> 0 < margin L b0 g) /\ (margin L b0 f < 0 -> margin L b0 g < 0). Proof. intros L b0 f g n Hreach Hcert. assert (Hb : Qabs (raw L f - raw L g) <= inject_Z (Z.of_nat n) * K) by (apply reach_bound; exact Hreach). set (x := margin L b0 f). set (y := margin L b0 g). assert (Hd : Qabs (x - y) < Qabs x). { setoid_replace (x - y) with (raw L f - raw L g) by (unfold x, y; apply margin_diff). eapply Qle_lt_trans; [exact Hb| exact Hcert]. } destruct (sign_preserved x (x - y) Hd) as [Hpos Hneg]. split; intro Hsgn. - assert (Hy : 0 < x - (x - y)) by (apply Hpos; exact Hsgn). setoid_replace y with (x - (x - y)) by ring. exact Hy. - assert (Hy : x - (x - y) < 0) by (apply Hneg; exact Hsgn). setoid_replace y with (x - (x - y)) by ring. exact Hy. Qed. End Detector. (* Instantiation to the published model's setting: the four DNA bases and k = 8. The host head's effective weights have a concrete finite bound Umax_host, so the guarantee holds with per-substitution margin constant K = 8 * (2 * Umax_host). The radius `certify` reports is the largest n with n * K < |margin|. *) Inductive DNA := dA | dC | dG | dT. Corollary certified_robust_dna : forall (u : list DNA -> Q) (Umax : Q), (forall w, Qabs (u w) <= Umax) -> forall (L : nat) (b0 : Q) (f g : nat -> DNA) (n : nat), reach DNA L f g n -> inject_Z (Z.of_nat n) * K 8 Umax < Qabs (margin DNA 8 u L b0 f) -> (0 < margin DNA 8 u L b0 f -> 0 < margin DNA 8 u L b0 g) /\ (margin DNA 8 u L b0 f < 0 -> margin DNA 8 u L b0 g < 0). Proof. intros u Umax Hu. exact (certified_robust DNA 8 u Umax Hu). Qed.