| |
| import asyncio |
| from typing import Any, Dict, List |
|
|
| from mcp.arxiv import fetch_arxiv |
| from mcp.pubmed import fetch_pubmed |
| from mcp.nlp import extract_umls_concepts |
| from mcp.umls import lookup_umls |
| from mcp.umls_rel import fetch_relations |
| from mcp.openfda import fetch_drug_safety |
| from mcp.ncbi import search_gene, get_mesh_definition |
| from mcp.disgenet import disease_to_genes |
| from mcp.clinicaltrials import search_trials |
| from mcp.mygene import mygene |
| from mcp.opentargets import ot |
| from mcp.cbio import cbio |
| from mcp.openai_utils import ai_summarize, ai_qa |
| from mcp.gemini import gemini_summarize, gemini_qa |
| from mcp.embeddings import embed_texts, cluster_embeddings |
|
|
|
|
| def _get_llm(llm: str): |
| """ |
| Route summarization and QA to the chosen engine. |
| """ |
| if llm.lower() == "gemini": |
| return gemini_summarize, gemini_qa |
| return ai_summarize, ai_qa |
|
|
|
|
| async def orchestrate_search(query: str, llm: str = "openai") -> Dict[str, Any]: |
| """ |
| Fetch papers, extract concepts & relations, enrich data, |
| compute embeddings+clusters, and run LLM summary. |
| """ |
| |
| arxiv_task = fetch_arxiv(query) |
| pubmed_task = fetch_pubmed(query) |
| lit_results = await asyncio.gather(arxiv_task, pubmed_task, return_exceptions=True) |
| papers: List[Dict] = [] |
| for res in lit_results: |
| if isinstance(res, list): |
| papers.extend(res) |
|
|
| |
| blob = " ".join(p.get("summary", "") for p in papers) |
| umls = await extract_umls_concepts(blob) |
|
|
| |
| rel_tasks = [fetch_relations(c["cui"]) for c in umls] |
| umls_relations = await asyncio.gather(*rel_tasks, return_exceptions=True) |
|
|
| |
| names = [c["name"] for c in umls] |
| fda_tasks = [fetch_drug_safety(n) for n in names] |
| gene_task = search_gene(names[0]) if names else asyncio.sleep(0, result=[]) |
| mesh_task = get_mesh_definition(names[0]) if names else asyncio.sleep(0, result="") |
| dis_task = disease_to_genes(names[0]) if names else asyncio.sleep(0, result=[]) |
| trials_task = search_trials(query) |
| ot_task = ot.fetch(names[0]) if names else asyncio.sleep(0, result=[]) |
| cbio_task = cbio.fetch_variants(names[0]) if names else asyncio.sleep(0, result=[]) |
|
|
| |
| fda, gene, mesh, dis, trials, ot_assoc, variants = await asyncio.gather( |
| asyncio.gather(*fda_tasks, return_exceptions=True), |
| gene_task, mesh_task, dis_task, |
| trials_task, ot_task, cbio_task, |
| return_exceptions=False |
| ) |
|
|
| |
| summaries = [p.get("summary", "") for p in papers] |
| if summaries: |
| embeddings = await embed_texts(summaries) |
| clusters = await cluster_embeddings( |
| embeddings, n_clusters = max(2, min(10, len(embeddings)//2)) |
| ) |
| else: |
| embeddings, clusters = [], [] |
|
|
| |
| summarize_fn, _ = _get_llm(llm) |
| try: |
| ai_summary = await summarize_fn(blob) |
| except Exception: |
| ai_summary = "LLM summary failed." |
|
|
| return { |
| "papers": papers, |
| "umls": umls, |
| "umls_relations": umls_relations, |
| "drug_safety": fda, |
| "genes": [gene], |
| "mesh_defs": [mesh], |
| "gene_disease": dis, |
| "clinical_trials": trials, |
| "ot_associations": ot_assoc, |
| "variants": variants, |
| "embeddings": embeddings, |
| "clusters": clusters, |
| "ai_summary": ai_summary, |
| "llm_used": llm.lower() |
| } |
|
|
|
|
| async def answer_ai_question(question: str, context: str = "", llm: str = "openai") -> Dict[str, str]: |
| """ |
| Follow-up Q&A via chosen LLM. |
| """ |
| _, qa_fn = _get_llm(llm) |
| try: |
| ans = await qa_fn(question, context) |
| except Exception: |
| ans = "LLM follow-up failed." |
| return {"answer": ans} |
|
|