File size: 16,480 Bytes
1de0a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
from __future__ import annotations

import os
import json
import re
from typing import Any, Dict, List, Optional
from dotenv import load_dotenv

from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool

from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition

from tools import (
    tavily_search,
    stub_evidence,
    classify_query,
    extract_entities,
    normalize_evidence,
    generate_graph_dot,
    clinicaltrials_search,
    render_dot_to_png_base64
)

# Load environment variables
load_dotenv()

# -----------------------------
# LangChain Tool Wrappers
# -----------------------------
@tool("web_search")
def web_search_tool(query: str, max_results: int = 5) -> List[Dict[str, Any]]:
    """Web search using Tavily. Returns a list of evidence dicts."""
    ev = tavily_search(query=query, max_results=max_results)
    return [e.model_dump() for e in ev]


@tool("stub_evidence")
def stub_evidence_tool(query: str) -> List[Dict[str, Any]]:
    """Deterministic fallback evidence tool (offline/demo)."""
    ev = stub_evidence(query=query)
    return [e.model_dump() for e in ev]

@tool("classify_query")
def classify_query_tool(query: str) -> Dict[str, Any]:
    """Classify query to decide which tools are needed."""
    return classify_query(query)


@tool("extract_entities")
def extract_entities_tool(query: str) -> Dict[str, Optional[str]]:
    """Extract drug and indication from query."""
    return extract_entities(query)


@tool("normalize_evidence")
def normalize_evidence_tool(evidence: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Dedupe and clean evidence."""
    return normalize_evidence(evidence)


@tool("generate_graph_dot")
def generate_graph_dot_tool(

    title: str,

    nodes: List[Dict[str, str]],

    edges: List[Dict[str, str]],

    rankdir: str = "LR",

) -> str:
    """

    Generate Graphviz DOT.

    IMPORTANT: Use this tool instead of writing DOT directly.

    """
    return generate_graph_dot(
        title=title,
        nodes=nodes,
        edges=edges,
        rankdir=rankdir,
    )

@tool("clinicaltrials_search")
def clinicaltrials_search_tool(drug: str, indication: str, max_results: int = 5) -> List[Dict[str, Any]]:
    """Search ClinicalTrials.gov (Tavily-based MVP)."""
    ev = clinicaltrials_search(drug=drug, indication=indication, max_results=max_results)
    return [e.model_dump() for e in ev]

@tool("render_dot_to_png_base64")
def render_dot_to_png_base64_tool(dot: str) -> Dict[str, Any]:
    """Render DOT to PNG (base64). Optional dependency on graphviz."""
    return render_dot_to_png_base64(dot)

TOOLS = [
    web_search_tool,
    stub_evidence_tool,
    classify_query_tool,
    extract_entities_tool,
    normalize_evidence_tool,
    generate_graph_dot_tool,
    clinicaltrials_search_tool,
    render_dot_to_png_base64_tool
]

# -----------------------------
# LangGraph State
# -----------------------------
class PharmAIState(MessagesState):
    session_id: Optional[str]
    user_query: str
    decision_brief: str
    citations: List[str]
    confidence_score: float
    tool_loops: int                    # safety counter
    diagram_png_base64: Optional[str]  # <-- add
    diagram_dot: Optional[str]         # <-- optional
    intent: str  # "simple" | "diligence" | "diagram"

# -----------------------------
# Guardrails + Prompts
# -----------------------------
SYSTEM_PROMPT = """You are PharmAI Navigator, an evidence-grounded diligence assistant for drug/asset evaluation.



Your job:

Turn a query like "Assess {Drug} for {Indication}" into a decision-grade brief OR structured output.



CRITICAL TOOL USAGE RULES:

- If the user asks for a diagram, flow, architecture, graph, visualization, or Graphviz:

  → You MUST call `generate_graph_dot`.

  → You MUST NOT write Graphviz DOT directly in your response.

  → If the user asks for an image/PNG, call `render_dot_to_png_base64` AFTER you get DOT.

- If the user asks for trials / phases / NCT IDs / endpoints:

  → Prefer calling `extract_entities` then `clinicaltrials_search`.

- If the user asks for factual claims (approvals, safety, pricing, patents, market):

  → Prefer calling `web_search`.



Guardrails (STRICT):

- Do NOT invent specific facts (approval dates, trial names, endpoints, statistics, patent expiry).

- Any concrete number/date/claim MUST be supported by tool evidence.

- If evidence is insufficient, clearly list Evidence Gaps.

- Be concise, structured, and decision-oriented.

- Avoid medical advice; present as diligence/analysis.



Simple Query Rule (CRITICAL):

- If the user asks a simple definitional question ("what is", "define", "explain") and you can answer without external verification, do NOT call tools and respond directly.

- Only use tools when you need current/specific data (trials, approvals, patents, market data).



Citations policy:

- The final response's "Citations" section is handled by the system.

- Do NOT create your own citation list.

"""

FINAL_PROMPT = """Write the FINAL decision brief with these sections:



1) Executive Recommendation (1–2 lines)

