"""Sequence-level frustration contrast scores. contrast_hvlv(seq) = mean_FI(high-variance positions) - mean_FI(low-variance positions) High-variance positions are MSA columns whose across-sequence FI variance is at or above the (default 80th) percentile. """ from __future__ import annotations import numpy as np def high_variance_mask(fi_matrix: np.ndarray, percentile: float = 80.0) -> np.ndarray: """Boolean (L,) mask of high-variance MSA columns. Args: fi_matrix: (N, L) per-residue FI; may contain NaN. percentile: column-variance percentile threshold (default 80). Returns: boolean array of length L (True = high-variance). """ if fi_matrix.ndim != 2: raise ValueError("fi_matrix must be 2-D (N, L)") col_var = np.nanvar(fi_matrix, axis=0) if np.all(np.isnan(col_var)): return np.zeros(fi_matrix.shape[1], dtype=bool) thresh = np.nanpercentile(col_var, percentile) return col_var >= thresh def contrast_hvlv(fi_matrix: np.ndarray, percentile: float = 80.0) -> np.ndarray: """Per-sequence high-variance / low-variance FI contrast. score[i] = mean_FI_over_HV_cols(seq_i) - mean_FI_over_LV_cols(seq_i) NaN-safe: sequences with all-NaN in a group contribute 0 there. Args: fi_matrix: (N, L) per-residue FI matrix. percentile: column-variance percentile defining HV (default 80). Returns: np.ndarray (N,) float64 contrast score per sequence. """ if fi_matrix.ndim != 2: raise ValueError("fi_matrix must be 2-D (N, L)") N = fi_matrix.shape[0] hv = high_variance_mask(fi_matrix, percentile=percentile) lv = ~hv if hv.any(): mean_hv = np.nanmean(fi_matrix[:, hv], axis=1) else: mean_hv = np.zeros(N, dtype=np.float64) if lv.any(): mean_lv = np.nanmean(fi_matrix[:, lv], axis=1) else: mean_lv = np.zeros(N, dtype=np.float64) mean_hv = np.nan_to_num(mean_hv, nan=0.0) mean_lv = np.nan_to_num(mean_lv, nan=0.0) return (mean_hv - mean_lv).astype(np.float64, copy=False)