File size: 2,131 Bytes
ccbe063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Sequence-level frustration contrast scores.

contrast_hvlv(seq) = mean_FI(high-variance positions) - mean_FI(low-variance positions)

High-variance positions are MSA columns whose across-sequence FI variance is
at or above the (default 80th) percentile.
"""
from __future__ import annotations

import numpy as np


def high_variance_mask(fi_matrix: np.ndarray,
                       percentile: float = 80.0) -> np.ndarray:
    """Boolean (L,) mask of high-variance MSA columns.

    Args:
        fi_matrix:  (N, L) per-residue FI; may contain NaN.
        percentile: column-variance percentile threshold (default 80).

    Returns:
        boolean array of length L (True = high-variance).
    """
    if fi_matrix.ndim != 2:
        raise ValueError("fi_matrix must be 2-D (N, L)")
    col_var = np.nanvar(fi_matrix, axis=0)
    if np.all(np.isnan(col_var)):
        return np.zeros(fi_matrix.shape[1], dtype=bool)
    thresh = np.nanpercentile(col_var, percentile)
    return col_var >= thresh


def contrast_hvlv(fi_matrix: np.ndarray,
                  percentile: float = 80.0) -> np.ndarray:
    """Per-sequence high-variance / low-variance FI contrast.

    score[i] = mean_FI_over_HV_cols(seq_i) - mean_FI_over_LV_cols(seq_i)

    NaN-safe: sequences with all-NaN in a group contribute 0 there.

    Args:
        fi_matrix:  (N, L) per-residue FI matrix.
        percentile: column-variance percentile defining HV (default 80).

    Returns:
        np.ndarray (N,) float64 contrast score per sequence.
    """
    if fi_matrix.ndim != 2:
        raise ValueError("fi_matrix must be 2-D (N, L)")
    N = fi_matrix.shape[0]

    hv = high_variance_mask(fi_matrix, percentile=percentile)
    lv = ~hv

    if hv.any():
        mean_hv = np.nanmean(fi_matrix[:, hv], axis=1)
    else:
        mean_hv = np.zeros(N, dtype=np.float64)
    if lv.any():
        mean_lv = np.nanmean(fi_matrix[:, lv], axis=1)
    else:
        mean_lv = np.zeros(N, dtype=np.float64)

    mean_hv = np.nan_to_num(mean_hv, nan=0.0)
    mean_lv = np.nan_to_num(mean_lv, nan=0.0)
    return (mean_hv - mean_lv).astype(np.float64, copy=False)