File size: 7,138 Bytes
5c3cfae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""Pathway-aware gene similarity index for structured reward scoring.



Uses gseapy pathway libraries (KEGG + Reactome) to build binary pathway

membership vectors per gene, enabling cosine-similarity-based set scoring

instead of substring matching.



Mechanism comparison uses sentence-transformers for semantic similarity.

"""

from __future__ import annotations

import logging
from functools import lru_cache
from typing import Dict, List, Optional, Tuple

import numpy as np

logger = logging.getLogger(__name__)

_PATHWAY_SETS: Optional[Dict[str, List[str]]] = None
_PATHWAY_NAMES: Optional[List[str]] = None
_GENE_TO_PATHWAY_IDX: Optional[Dict[str, List[int]]] = None
_N_PATHWAYS: int = 0

_SENTENCE_MODEL = None


def _ensure_pathway_index() -> None:
    """Lazily build the inverted gene→pathway index on first use."""
    global _PATHWAY_SETS, _PATHWAY_NAMES, _GENE_TO_PATHWAY_IDX, _N_PATHWAYS

    if _PATHWAY_NAMES is not None:
        return

    try:
        import gseapy as gp
    except ImportError:
        logger.warning("gseapy not installed; pathway scoring will use fallback.")
        _PATHWAY_SETS = {}
        _PATHWAY_NAMES = []
        _GENE_TO_PATHWAY_IDX = {}
        _N_PATHWAYS = 0
        return

    combined: Dict[str, List[str]] = {}
    for lib_name in ("KEGG_2021_Human", "Reactome_2022"):
        try:
            combined.update(gp.get_library(lib_name))
        except Exception as exc:
            logger.warning("Failed to load %s: %s", lib_name, exc)

    _PATHWAY_SETS = combined
    _PATHWAY_NAMES = sorted(combined.keys())
    _N_PATHWAYS = len(_PATHWAY_NAMES)

    inv: Dict[str, List[int]] = {}
    for idx, pw_name in enumerate(_PATHWAY_NAMES):
        for gene in combined[pw_name]:
            gene_upper = gene.upper().strip()
            inv.setdefault(gene_upper, []).append(idx)

    _GENE_TO_PATHWAY_IDX = inv
    logger.info(
        "Pathway index built: %d pathways, %d genes indexed.",
        _N_PATHWAYS, len(inv),
    )


def _ensure_sentence_model():
    """Lazily load the sentence-transformer model."""
    global _SENTENCE_MODEL
    if _SENTENCE_MODEL is not None:
        return

    try:
        from sentence_transformers import SentenceTransformer
        _SENTENCE_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
    except ImportError:
        logger.warning(
            "sentence-transformers not installed; mechanism scoring will use fallback."
        )
        _SENTENCE_MODEL = None


def gene_vector(gene: str) -> np.ndarray:
    """L2-normalised binary pathway membership vector for *gene*."""
    _ensure_pathway_index()
    vec = np.zeros(_N_PATHWAYS, dtype=np.float32)
    indices = _GENE_TO_PATHWAY_IDX.get(gene.upper().strip(), [])
    if indices:
        vec[indices] = 1.0
        norm = np.linalg.norm(vec)
        if norm > 0:
            vec /= norm
    return vec


def pathway_similarity(g1: str, g2: str) -> float:
    """Cosine similarity between two genes in pathway space."""
    v1 = gene_vector(g1)
    v2 = gene_vector(g2)
    dot = float(np.dot(v1, v2))
    return max(0.0, min(1.0, dot))


def marker_set_score(

    predicted: List[str],

    truth: List[str],

    sigma: float = 0.3,

) -> float:
    """Pathway-weighted Gaussian set similarity for marker genes.



    For each true marker, finds the best-matching predicted gene by

    pathway cosine similarity, then applies a Gaussian kernel:

        score_i = exp(-d^2 / (2 * sigma^2))   where d = 1 - sim

    Returns the mean score over all true markers.

    """
    if not truth:
        return 0.0
    if not predicted:
        return 0.0

    _ensure_pathway_index()

    if _N_PATHWAYS == 0:
        return _fallback_marker_score(predicted, truth)

    pred_vecs = [gene_vector(g) for g in predicted]
    scores: List[float] = []

    for true_gene in truth:
        tv = gene_vector(true_gene)
        best_sim = 0.0
        for pv in pred_vecs:
            sim = float(np.dot(tv, pv))
            if sim > best_sim:
                best_sim = sim
        d = 1.0 - best_sim
        scores.append(float(np.exp(-(d ** 2) / (2.0 * sigma ** 2))))

    return sum(scores) / len(scores)


def _fallback_marker_score(predicted: List[str], truth: List[str]) -> float:
    """Exact-match fallback when pathway data is unavailable."""
    pred_set = {g.upper().strip() for g in predicted}
    hits = sum(1 for g in truth if g.upper().strip() in pred_set)
    return hits / len(truth) if truth else 0.0


def mechanism_set_score(predicted: List[str], truth: List[str]) -> float:
    """Sentence-transformer semantic similarity for mechanism strings.



    For each truth mechanism, finds the best-matching predicted mechanism

    by cosine similarity and returns the mean of best matches.

    """
    if not truth:
        return 0.0
    if not predicted:
        return 0.0

    _ensure_sentence_model()

    if _SENTENCE_MODEL is None:
        return _fallback_mechanism_score(predicted, truth)

    pred_embs = _SENTENCE_MODEL.encode(predicted, convert_to_numpy=True)
    truth_embs = _SENTENCE_MODEL.encode(truth, convert_to_numpy=True)

    pred_norms = pred_embs / (
        np.linalg.norm(pred_embs, axis=1, keepdims=True) + 1e-9
    )
    truth_norms = truth_embs / (
        np.linalg.norm(truth_embs, axis=1, keepdims=True) + 1e-9
    )

    sim_matrix = truth_norms @ pred_norms.T
    best_per_truth = sim_matrix.max(axis=1)
    return float(np.mean(np.clip(best_per_truth, 0.0, 1.0)))


def _fallback_mechanism_score(predicted: List[str], truth: List[str]) -> float:
    """Token-overlap fallback when sentence-transformers is unavailable."""
    scores: List[float] = []
    for t in truth:
        t_tokens = set(t.lower().split())
        best = 0.0
        for p in predicted:
            p_tokens = set(p.lower().split())
            union = t_tokens | p_tokens
            if union:
                overlap = len(t_tokens & p_tokens) / len(union)
                best = max(best, overlap)
        scores.append(best)
    return sum(scores) / len(scores) if scores else 0.0


def score_pathways(

    predicted: Dict[str, float],

    truth: Dict[str, float],

) -> float:
    """Score predicted pathway activations against ground truth.



    Uses normalised key matching with activity-level weighting.

    """
    if not truth:
        return 0.0
    if not predicted:
        return 0.0

    pred_norm = {k.lower().strip(): v for k, v in predicted.items()}
    total_weight = 0.0
    weighted_score = 0.0

    for pw, true_activity in truth.items():
        pw_key = pw.lower().strip()
        weight = true_activity
        total_weight += weight
        if pw_key in pred_norm:
            pred_activity = pred_norm[pw_key]
            diff = abs(pred_activity - true_activity)
            match_score = max(0.0, 1.0 - diff)
            weighted_score += weight * match_score

    return weighted_score / total_weight if total_weight > 0 else 0.0