PharmAI_Navigator / tools.py
arittrabag's picture
Added main app files
1de0a51 verified
from typing import List, Dict, Any, Optional
import os
import uuid
import re
import base64
from schemas import EvidenceItem, EvidenceType
#Helper Functions
def _etype(name: str, default: EvidenceType) -> EvidenceType:
"""Return EvidenceType.<name> if it exists, else default (prevents breaking)."""
return getattr(EvidenceType, name, default)
def _short(s: str, n: int = 700) -> str:
return (s or "")[:n]
def _is_url(s: str) -> bool:
return isinstance(s, str) and s.startswith(("http://", "https://"))
# Tool 1: Tavily Web Search (existing, unchanged)
def tavily_search(query: str, max_results: int = 5) -> List[EvidenceItem]:
"""
Uses Tavily API to perform web search.
Returns structured evidence items.
"""
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
return [
EvidenceItem(
type=EvidenceType.OTHER,
source="tavily_disabled",
summary="Tavily API key not configured; search skipped.",
confidence=0.0,
)
]
try:
from tavily import TavilyClient
client = TavilyClient(api_key=api_key)
results = client.search(
query=query,
max_results=max_results,
include_raw_content=False,
)
evidence: List[EvidenceItem] = []
for r in results.get("results", []):
evidence.append(
EvidenceItem(
type=EvidenceType.LITERATURE,
source=r.get("url", "unknown"),
summary=r.get("content", "")[:500],
confidence=0.6,
raw=r,
)
)
return evidence
except Exception as e:
return [
EvidenceItem(
type=EvidenceType.OTHER,
source="tavily_error",
summary=f"Tavily search failed: {str(e)}",
confidence=0.0,
)
]
# Tool 2: Stub Evidence Generator (existing, unchanged)
def stub_evidence(query: str) -> List[EvidenceItem]:
"""
Deterministic fallback tool.
Useful for demos, offline mode, or testing agent logic.
"""
return [
EvidenceItem(
type=EvidenceType.OTHER,
source="stub_tool",
summary=f"Stub evidence generated for query: '{query}'. "
f"This indicates where real retrieval will plug in.",
confidence=0.2,
raw={
"id": str(uuid.uuid4()),
"note": "Replace with real retrieval later",
},
)
]
# Tool 3: Query Classifier (planner helper)
def classify_query(query: str) -> Dict[str, Any]:
"""
Lightweight classifier to help the agent decide
which tools (if any) are required.
"""
q = (query or "").lower()
needs_graph = any(k in q for k in ["diagram", "graph", "graphviz", "dot", "flow", "architecture", "arch", "draw"])
needs_trials = any(k in q for k in ["trial", "clinical", "phase", "nct", "primary endpoint", "secondary endpoint"])
needs_facts = any(k in q for k in ["fda", "approval", "label", "patent", "exclusivity", "pricing", "aria", "safety", "market"])
needs_entities = any(k in q for k in ["evaluate", "assess", "analyze", "repurpose", "for "])
return {
"needs_graph": needs_graph,
"needs_clinical_trials": needs_trials,
"needs_web_search": needs_facts or needs_trials,
"needs_entity_extraction": needs_entities,
}
# Tool 4: Entity Extraction (Drug / Indication)
def extract_entities(query: str) -> Dict[str, Optional[str]]:
"""
Minimal entity extractor for MVP.
"""
text = (query or "").strip()
m = re.search(
r"(evaluate|assess|analyze)\s+(?P<drug>.+?)\s+for\s+(?P<indication>.+)",
text,
re.IGNORECASE,
)
if m:
return {
"drug": m.group("drug").strip(),
"indication": m.group("indication").strip(),
}
return {"drug": None, "indication": None}
# Tool 5: Evidence Normalizer (dedupe + cleanup)
def normalize_evidence(evidence: List[EvidenceItem]) -> List[EvidenceItem]:
"""
Deduplicates evidence by source and trims noisy content.
"""
seen = set()
cleaned: List[EvidenceItem] = []
for e in evidence:
if e.source in seen:
continue
seen.add(e.source)
cleaned.append(
EvidenceItem(
type=e.type,
source=e.source,
summary=(e.summary or "")[:800],
confidence=e.confidence,
raw=None, # drop heavy payloads
)
)
return cleaned
# Tool 6: Graph Generation (Graphviz DOT only)
def generate_graph_dot(
title: str,
nodes: List[Dict[str, str]],
edges: List[Dict[str, str]],
rankdir: str = "LR",
) -> str:
"""
Generates Graphviz DOT code.
IMPORTANT: LLM must call this tool; never output DOT directly.
"""
safe_title = (title or "PharmAI Graph").replace('"', "'")
lines = [
"digraph G {",
f" rankdir={rankdir};",
' labelloc="t";',
' labeljust="c";',
f' label=<<B><FONT POINT-SIZE="28">{safe_title}</FONT></B>>;',
" node [shape=box, style=rounded];",
"",
]
for n in nodes or []:
nid = n.get("id")
lbl = (n.get("label") or nid).replace('"', "'")
if nid:
lines.append(f' {nid} [label="{lbl}"];')
lines.append("")
for e in edges or []:
src = e.get("from")
tgt = e.get("to")
lbl = e.get("label")
if src and tgt:
if lbl:
lines.append(f' {src} -> {tgt} [label="{lbl}"];')
else:
lines.append(f" {src} -> {tgt};")
lines.append("}")
return "\n".join(lines)
#Tool 7: ClinicalTrials search (lightweight, Tavily-based)
def clinicaltrials_search(drug: str, indication: str, max_results: int = 5) -> List[EvidenceItem]:
"""
MVP approach:
- Uses Tavily to target ClinicalTrials.gov / NCT IDs
- Returns EvidenceItems for trial links + snippets
"""
drug = (drug or "").strip()
indication = (indication or "").strip()
if not drug or not indication:
return [
EvidenceItem(
type=EvidenceType.OTHER,
source="clinicaltrials_search_invalid_input",
summary="Missing drug or indication for clinical trials search.",
confidence=0.0,
)
]
query = f'site:clinicaltrials.gov ("{drug}") ("{indication}") NCT'
ev = tavily_search(query=query, max_results=max_results)
trial_type = _etype("CLINICAL_TRIAL", EvidenceType.LITERATURE)
out: List[EvidenceItem] = []
for e in ev:
# only keep plausible CT.gov results if possible
if _is_url(e.source) and "clinicaltrials.gov" in e.source:
out.append(
EvidenceItem(
type=trial_type,
source=e.source,
summary=e.summary,
confidence=max(0.55, float(e.confidence or 0.55)),
raw=e.raw,
)
)
if out:
return out
# fallback: return whatever Tavily gave (still structured)
return [
EvidenceItem(
type=trial_type,
source=e.source,
summary=e.summary,
confidence=float(e.confidence or 0.4),
raw=e.raw,
)
for e in ev
]
#Tool 8 : DOT -> PNG
def render_dot_to_png_base64(dot: str) -> Dict[str, Any]:
"""
Converts DOT to PNG and returns base64 string.
- Uses python 'graphviz' package if available.
- If graphviz isn't installed in the environment, returns an error payload.
"""
dot = (dot or "").strip()
if not dot:
return {"ok": False, "error": "Empty DOT string"}
try:
from graphviz import Source # optional dependency
src = Source(dot)
png_bytes = src.pipe(format="png")
b64 = base64.b64encode(png_bytes).decode("utf-8")
return {"ok": True, "png_base64": b64}
except Exception as e:
return {
"ok": False,
"error": f"DOT->PNG render failed. Ensure `graphviz` Python package and system binaries are installed. Details: {str(e)}",
}
# Tool Registry (extended, backward compatible)
TOOL_REGISTRY: Dict[str, Any] = {
# existing
"web_search": tavily_search,
"stub_evidence": stub_evidence,
# new
"classify_query": classify_query,
"extract_entities": extract_entities,
"normalize_evidence": normalize_evidence,
"generate_graph_dot": generate_graph_dot,
"clinicaltrials_search": clinicaltrials_search,
"render_dot_to_png_base64": render_dot_to_png_base64
}