"""Pharmacogenomic response prediction. Given the patient's gene variants and a candidate drug, predict expected response (responder / non-responder / adverse) and dose-modification recommendations using: - **CPIC guidelines** (Clinical Pharmacogenetics Implementation Consortium) cross-walked through PharmGKB - **gnomAD / ClinVar variant calls** for the patient's specific variant - **Disease-pathway interaction**: variant in same pathway as drug target is treated as "in-scope" with reduced confidence Bootstrap path uses Neo4j Gene→Drug edges with annotations from PharmGKB already ingested by raras-app. Phase-2: a learned response head over the HGT patient embedding × drug embedding. """ from __future__ import annotations import logging import os from typing import Optional from .types import PharmacogenSpec, PharmacogenAssessment logger = logging.getLogger("gemeo.pharmacogen") async def _safe_query(cypher: str, params: dict = None) -> list: try: from space_graph import _safe_query as q return await q(cypher, params or {}, timeout=10.0) except Exception as e: logger.debug(f"cypher failed: {e}") return [] async def _cpic_lookup(gene_symbol: str, drug_key: str) -> Optional[dict]: cypher = """ MATCH (g:Gene {symbol: $gene})-[r:AFFECTS_RESPONSE_TO|PHARMACOGENOMIC_OF]->(d:Drug) WHERE d.rxcui = $drug OR toLower(d.name) = toLower($drug) RETURN r.cpic_level AS cpic_level, r.recommendation AS recommendation, r.phenotype AS expected_phenotype, r.dose_modification AS dose_modification, r.evidence AS evidence, d.name AS drug_name LIMIT 1 """ rows = await _safe_query(cypher, {"gene": gene_symbol.upper(), "drug": drug_key}) return rows[0] if rows else None async def _pathway_overlap(gene_symbol: str, drug_key: str) -> Optional[dict]: """Indirect: drug acts on a pathway containing the patient's gene.""" cypher = """ MATCH (g:Gene {symbol: $gene})-[:IN_PATHWAY]->(pw:Pathway)<-[:ACTS_ON|TARGETS_PATHWAY]-(d:Drug) WHERE d.rxcui = $drug OR toLower(d.name) = toLower($drug) RETURN pw.name AS pathway, d.name AS drug_name LIMIT 1 """ rows = await _safe_query(cypher, {"gene": gene_symbol.upper(), "drug": drug_key}) return rows[0] if rows else None async def assess_pair(gene_symbol: str, variant: Optional[str], drug: dict) -> Optional[PharmacogenAssessment]: drug_key = drug.get("rxcui") or drug.get("name") if not drug_key: return None cpic = await _cpic_lookup(gene_symbol, drug_key) if cpic: return PharmacogenAssessment( gene=gene_symbol.upper(), variant=variant, drug=cpic.get("drug_name") or drug.get("name"), rxcui=drug.get("rxcui"), expected_phenotype=cpic.get("expected_phenotype") or "see CPIC", recommendation=cpic.get("recommendation") or "", dose_modification=cpic.get("dose_modification") or "", cpic_level=cpic.get("cpic_level") or "", evidence=cpic.get("evidence") or "", confidence=0.85 if cpic.get("cpic_level") in ("A", "B") else 0.65, source="cpic", ) pathway = await _pathway_overlap(gene_symbol, drug_key) if pathway: return PharmacogenAssessment( gene=gene_symbol.upper(), variant=variant, drug=pathway.get("drug_name") or drug.get("name"), rxcui=drug.get("rxcui"), expected_phenotype="indirect (pathway overlap)", recommendation=f"Monitor — drug acts on pathway containing {gene_symbol}", dose_modification="", cpic_level="", evidence=f"Shared pathway: {pathway.get('pathway')}", confidence=0.45, source="pathway", ) return None async def assess( *, genes: list, drug_candidates: list, ) -> PharmacogenSpec: """Cross all gene variants against all drug candidates.""" assessments = [] for g in genes or []: sym = (g.get("symbol") or g.get("gene") or "").upper() if not sym: continue var = g.get("variant") for d in drug_candidates or []: if not isinstance(d, dict): continue try: a = await assess_pair(sym, var, d) except Exception as e: logger.debug(f"pair assess failed: {e}") continue if a is not None: assessments.append(a) # rank by confidence assessments.sort(key=lambda a: a.confidence, reverse=True) return PharmacogenSpec( assessments=assessments, n_pairs=len(genes or []) * len(drug_candidates or []), n_actionable=len([a for a in assessments if a.cpic_level in ("A", "B")]), model="cpic_kg", )