2) Scientific Rationale (bullets)

3) Clinical Evidence Snapshot (bullets)

4) IP / Exclusivity Quick View (bullets)

5) Market / SoC Snapshot (bullets)

6) Key Risks + Next Actions (bullets)



Rules:

- If evidence is insufficient, include "Evidence Gaps" with bullets.

- Do NOT add a citations section yourself; the system will append it.

Return plain text only.

"""

# Placeholder detection to avoid wasting tokens on "Drug X / Indication Y"
PLACEHOLDER_PATTERNS = [
    r"\bdrug\s*x\b",
    r"\bindication\s*y\b",
    r"\bdrug\s*name\b",
    r"\bindication\s*name\b",
]
def _looks_like_placeholder(q: str) -> bool:
    ql = (q or "").strip().lower()
    return any(re.search(p, ql) for p in PLACEHOLDER_PATTERNS)


def _build_model() -> ChatAnthropic:
    model_name = os.getenv("ANTHROPIC_MODEL", "claude-3-7-sonnet-latest")
    return ChatAnthropic(
        model=model_name,
        temperature=0.2,
        max_tokens=10000,
        timeout=120,
        streaming=False,
        stop=None
    ).bind_tools(TOOLS)


# Safety cap to avoid endless tool loops
MAX_TOOL_LOOPS = int(os.getenv("MAX_TOOL_LOOPS", "4"))


def llm_call(state: PharmAIState) -> Dict[str, Any]:
    """

    Calls Claude with tool schemas attached.

    Returns new messages to append into state["messages"].

    """
    llm = _build_model()
    messages: List[BaseMessage] = state["messages"]

    if not messages or not isinstance(messages[0], SystemMessage):
        messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages

    tool_loops = state.get("tool_loops", 0)
    if tool_loops >= MAX_TOOL_LOOPS:
        # Stop tool-calling loop and force synthesis
        stop_msg = HumanMessage(
            content=(
                "Stop calling tools now. Proceed to final synthesis using what you already have. "
                "If evidence is insufficient, clearly list Evidence Gaps."
            )
        )
        messages = messages + [stop_msg]

    resp = llm.invoke(messages)
    return {"messages": [resp]}


# -----------------------------
# Citations extraction (tool-only)
# -----------------------------
def _clean_url(u: str) -> str:
    return u.strip().strip("),.]}\"'")

def _extract_citations_from_messages(messages: List[BaseMessage]) -> List[str]:
    """

    Tool-only citation extraction (single source of truth):

    - ONLY reads ToolMessage contents (actual tool outputs).

    - If tool output is JSON (list/dict), pull `source` fields.

    - Fallback: regex URL extraction from tool text.

    """
    citations: List[str] = []
    url_re = re.compile(r"https?://[^\s\]\)\}\",']+")

    for m in messages:
        if not isinstance(m, ToolMessage):
            continue

        content = getattr(m, "content", None)
        if not content:
            continue

        if isinstance(content, str):
            parsed = None
            try:
                parsed = json.loads(content)
            except Exception:
                parsed = None

            if isinstance(parsed, list):
                for item in parsed:
                    if isinstance(item, dict):
                        src = item.get("source")
                        if isinstance(src, str) and src.startswith(("http://", "https://")):
                            citations.append(_clean_url(src))
            elif isinstance(parsed, dict):
                src = parsed.get("source")
                if isinstance(src, str) and src.startswith(("http://", "https://")):
                    citations.append(_clean_url(src))

            for u in url_re.findall(content):
                citations.append(_clean_url(u))

    # De-duplicate
    seen = set()
    out = []
    for c in citations:
        # drop clearly broken/truncated URLs
        if len(c) < 12:
            continue
        if c not in seen:
            seen.add(c)
            out.append(c)
    return out


def _append_citations_section(brief_text: str, citations: List[str]) -> str:
    """

    Enforces "single source of truth":

    - Removes any existing 'Citations' section the model may have produced

    - Appends citations derived from tool outputs only

    """
    text = (brief_text or "").strip()

    # Remove any model-generated citations section (best-effort)
    # (handles '## Citations' or 'Citations' headers)
    text = re.split(r"\n#{1,3}\s*Citations\s*\n|\nCitations\s*\n", text, maxsplit=1)[0].rstrip()

    if citations:
        lines = ["", "## Citations"]
        for i, c in enumerate(citations, 1):
            lines.append(f"{i}. {c}")
        text = text + "\n" + "\n".join(lines)
    else:
        text = text + "\n\n## Citations\n- (No external sources retrieved.)"

    return text

def capture_diagram(state: PharmAIState) -> Dict[str, Any]:
    # Find the last ToolMessage (most recent tool output)
    last_tool = None
    for m in reversed(state["messages"]):
        if isinstance(m, ToolMessage):
            last_tool = m
            break

    if not last_tool:
        return {}

    tool_name = getattr(last_tool, "name", "") or ""
    content = getattr(last_tool, "content", "")

    # If your render tool returns base64 string directly
    if tool_name == "render_dot_to_png_base64":
        return {"diagram_png_base64": content}

    # If your generate_graph_dot returns dot string
    if tool_name == "generate_graph_dot":
        return {"diagram_dot": content}

    return {}

def route_after_tools(state: PharmAIState) -> str:
    # If we already have the final diagram artifact, stop.
    if state.get("diagram_png_base64"):
        return END
    return "bump_tool_loop"

def preprocess(state: PharmAIState) -> Dict[str, Any]:
    q = (state.get("user_query") or "").strip().lower()

    if any(k in q for k in ["diagram", "flowchart", "architecture", "graphviz", "dot", "draw"]):
        return {"intent": "diagram"}

    if re.match(r"^(what is|define|explain)\b", q) and len(q) < 120:
        return {"intent": "simple"}

    return {"intent": "diligence"}

def route_after_llm(state: PharmAIState):
    # If query is simple, never call tools/synthesize
    if state.get("intent") == "simple":
        return "end_simple"

    # If the model asked for tools, go tools
    last = state["messages"][-1]
    if getattr(last, "tool_calls", None):
        return "tools"

    return "synthesize"

def end_simple(state: PharmAIState) -> Dict[str, Any]:
    # Return the last assistant content as the final answer
    last = state["messages"][-1]
    text = getattr(last, "content", "") if isinstance(getattr(last, "content", ""), str) else str(getattr(last, "content", ""))
    return {"decision_brief": text, "citations": []}


# -----------------------------
# Final Synthesis Node
# -----------------------------
def synthesize(state: PharmAIState) -> Dict[str, Any]:
    # Fast guardrail: placeholders -> short response without tool burn
    uq = state.get("user_query", "")
    if _looks_like_placeholder(uq):
        brief = (
            "# FINAL DECISION BRIEF\n\n"
            "I need the **actual drug name** and **specific indication** to perform diligence.\n\n"
            "## Evidence Gaps\n"
            "- Drug name (e.g., semaglutide)\n"
            "- Indication (e.g., obesity)\n"
            "- Trial/program context (if any)\n"
        )
        return {
            "decision_brief": _append_citations_section(brief, []),
            "citations": [],
            "messages": [HumanMessage(content="(placeholder query detected; returned guardrail response)")],
        }

    llm = _build_model()
    messages: List[BaseMessage] = state["messages"]
    messages = messages + [HumanMessage(content=FINAL_PROMPT)]

    resp = llm.invoke(messages)

    tool_citations = _extract_citations_from_messages(state["messages"])
    brief_text = resp.content if isinstance(resp.content, str) else str(resp.content)
    brief_text = _append_citations_section(brief_text, tool_citations)

    return {
        "decision_brief": brief_text,
        "citations": tool_citations,
        "messages": [resp],
    }


# -----------------------------
# Build + Compile Graph
# -----------------------------
def build_graph():
    """

    Graph with preprocessing and smart routing.

    """
    g = StateGraph(PharmAIState)
    
    g.add_node("preprocess", preprocess)
    g.add_node("llm_call", llm_call)
    g.add_node("tools", ToolNode(TOOLS))
    g.add_node("capture_diagram", capture_diagram)
    g.add_node("bump_tool_loop", lambda s: {"tool_loops": s.get("tool_loops", 0) + 1})
    g.add_node("synthesize", synthesize)
    g.add_node("end_simple", end_simple)

    g.add_edge(START, "preprocess")
    g.add_edge("preprocess", "llm_call")
    
    # After LLM: route based on intent and tool calls
    g.add_conditional_edges(
        "llm_call",
        route_after_llm,
        {
            "tools": "tools",
            "synthesize": "synthesize",
            "end_simple": "end_simple",
        },
    )

    # After tools: capture diagram data
    g.add_edge("tools", "capture_diagram")
    
    # After capture: check if we should stop (diagram complete) or continue
    g.add_conditional_edges(
        "capture_diagram",
        route_after_tools,
        {
            END: END,  # Stop if diagram is complete
            "bump_tool_loop": "bump_tool_loop",  # Continue otherwise
        },
    )
    
    g.add_edge("bump_tool_loop", "llm_call")
    g.add_edge("end_simple", END)
    g.add_edge("synthesize", END)
    
    return g.compile()

# -----------------------------
# Test execution
# -----------------------------
if __name__ == "__main__":
    print("Building PharmAI Navigator graph...")
    graph = build_graph()
    print("Graph compiled successfully!")

    # Test query designed to trigger generate_graph_dot tool
    #test_query = "Assess semaglutide for obesity"
    #test_query = "Assess donanemab for early Alzheimer’s disease. Retrieve key clinical trials, summarize efficacy and safety outcomes, normalize the evidence, and generate a system architecture graph showing how PharmAI Navigator evaluates this asset."
    #test_query = "Create a DOT graph showing the relationship between Drug, Indication, Clinical Trials, FDA Approval, and Market Launch and render it as png"
    test_query = "What is pembrolizumab?"
    print(f"\nRunning test query: {test_query}")

    result = graph.invoke({
        "messages": [HumanMessage(content=test_query)],
        "user_query": test_query,
        "tool_loops": 0,
    })

    print("\n" + "=" * 60)
    print("OUTPUT:")
    print("=" * 60)
    print(result.get("decision_brief", "No output"))

    print("\n" + "=" * 60)
    print("CITATIONS (tool-only):")
    print("=" * 60)
    for i, citation in enumerate(result.get("citations", []), 1):
        print(f"{i}. {citation}")