SF-Cluster / src /sf_cluster /score.py
chq1155's picture
Initial OSS release: mosaic + gradient subset builders (verified KaiB 95.0%, GA98 92.5%, GB98 50.0% on Phase XII pilot)
ccbe063 verified
Raw
History Blame Contribute Delete
2.13 kB
"""Sequence-level frustration contrast scores.
contrast_hvlv(seq) = mean_FI(high-variance positions) - mean_FI(low-variance positions)
High-variance positions are MSA columns whose across-sequence FI variance is
at or above the (default 80th) percentile.
"""
from __future__ import annotations
import numpy as np
def high_variance_mask(fi_matrix: np.ndarray,
percentile: float = 80.0) -> np.ndarray:
"""Boolean (L,) mask of high-variance MSA columns.
Args:
fi_matrix: (N, L) per-residue FI; may contain NaN.
percentile: column-variance percentile threshold (default 80).
Returns:
boolean array of length L (True = high-variance).
"""
if fi_matrix.ndim != 2:
raise ValueError("fi_matrix must be 2-D (N, L)")
col_var = np.nanvar(fi_matrix, axis=0)
if np.all(np.isnan(col_var)):
return np.zeros(fi_matrix.shape[1], dtype=bool)
thresh = np.nanpercentile(col_var, percentile)
return col_var >= thresh
def contrast_hvlv(fi_matrix: np.ndarray,
percentile: float = 80.0) -> np.ndarray:
"""Per-sequence high-variance / low-variance FI contrast.
score[i] = mean_FI_over_HV_cols(seq_i) - mean_FI_over_LV_cols(seq_i)
NaN-safe: sequences with all-NaN in a group contribute 0 there.
Args:
fi_matrix: (N, L) per-residue FI matrix.
percentile: column-variance percentile defining HV (default 80).
Returns:
np.ndarray (N,) float64 contrast score per sequence.
"""
if fi_matrix.ndim != 2:
raise ValueError("fi_matrix must be 2-D (N, L)")
N = fi_matrix.shape[0]
hv = high_variance_mask(fi_matrix, percentile=percentile)
lv = ~hv
if hv.any():
mean_hv = np.nanmean(fi_matrix[:, hv], axis=1)
else:
mean_hv = np.zeros(N, dtype=np.float64)
if lv.any():
mean_lv = np.nanmean(fi_matrix[:, lv], axis=1)
else:
mean_lv = np.zeros(N, dtype=np.float64)
mean_hv = np.nan_to_num(mean_hv, nan=0.0)
mean_lv = np.nan_to_num(mean_lv, nan=0.0)
return (mean_hv - mean_lv).astype(np.float64, copy=False)