from __future__ import annotations import os import json import re from typing import Any, Dict, List, Optional from dotenv import load_dotenv from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.tools import tool from langgraph.graph import StateGraph, START, END, MessagesState from langgraph.prebuilt import ToolNode, tools_condition from tools import ( tavily_search, stub_evidence, classify_query, extract_entities, normalize_evidence, generate_graph_dot, clinicaltrials_search, render_dot_to_png_base64 ) # Load environment variables load_dotenv() # ----------------------------- # LangChain Tool Wrappers # ----------------------------- @tool("web_search") def web_search_tool(query: str, max_results: int = 5) -> List[Dict[str, Any]]: """Web search using Tavily. Returns a list of evidence dicts.""" ev = tavily_search(query=query, max_results=max_results) return [e.model_dump() for e in ev] @tool("stub_evidence") def stub_evidence_tool(query: str) -> List[Dict[str, Any]]: """Deterministic fallback evidence tool (offline/demo).""" ev = stub_evidence(query=query) return [e.model_dump() for e in ev] @tool("classify_query") def classify_query_tool(query: str) -> Dict[str, Any]: """Classify query to decide which tools are needed.""" return classify_query(query) @tool("extract_entities") def extract_entities_tool(query: str) -> Dict[str, Optional[str]]: """Extract drug and indication from query.""" return extract_entities(query) @tool("normalize_evidence") def normalize_evidence_tool(evidence: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Dedupe and clean evidence.""" return normalize_evidence(evidence) @tool("generate_graph_dot") def generate_graph_dot_tool( title: str, nodes: List[Dict[str, str]], edges: List[Dict[str, str]], rankdir: str = "LR", ) -> str: """ Generate Graphviz DOT. IMPORTANT: Use this tool instead of writing DOT directly. """ return generate_graph_dot( title=title, nodes=nodes, edges=edges, rankdir=rankdir, ) @tool("clinicaltrials_search") def clinicaltrials_search_tool(drug: str, indication: str, max_results: int = 5) -> List[Dict[str, Any]]: """Search ClinicalTrials.gov (Tavily-based MVP).""" ev = clinicaltrials_search(drug=drug, indication=indication, max_results=max_results) return [e.model_dump() for e in ev] @tool("render_dot_to_png_base64") def render_dot_to_png_base64_tool(dot: str) -> Dict[str, Any]: """Render DOT to PNG (base64). Optional dependency on graphviz.""" return render_dot_to_png_base64(dot) TOOLS = [ web_search_tool, stub_evidence_tool, classify_query_tool, extract_entities_tool, normalize_evidence_tool, generate_graph_dot_tool, clinicaltrials_search_tool, render_dot_to_png_base64_tool ] # ----------------------------- # LangGraph State # ----------------------------- class PharmAIState(MessagesState): session_id: Optional[str] user_query: str decision_brief: str citations: List[str] confidence_score: float tool_loops: int # safety counter diagram_png_base64: Optional[str] # <-- add diagram_dot: Optional[str] # <-- optional intent: str # "simple" | "diligence" | "diagram" # ----------------------------- # Guardrails + Prompts # ----------------------------- SYSTEM_PROMPT = """You are PharmAI Navigator, an evidence-grounded diligence assistant for drug/asset evaluation. Your job: Turn a query like "Assess {Drug} for {Indication}" into a decision-grade brief OR structured output. CRITICAL TOOL USAGE RULES: - If the user asks for a diagram, flow, architecture, graph, visualization, or Graphviz: → You MUST call `generate_graph_dot`. → You MUST NOT write Graphviz DOT directly in your response. → If the user asks for an image/PNG, call `render_dot_to_png_base64` AFTER you get DOT. - If the user asks for trials / phases / NCT IDs / endpoints: → Prefer calling `extract_entities` then `clinicaltrials_search`. - If the user asks for factual claims (approvals, safety, pricing, patents, market): → Prefer calling `web_search`. Guardrails (STRICT): - Do NOT invent specific facts (approval dates, trial names, endpoints, statistics, patent expiry). - Any concrete number/date/claim MUST be supported by tool evidence. - If evidence is insufficient, clearly list Evidence Gaps. - Be concise, structured, and decision-oriented. - Avoid medical advice; present as diligence/analysis. Simple Query Rule (CRITICAL): - If the user asks a simple definitional question ("what is", "define", "explain") and you can answer without external verification, do NOT call tools and respond directly. - Only use tools when you need current/specific data (trials, approvals, patents, market data). Citations policy: - The final response's "Citations" section is handled by the system. - Do NOT create your own citation list. """ FINAL_PROMPT = """Write the FINAL decision brief with these sections: 1) Executive Recommendation (1–2 lines) 2) Scientific Rationale (bullets) 3) Clinical Evidence Snapshot (bullets) 4) IP / Exclusivity Quick View (bullets) 5) Market / SoC Snapshot (bullets) 6) Key Risks + Next Actions (bullets) Rules: - If evidence is insufficient, include "Evidence Gaps" with bullets. - Do NOT add a citations section yourself; the system will append it. Return plain text only. """ # Placeholder detection to avoid wasting tokens on "Drug X / Indication Y" PLACEHOLDER_PATTERNS = [ r"\bdrug\s*x\b", r"\bindication\s*y\b", r"\bdrug\s*name\b", r"\bindication\s*name\b", ] def _looks_like_placeholder(q: str) -> bool: ql = (q or "").strip().lower() return any(re.search(p, ql) for p in PLACEHOLDER_PATTERNS) def _build_model() -> ChatAnthropic: model_name = os.getenv("ANTHROPIC_MODEL", "claude-3-7-sonnet-latest") return ChatAnthropic( model=model_name, temperature=0.2, max_tokens=10000, timeout=120, streaming=False, stop=None ).bind_tools(TOOLS) # Safety cap to avoid endless tool loops MAX_TOOL_LOOPS = int(os.getenv("MAX_TOOL_LOOPS", "4")) def llm_call(state: PharmAIState) -> Dict[str, Any]: """ Calls Claude with tool schemas attached. Returns new messages to append into state["messages"]. """ llm = _build_model() messages: List[BaseMessage] = state["messages"] if not messages or not isinstance(messages[0], SystemMessage): messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages tool_loops = state.get("tool_loops", 0) if tool_loops >= MAX_TOOL_LOOPS: # Stop tool-calling loop and force synthesis stop_msg = HumanMessage( content=( "Stop calling tools now. Proceed to final synthesis using what you already have. " "If evidence is insufficient, clearly list Evidence Gaps." ) ) messages = messages + [stop_msg] resp = llm.invoke(messages) return {"messages": [resp]} # ----------------------------- # Citations extraction (tool-only) # ----------------------------- def _clean_url(u: str) -> str: return u.strip().strip("),.]}\"'") def _extract_citations_from_messages(messages: List[BaseMessage]) -> List[str]: """ Tool-only citation extraction (single source of truth): - ONLY reads ToolMessage contents (actual tool outputs). - If tool output is JSON (list/dict), pull `source` fields. - Fallback: regex URL extraction from tool text. """ citations: List[str] = [] url_re = re.compile(r"https?://[^\s\]\)\}\",']+") for m in messages: if not isinstance(m, ToolMessage): continue content = getattr(m, "content", None) if not content: continue if isinstance(content, str): parsed = None try: parsed = json.loads(content) except Exception: parsed = None if isinstance(parsed, list): for item in parsed: if isinstance(item, dict): src = item.get("source") if isinstance(src, str) and src.startswith(("http://", "https://")): citations.append(_clean_url(src)) elif isinstance(parsed, dict): src = parsed.get("source") if isinstance(src, str) and src.startswith(("http://", "https://")): citations.append(_clean_url(src)) for u in url_re.findall(content): citations.append(_clean_url(u)) # De-duplicate seen = set() out = [] for c in citations: # drop clearly broken/truncated URLs if len(c) < 12: continue if c not in seen: seen.add(c) out.append(c) return out def _append_citations_section(brief_text: str, citations: List[str]) -> str: """ Enforces "single source of truth": - Removes any existing 'Citations' section the model may have produced - Appends citations derived from tool outputs only """ text = (brief_text or "").strip() # Remove any model-generated citations section (best-effort) # (handles '## Citations' or 'Citations' headers) text = re.split(r"\n#{1,3}\s*Citations\s*\n|\nCitations\s*\n", text, maxsplit=1)[0].rstrip() if citations: lines = ["", "## Citations"] for i, c in enumerate(citations, 1): lines.append(f"{i}. {c}") text = text + "\n" + "\n".join(lines) else: text = text + "\n\n## Citations\n- (No external sources retrieved.)" return text def capture_diagram(state: PharmAIState) -> Dict[str, Any]: # Find the last ToolMessage (most recent tool output) last_tool = None for m in reversed(state["messages"]): if isinstance(m, ToolMessage): last_tool = m break if not last_tool: return {} tool_name = getattr(last_tool, "name", "") or "" content = getattr(last_tool, "content", "") # If your render tool returns base64 string directly if tool_name == "render_dot_to_png_base64": return {"diagram_png_base64": content} # If your generate_graph_dot returns dot string if tool_name == "generate_graph_dot": return {"diagram_dot": content} return {} def route_after_tools(state: PharmAIState) -> str: # If we already have the final diagram artifact, stop. if state.get("diagram_png_base64"): return END return "bump_tool_loop" def preprocess(state: PharmAIState) -> Dict[str, Any]: q = (state.get("user_query") or "").strip().lower() if any(k in q for k in ["diagram", "flowchart", "architecture", "graphviz", "dot", "draw"]): return {"intent": "diagram"} if re.match(r"^(what is|define|explain)\b", q) and len(q) < 120: return {"intent": "simple"} return {"intent": "diligence"} def route_after_llm(state: PharmAIState): # If query is simple, never call tools/synthesize if state.get("intent") == "simple": return "end_simple" # If the model asked for tools, go tools last = state["messages"][-1] if getattr(last, "tool_calls", None): return "tools" return "synthesize" def end_simple(state: PharmAIState) -> Dict[str, Any]: # Return the last assistant content as the final answer last = state["messages"][-1] text = getattr(last, "content", "") if isinstance(getattr(last, "content", ""), str) else str(getattr(last, "content", "")) return {"decision_brief": text, "citations": []} # ----------------------------- # Final Synthesis Node # ----------------------------- def synthesize(state: PharmAIState) -> Dict[str, Any]: # Fast guardrail: placeholders -> short response without tool burn uq = state.get("user_query", "") if _looks_like_placeholder(uq): brief = ( "# FINAL DECISION BRIEF\n\n" "I need the **actual drug name** and **specific indication** to perform diligence.\n\n" "## Evidence Gaps\n" "- Drug name (e.g., semaglutide)\n" "- Indication (e.g., obesity)\n" "- Trial/program context (if any)\n" ) return { "decision_brief": _append_citations_section(brief, []), "citations": [], "messages": [HumanMessage(content="(placeholder query detected; returned guardrail response)")], } llm = _build_model() messages: List[BaseMessage] = state["messages"] messages = messages + [HumanMessage(content=FINAL_PROMPT)] resp = llm.invoke(messages) tool_citations = _extract_citations_from_messages(state["messages"]) brief_text = resp.content if isinstance(resp.content, str) else str(resp.content) brief_text = _append_citations_section(brief_text, tool_citations) return { "decision_brief": brief_text, "citations": tool_citations, "messages": [resp], } # ----------------------------- # Build + Compile Graph # ----------------------------- def build_graph(): """ Graph with preprocessing and smart routing. """ g = StateGraph(PharmAIState) g.add_node("preprocess", preprocess) g.add_node("llm_call", llm_call) g.add_node("tools", ToolNode(TOOLS)) g.add_node("capture_diagram", capture_diagram) g.add_node("bump_tool_loop", lambda s: {"tool_loops": s.get("tool_loops", 0) + 1}) g.add_node("synthesize", synthesize) g.add_node("end_simple", end_simple) g.add_edge(START, "preprocess") g.add_edge("preprocess", "llm_call") # After LLM: route based on intent and tool calls g.add_conditional_edges( "llm_call", route_after_llm, { "tools": "tools", "synthesize": "synthesize", "end_simple": "end_simple", }, ) # After tools: capture diagram data g.add_edge("tools", "capture_diagram") # After capture: check if we should stop (diagram complete) or continue g.add_conditional_edges( "capture_diagram", route_after_tools, { END: END, # Stop if diagram is complete "bump_tool_loop": "bump_tool_loop", # Continue otherwise }, ) g.add_edge("bump_tool_loop", "llm_call") g.add_edge("end_simple", END) g.add_edge("synthesize", END) return g.compile() # ----------------------------- # Test execution # ----------------------------- if __name__ == "__main__": print("Building PharmAI Navigator graph...") graph = build_graph() print("Graph compiled successfully!") # Test query designed to trigger generate_graph_dot tool #test_query = "Assess semaglutide for obesity" #test_query = "Assess donanemab for early Alzheimer’s disease. Retrieve key clinical trials, summarize efficacy and safety outcomes, normalize the evidence, and generate a system architecture graph showing how PharmAI Navigator evaluates this asset." #test_query = "Create a DOT graph showing the relationship between Drug, Indication, Clinical Trials, FDA Approval, and Market Launch and render it as png" test_query = "What is pembrolizumab?" print(f"\nRunning test query: {test_query}") result = graph.invoke({ "messages": [HumanMessage(content=test_query)], "user_query": test_query, "tool_loops": 0, }) print("\n" + "=" * 60) print("OUTPUT:") print("=" * 60) print(result.get("decision_brief", "No output")) print("\n" + "=" * 60) print("CITATIONS (tool-only):") print("=" * 60) for i, citation in enumerate(result.get("citations", []), 1): print(f"{i}. {citation}")