| """Pathway-aware gene similarity index for structured reward scoring. | |
| Uses gseapy pathway libraries (KEGG + Reactome) to build binary pathway | |
| membership vectors per gene, enabling cosine-similarity-based set scoring | |
| instead of substring matching. | |
| Mechanism comparison uses sentence-transformers for semantic similarity. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from functools import lru_cache | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| _PATHWAY_SETS: Optional[Dict[str, List[str]]] = None | |
| _PATHWAY_NAMES: Optional[List[str]] = None | |
| _GENE_TO_PATHWAY_IDX: Optional[Dict[str, List[int]]] = None | |
| _N_PATHWAYS: int = 0 | |
| _SENTENCE_MODEL = None | |
| def _ensure_pathway_index() -> None: | |
| """Lazily build the inverted gene→pathway index on first use.""" | |
| global _PATHWAY_SETS, _PATHWAY_NAMES, _GENE_TO_PATHWAY_IDX, _N_PATHWAYS | |
| if _PATHWAY_NAMES is not None: | |
| return | |
| try: | |
| import gseapy as gp | |
| except ImportError: | |
| logger.warning("gseapy not installed; pathway scoring will use fallback.") | |
| _PATHWAY_SETS = {} | |
| _PATHWAY_NAMES = [] | |
| _GENE_TO_PATHWAY_IDX = {} | |
| _N_PATHWAYS = 0 | |
| return | |
| combined: Dict[str, List[str]] = {} | |
| for lib_name in ("KEGG_2021_Human", "Reactome_2022"): | |
| try: | |
| combined.update(gp.get_library(lib_name)) | |
| except Exception as exc: | |
| logger.warning("Failed to load %s: %s", lib_name, exc) | |
| _PATHWAY_SETS = combined | |
| _PATHWAY_NAMES = sorted(combined.keys()) | |
| _N_PATHWAYS = len(_PATHWAY_NAMES) | |
| inv: Dict[str, List[int]] = {} | |
| for idx, pw_name in enumerate(_PATHWAY_NAMES): | |
| for gene in combined[pw_name]: | |
| gene_upper = gene.upper().strip() | |
| inv.setdefault(gene_upper, []).append(idx) | |
| _GENE_TO_PATHWAY_IDX = inv | |
| logger.info( | |
| "Pathway index built: %d pathways, %d genes indexed.", | |
| _N_PATHWAYS, len(inv), | |
| ) | |
| def _ensure_sentence_model(): | |
| """Lazily load the sentence-transformer model.""" | |
| global _SENTENCE_MODEL | |
| if _SENTENCE_MODEL is not None: | |
| return | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| _SENTENCE_MODEL = SentenceTransformer("all-MiniLM-L6-v2") | |
| except ImportError: | |
| logger.warning( | |
| "sentence-transformers not installed; mechanism scoring will use fallback." | |
| ) | |
| _SENTENCE_MODEL = None | |
| def gene_vector(gene: str) -> np.ndarray: | |
| """L2-normalised binary pathway membership vector for *gene*.""" | |
| _ensure_pathway_index() | |
| vec = np.zeros(_N_PATHWAYS, dtype=np.float32) | |
| indices = _GENE_TO_PATHWAY_IDX.get(gene.upper().strip(), []) | |
| if indices: | |
| vec[indices] = 1.0 | |
| norm = np.linalg.norm(vec) | |
| if norm > 0: | |
| vec /= norm | |
| return vec | |
| def pathway_similarity(g1: str, g2: str) -> float: | |
| """Cosine similarity between two genes in pathway space.""" | |
| v1 = gene_vector(g1) | |
| v2 = gene_vector(g2) | |
| dot = float(np.dot(v1, v2)) | |
| return max(0.0, min(1.0, dot)) | |
| def marker_set_score( | |
| predicted: List[str], | |
| truth: List[str], | |
| sigma: float = 0.3, | |
| ) -> float: | |
| """Pathway-weighted Gaussian set similarity for marker genes. | |
| For each true marker, finds the best-matching predicted gene by | |
| pathway cosine similarity, then applies a Gaussian kernel: | |
| score_i = exp(-d^2 / (2 * sigma^2)) where d = 1 - sim | |
| Returns the mean score over all true markers. | |
| """ | |
| if not truth: | |
| return 0.0 | |
| if not predicted: | |
| return 0.0 | |
| _ensure_pathway_index() | |
| if _N_PATHWAYS == 0: | |
| return _fallback_marker_score(predicted, truth) | |
| pred_vecs = [gene_vector(g) for g in predicted] | |
| scores: List[float] = [] | |
| for true_gene in truth: | |
| tv = gene_vector(true_gene) | |
| best_sim = 0.0 | |
| for pv in pred_vecs: | |
| sim = float(np.dot(tv, pv)) | |
| if sim > best_sim: | |
| best_sim = sim | |
| d = 1.0 - best_sim | |
| scores.append(float(np.exp(-(d ** 2) / (2.0 * sigma ** 2)))) | |
| return sum(scores) / len(scores) | |
| def _fallback_marker_score(predicted: List[str], truth: List[str]) -> float: | |
| """Exact-match fallback when pathway data is unavailable.""" | |
| pred_set = {g.upper().strip() for g in predicted} | |
| hits = sum(1 for g in truth if g.upper().strip() in pred_set) | |
| return hits / len(truth) if truth else 0.0 | |
| def mechanism_set_score(predicted: List[str], truth: List[str]) -> float: | |
| """Sentence-transformer semantic similarity for mechanism strings. | |
| For each truth mechanism, finds the best-matching predicted mechanism | |
| by cosine similarity and returns the mean of best matches. | |
| """ | |
| if not truth: | |
| return 0.0 | |
| if not predicted: | |
| return 0.0 | |
| _ensure_sentence_model() | |
| if _SENTENCE_MODEL is None: | |
| return _fallback_mechanism_score(predicted, truth) | |
| pred_embs = _SENTENCE_MODEL.encode(predicted, convert_to_numpy=True) | |
| truth_embs = _SENTENCE_MODEL.encode(truth, convert_to_numpy=True) | |
| pred_norms = pred_embs / ( | |
| np.linalg.norm(pred_embs, axis=1, keepdims=True) + 1e-9 | |
| ) | |
| truth_norms = truth_embs / ( | |
| np.linalg.norm(truth_embs, axis=1, keepdims=True) + 1e-9 | |
| ) | |
| sim_matrix = truth_norms @ pred_norms.T | |
| best_per_truth = sim_matrix.max(axis=1) | |
| return float(np.mean(np.clip(best_per_truth, 0.0, 1.0))) | |
| def _fallback_mechanism_score(predicted: List[str], truth: List[str]) -> float: | |
| """Token-overlap fallback when sentence-transformers is unavailable.""" | |
| scores: List[float] = [] | |
| for t in truth: | |
| t_tokens = set(t.lower().split()) | |
| best = 0.0 | |
| for p in predicted: | |
| p_tokens = set(p.lower().split()) | |
| union = t_tokens | p_tokens | |
| if union: | |
| overlap = len(t_tokens & p_tokens) / len(union) | |
| best = max(best, overlap) | |
| scores.append(best) | |
| return sum(scores) / len(scores) if scores else 0.0 | |
| def score_pathways( | |
| predicted: Dict[str, float], | |
| truth: Dict[str, float], | |
| ) -> float: | |
| """Score predicted pathway activations against ground truth. | |
| Uses normalised key matching with activity-level weighting. | |
| """ | |
| if not truth: | |
| return 0.0 | |
| if not predicted: | |
| return 0.0 | |
| pred_norm = {k.lower().strip(): v for k, v in predicted.items()} | |
| total_weight = 0.0 | |
| weighted_score = 0.0 | |
| for pw, true_activity in truth.items(): | |
| pw_key = pw.lower().strip() | |
| weight = true_activity | |
| total_weight += weight | |
| if pw_key in pred_norm: | |
| pred_activity = pred_norm[pw_key] | |
| diff = abs(pred_activity - true_activity) | |
| match_score = max(0.0, 1.0 - diff) | |
| weighted_score += weight * match_score | |
| return weighted_score / total_weight if total_weight > 0 else 0.0 | |