Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import json | |
| import re | |
| from typing import Any, Dict, List, Optional | |
| from dotenv import load_dotenv | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage | |
| from langchain_core.tools import tool | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from tools import ( | |
| tavily_search, | |
| stub_evidence, | |
| classify_query, | |
| extract_entities, | |
| normalize_evidence, | |
| generate_graph_dot, | |
| clinicaltrials_search, | |
| render_dot_to_png_base64 | |
| ) | |
| # Load environment variables | |
| load_dotenv() | |
| # ----------------------------- | |
| # LangChain Tool Wrappers | |
| # ----------------------------- | |
| def web_search_tool(query: str, max_results: int = 5) -> List[Dict[str, Any]]: | |
| """Web search using Tavily. Returns a list of evidence dicts.""" | |
| ev = tavily_search(query=query, max_results=max_results) | |
| return [e.model_dump() for e in ev] | |
| def stub_evidence_tool(query: str) -> List[Dict[str, Any]]: | |
| """Deterministic fallback evidence tool (offline/demo).""" | |
| ev = stub_evidence(query=query) | |
| return [e.model_dump() for e in ev] | |
| def classify_query_tool(query: str) -> Dict[str, Any]: | |
| """Classify query to decide which tools are needed.""" | |
| return classify_query(query) | |
| def extract_entities_tool(query: str) -> Dict[str, Optional[str]]: | |
| """Extract drug and indication from query.""" | |
| return extract_entities(query) | |
| def normalize_evidence_tool(evidence: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Dedupe and clean evidence.""" | |
| return normalize_evidence(evidence) | |
| def generate_graph_dot_tool( | |
| title: str, | |
| nodes: List[Dict[str, str]], | |
| edges: List[Dict[str, str]], | |
| rankdir: str = "LR", | |
| ) -> str: | |
| """ | |
| Generate Graphviz DOT. | |
| IMPORTANT: Use this tool instead of writing DOT directly. | |
| """ | |
| return generate_graph_dot( | |
| title=title, | |
| nodes=nodes, | |
| edges=edges, | |
| rankdir=rankdir, | |
| ) | |
| def clinicaltrials_search_tool(drug: str, indication: str, max_results: int = 5) -> List[Dict[str, Any]]: | |
| """Search ClinicalTrials.gov (Tavily-based MVP).""" | |
| ev = clinicaltrials_search(drug=drug, indication=indication, max_results=max_results) | |
| return [e.model_dump() for e in ev] | |
| def render_dot_to_png_base64_tool(dot: str) -> Dict[str, Any]: | |
| """Render DOT to PNG (base64). Optional dependency on graphviz.""" | |
| return render_dot_to_png_base64(dot) | |
| TOOLS = [ | |
| web_search_tool, | |
| stub_evidence_tool, | |
| classify_query_tool, | |
| extract_entities_tool, | |
| normalize_evidence_tool, | |
| generate_graph_dot_tool, | |
| clinicaltrials_search_tool, | |
| render_dot_to_png_base64_tool | |
| ] | |
| # ----------------------------- | |
| # LangGraph State | |
| # ----------------------------- | |
| class PharmAIState(MessagesState): | |
| session_id: Optional[str] | |
| user_query: str | |
| decision_brief: str | |
| citations: List[str] | |
| confidence_score: float | |
| tool_loops: int # safety counter | |
| diagram_png_base64: Optional[str] # <-- add | |
| diagram_dot: Optional[str] # <-- optional | |
| intent: str # "simple" | "diligence" | "diagram" | |
| # ----------------------------- | |
| # Guardrails + Prompts | |
| # ----------------------------- | |
| SYSTEM_PROMPT = """You are PharmAI Navigator, an evidence-grounded diligence assistant for drug/asset evaluation. | |
| Your job: | |
| Turn a query like "Assess {Drug} for {Indication}" into a decision-grade brief OR structured output. | |
| CRITICAL TOOL USAGE RULES: | |
| - If the user asks for a diagram, flow, architecture, graph, visualization, or Graphviz: | |
| → You MUST call `generate_graph_dot`. | |
| → You MUST NOT write Graphviz DOT directly in your response. | |
| → If the user asks for an image/PNG, call `render_dot_to_png_base64` AFTER you get DOT. | |
| - If the user asks for trials / phases / NCT IDs / endpoints: | |
| → Prefer calling `extract_entities` then `clinicaltrials_search`. | |
| - If the user asks for factual claims (approvals, safety, pricing, patents, market): | |
| → Prefer calling `web_search`. | |
| Guardrails (STRICT): | |
| - Do NOT invent specific facts (approval dates, trial names, endpoints, statistics, patent expiry). | |
| - Any concrete number/date/claim MUST be supported by tool evidence. | |
| - If evidence is insufficient, clearly list Evidence Gaps. | |
| - Be concise, structured, and decision-oriented. | |
| - Avoid medical advice; present as diligence/analysis. | |
| Simple Query Rule (CRITICAL): | |
| - If the user asks a simple definitional question ("what is", "define", "explain") and you can answer without external verification, do NOT call tools and respond directly. | |
| - Only use tools when you need current/specific data (trials, approvals, patents, market data). | |
| Citations policy: | |
| - The final response's "Citations" section is handled by the system. | |
| - Do NOT create your own citation list. | |
| """ | |
| FINAL_PROMPT = """Write the FINAL decision brief with these sections: | |
| 1) Executive Recommendation (1–2 lines) | |
| 2) Scientific Rationale (bullets) | |
| 3) Clinical Evidence Snapshot (bullets) | |
| 4) IP / Exclusivity Quick View (bullets) | |
| 5) Market / SoC Snapshot (bullets) | |
| 6) Key Risks + Next Actions (bullets) | |
| Rules: | |
| - If evidence is insufficient, include "Evidence Gaps" with bullets. | |
| - Do NOT add a citations section yourself; the system will append it. | |
| Return plain text only. | |
| """ | |
| # Placeholder detection to avoid wasting tokens on "Drug X / Indication Y" | |
| PLACEHOLDER_PATTERNS = [ | |
| r"\bdrug\s*x\b", | |
| r"\bindication\s*y\b", | |
| r"\bdrug\s*name\b", | |
| r"\bindication\s*name\b", | |
| ] | |
| def _looks_like_placeholder(q: str) -> bool: | |
| ql = (q or "").strip().lower() | |
| return any(re.search(p, ql) for p in PLACEHOLDER_PATTERNS) | |
| def _build_model() -> ChatAnthropic: | |
| model_name = os.getenv("ANTHROPIC_MODEL", "claude-3-7-sonnet-latest") | |
| return ChatAnthropic( | |
| model=model_name, | |
| temperature=0.2, | |
| max_tokens=10000, | |
| timeout=120, | |
| streaming=False, | |
| stop=None | |
| ).bind_tools(TOOLS) | |
| # Safety cap to avoid endless tool loops | |
| MAX_TOOL_LOOPS = int(os.getenv("MAX_TOOL_LOOPS", "4")) | |
| def llm_call(state: PharmAIState) -> Dict[str, Any]: | |
| """ | |
| Calls Claude with tool schemas attached. | |
| Returns new messages to append into state["messages"]. | |
| """ | |
| llm = _build_model() | |
| messages: List[BaseMessage] = state["messages"] | |
| if not messages or not isinstance(messages[0], SystemMessage): | |
| messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages | |
| tool_loops = state.get("tool_loops", 0) | |
| if tool_loops >= MAX_TOOL_LOOPS: | |
| # Stop tool-calling loop and force synthesis | |
| stop_msg = HumanMessage( | |
| content=( | |
| "Stop calling tools now. Proceed to final synthesis using what you already have. " | |
| "If evidence is insufficient, clearly list Evidence Gaps." | |
| ) | |
| ) | |
| messages = messages + [stop_msg] | |
| resp = llm.invoke(messages) | |
| return {"messages": [resp]} | |
| # ----------------------------- | |
| # Citations extraction (tool-only) | |
| # ----------------------------- | |
| def _clean_url(u: str) -> str: | |
| return u.strip().strip("),.]}\"'") | |
| def _extract_citations_from_messages(messages: List[BaseMessage]) -> List[str]: | |
| """ | |
| Tool-only citation extraction (single source of truth): | |
| - ONLY reads ToolMessage contents (actual tool outputs). | |
| - If tool output is JSON (list/dict), pull `source` fields. | |
| - Fallback: regex URL extraction from tool text. | |
| """ | |
| citations: List[str] = [] | |
| url_re = re.compile(r"https?://[^\s\]\)\}\",']+") | |
| for m in messages: | |
| if not isinstance(m, ToolMessage): | |
| continue | |
| content = getattr(m, "content", None) | |
| if not content: | |
| continue | |
| if isinstance(content, str): | |
| parsed = None | |
| try: | |
| parsed = json.loads(content) | |
| except Exception: | |
| parsed = None | |
| if isinstance(parsed, list): | |
| for item in parsed: | |
| if isinstance(item, dict): | |
| src = item.get("source") | |
| if isinstance(src, str) and src.startswith(("http://", "https://")): | |
| citations.append(_clean_url(src)) | |
| elif isinstance(parsed, dict): | |
| src = parsed.get("source") | |
| if isinstance(src, str) and src.startswith(("http://", "https://")): | |
| citations.append(_clean_url(src)) | |
| for u in url_re.findall(content): | |
| citations.append(_clean_url(u)) | |
| # De-duplicate | |
| seen = set() | |
| out = [] | |
| for c in citations: | |
| # drop clearly broken/truncated URLs | |
| if len(c) < 12: | |
| continue | |
| if c not in seen: | |
| seen.add(c) | |
| out.append(c) | |
| return out | |
| def _append_citations_section(brief_text: str, citations: List[str]) -> str: | |
| """ | |
| Enforces "single source of truth": | |
| - Removes any existing 'Citations' section the model may have produced | |
| - Appends citations derived from tool outputs only | |
| """ | |
| text = (brief_text or "").strip() | |
| # Remove any model-generated citations section (best-effort) | |
| # (handles '## Citations' or 'Citations' headers) | |
| text = re.split(r"\n#{1,3}\s*Citations\s*\n|\nCitations\s*\n", text, maxsplit=1)[0].rstrip() | |
| if citations: | |
| lines = ["", "## Citations"] | |
| for i, c in enumerate(citations, 1): | |
| lines.append(f"{i}. {c}") | |
| text = text + "\n" + "\n".join(lines) | |
| else: | |
| text = text + "\n\n## Citations\n- (No external sources retrieved.)" | |
| return text | |
| def capture_diagram(state: PharmAIState) -> Dict[str, Any]: | |
| # Find the last ToolMessage (most recent tool output) | |
| last_tool = None | |
| for m in reversed(state["messages"]): | |
| if isinstance(m, ToolMessage): | |
| last_tool = m | |
| break | |
| if not last_tool: | |
| return {} | |
| tool_name = getattr(last_tool, "name", "") or "" | |
| content = getattr(last_tool, "content", "") | |
| # If your render tool returns base64 string directly | |
| if tool_name == "render_dot_to_png_base64": | |
| return {"diagram_png_base64": content} | |
| # If your generate_graph_dot returns dot string | |
| if tool_name == "generate_graph_dot": | |
| return {"diagram_dot": content} | |
| return {} | |
| def route_after_tools(state: PharmAIState) -> str: | |
| # If we already have the final diagram artifact, stop. | |
| if state.get("diagram_png_base64"): | |
| return END | |
| return "bump_tool_loop" | |
| def preprocess(state: PharmAIState) -> Dict[str, Any]: | |
| q = (state.get("user_query") or "").strip().lower() | |
| if any(k in q for k in ["diagram", "flowchart", "architecture", "graphviz", "dot", "draw"]): | |
| return {"intent": "diagram"} | |
| if re.match(r"^(what is|define|explain)\b", q) and len(q) < 120: | |
| return {"intent": "simple"} | |
| return {"intent": "diligence"} | |
| def route_after_llm(state: PharmAIState): | |
| # If query is simple, never call tools/synthesize | |
| if state.get("intent") == "simple": | |
| return "end_simple" | |
| # If the model asked for tools, go tools | |
| last = state["messages"][-1] | |
| if getattr(last, "tool_calls", None): | |
| return "tools" | |
| return "synthesize" | |
| def end_simple(state: PharmAIState) -> Dict[str, Any]: | |
| # Return the last assistant content as the final answer | |
| last = state["messages"][-1] | |
| text = getattr(last, "content", "") if isinstance(getattr(last, "content", ""), str) else str(getattr(last, "content", "")) | |
| return {"decision_brief": text, "citations": []} | |
| # ----------------------------- | |
| # Final Synthesis Node | |
| # ----------------------------- | |
| def synthesize(state: PharmAIState) -> Dict[str, Any]: | |
| # Fast guardrail: placeholders -> short response without tool burn | |
| uq = state.get("user_query", "") | |
| if _looks_like_placeholder(uq): | |
| brief = ( | |
| "# FINAL DECISION BRIEF\n\n" | |
| "I need the **actual drug name** and **specific indication** to perform diligence.\n\n" | |
| "## Evidence Gaps\n" | |
| "- Drug name (e.g., semaglutide)\n" | |
| "- Indication (e.g., obesity)\n" | |
| "- Trial/program context (if any)\n" | |
| ) | |
| return { | |
| "decision_brief": _append_citations_section(brief, []), | |
| "citations": [], | |
| "messages": [HumanMessage(content="(placeholder query detected; returned guardrail response)")], | |
| } | |
| llm = _build_model() | |
| messages: List[BaseMessage] = state["messages"] | |
| messages = messages + [HumanMessage(content=FINAL_PROMPT)] | |
| resp = llm.invoke(messages) | |
| tool_citations = _extract_citations_from_messages(state["messages"]) | |
| brief_text = resp.content if isinstance(resp.content, str) else str(resp.content) | |
| brief_text = _append_citations_section(brief_text, tool_citations) | |
| return { | |
| "decision_brief": brief_text, | |
| "citations": tool_citations, | |
| "messages": [resp], | |
| } | |
| # ----------------------------- | |
| # Build + Compile Graph | |
| # ----------------------------- | |
| def build_graph(): | |
| """ | |
| Graph with preprocessing and smart routing. | |
| """ | |
| g = StateGraph(PharmAIState) | |
| g.add_node("preprocess", preprocess) | |
| g.add_node("llm_call", llm_call) | |
| g.add_node("tools", ToolNode(TOOLS)) | |
| g.add_node("capture_diagram", capture_diagram) | |
| g.add_node("bump_tool_loop", lambda s: {"tool_loops": s.get("tool_loops", 0) + 1}) | |
| g.add_node("synthesize", synthesize) | |
| g.add_node("end_simple", end_simple) | |
| g.add_edge(START, "preprocess") | |
| g.add_edge("preprocess", "llm_call") | |
| # After LLM: route based on intent and tool calls | |
| g.add_conditional_edges( | |
| "llm_call", | |
| route_after_llm, | |
| { | |
| "tools": "tools", | |
| "synthesize": "synthesize", | |
| "end_simple": "end_simple", | |
| }, | |
| ) | |
| # After tools: capture diagram data | |
| g.add_edge("tools", "capture_diagram") | |
| # After capture: check if we should stop (diagram complete) or continue | |
| g.add_conditional_edges( | |
| "capture_diagram", | |
| route_after_tools, | |
| { | |
| END: END, # Stop if diagram is complete | |
| "bump_tool_loop": "bump_tool_loop", # Continue otherwise | |
| }, | |
| ) | |
| g.add_edge("bump_tool_loop", "llm_call") | |
| g.add_edge("end_simple", END) | |
| g.add_edge("synthesize", END) | |
| return g.compile() | |
| # ----------------------------- | |
| # Test execution | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| print("Building PharmAI Navigator graph...") | |
| graph = build_graph() | |
| print("Graph compiled successfully!") | |
| # Test query designed to trigger generate_graph_dot tool | |
| #test_query = "Assess semaglutide for obesity" | |
| #test_query = "Assess donanemab for early Alzheimer’s disease. Retrieve key clinical trials, summarize efficacy and safety outcomes, normalize the evidence, and generate a system architecture graph showing how PharmAI Navigator evaluates this asset." | |
| #test_query = "Create a DOT graph showing the relationship between Drug, Indication, Clinical Trials, FDA Approval, and Market Launch and render it as png" | |
| test_query = "What is pembrolizumab?" | |
| print(f"\nRunning test query: {test_query}") | |
| result = graph.invoke({ | |
| "messages": [HumanMessage(content=test_query)], | |
| "user_query": test_query, | |
| "tool_loops": 0, | |
| }) | |
| print("\n" + "=" * 60) | |
| print("OUTPUT:") | |
| print("=" * 60) | |
| print(result.get("decision_brief", "No output")) | |
| print("\n" + "=" * 60) | |
| print("CITATIONS (tool-only):") | |
| print("=" * 60) | |
| for i, citation in enumerate(result.get("citations", []), 1): | |
| print(f"{i}. {citation}") | |